diff --git a/langchain/llms/gpt4all.py b/langchain/llms/gpt4all.py index 166e6853..e7472b68 100644 --- a/langchain/llms/gpt4all.py +++ b/langchain/llms/gpt4all.py @@ -131,24 +131,27 @@ class GPT4All(LLM): """Validate that the python package exists in the environment.""" try: from gpt4all import GPT4All as GPT4AllModel - - full_path = values["model"] - model_path, delimiter, model_name = full_path.rpartition("/") - model_path += delimiter - - values["client"] = GPT4AllModel( - model_name=model_name, - model_path=model_path or None, - model_type=values["backend"], - allow_download=False, - ) - values["backend"] = values["client"].model.model_type - except ImportError: - raise ValueError( + raise ImportError( "Could not import gpt4all python package. " "Please install it with `pip install gpt4all`." ) + + full_path = values["model"] + model_path, delimiter, model_name = full_path.rpartition("/") + model_path += delimiter + + values["client"] = GPT4AllModel( + model_name, + model_path=model_path or None, + model_type=values["backend"], + allow_download=False, + ) + if values["n_threads"] is not None: + # set n_threads + values["client"].model.set_thread_count(values["n_threads"]) + values["backend"] = values["client"].model.model_type + return values @property