From 8121e04200a20a7a288d7bb200563806ccf74d7a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=ADctor=20Navarro=20Ar=C3=A1nguiz?= Date: Tue, 30 May 2023 19:31:30 -0400 Subject: [PATCH] added n_threads functionality for gpt4all (#5427) # Added support for modifying the number of threads in the GPT4All model I have added the capability to modify the number of threads used by the GPT4All model. This allows users to adjust the model's parallel processing capabilities based on their specific requirements. ## Changes Made - Updated the `validate_environment` method to set the number of threads for the GPT4All model using the `values["n_threads"]` parameter from the `GPT4All` class constructor. ## Context Useful in scenarios where users want to optimize the model's performance by leveraging multi-threading capabilities. Please note that the `n_threads` parameter was included in the `GPT4All` class constructor but was previously unused. This change ensures that the specified number of threads is utilized by the model . ## Dependencies There are no new dependencies introduced by this change. It only utilizes existing functionality provided by the GPT4All package. ## Testing Since this is a minor change testing is not required. --------- Co-authored-by: Dev 2049 --- langchain/llms/gpt4all.py | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/langchain/llms/gpt4all.py b/langchain/llms/gpt4all.py index 166e6853..e7472b68 100644 --- a/langchain/llms/gpt4all.py +++ b/langchain/llms/gpt4all.py @@ -131,24 +131,27 @@ class GPT4All(LLM): """Validate that the python package exists in the environment.""" try: from gpt4all import GPT4All as GPT4AllModel - - full_path = values["model"] - model_path, delimiter, model_name = full_path.rpartition("/") - model_path += delimiter - - values["client"] = GPT4AllModel( - model_name=model_name, - model_path=model_path or None, - model_type=values["backend"], - allow_download=False, - ) - values["backend"] = values["client"].model.model_type - except ImportError: - raise ValueError( + raise ImportError( "Could not import gpt4all python package. " "Please install it with `pip install gpt4all`." ) + + full_path = values["model"] + model_path, delimiter, model_name = full_path.rpartition("/") + model_path += delimiter + + values["client"] = GPT4AllModel( + model_name, + model_path=model_path or None, + model_type=values["backend"], + allow_download=False, + ) + if values["n_threads"] is not None: + # set n_threads + values["client"].model.set_thread_count(values["n_threads"]) + values["backend"] = values["client"].model.model_type + return values @property