forked from Archives/langchain
added n_threads functionality for gpt4all (#5427)
# Added support for modifying the number of threads in the GPT4All model I have added the capability to modify the number of threads used by the GPT4All model. This allows users to adjust the model's parallel processing capabilities based on their specific requirements. ## Changes Made - Updated the `validate_environment` method to set the number of threads for the GPT4All model using the `values["n_threads"]` parameter from the `GPT4All` class constructor. ## Context Useful in scenarios where users want to optimize the model's performance by leveraging multi-threading capabilities. Please note that the `n_threads` parameter was included in the `GPT4All` class constructor but was previously unused. This change ensures that the specified number of threads is utilized by the model . ## Dependencies There are no new dependencies introduced by this change. It only utilizes existing functionality provided by the GPT4All package. ## Testing Since this is a minor change testing is not required. --------- Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
This commit is contained in:
parent
e31705b5ab
commit
8121e04200
@ -131,24 +131,27 @@ class GPT4All(LLM):
|
|||||||
"""Validate that the python package exists in the environment."""
|
"""Validate that the python package exists in the environment."""
|
||||||
try:
|
try:
|
||||||
from gpt4all import GPT4All as GPT4AllModel
|
from gpt4all import GPT4All as GPT4AllModel
|
||||||
|
|
||||||
full_path = values["model"]
|
|
||||||
model_path, delimiter, model_name = full_path.rpartition("/")
|
|
||||||
model_path += delimiter
|
|
||||||
|
|
||||||
values["client"] = GPT4AllModel(
|
|
||||||
model_name=model_name,
|
|
||||||
model_path=model_path or None,
|
|
||||||
model_type=values["backend"],
|
|
||||||
allow_download=False,
|
|
||||||
)
|
|
||||||
values["backend"] = values["client"].model.model_type
|
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ValueError(
|
raise ImportError(
|
||||||
"Could not import gpt4all python package. "
|
"Could not import gpt4all python package. "
|
||||||
"Please install it with `pip install gpt4all`."
|
"Please install it with `pip install gpt4all`."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
full_path = values["model"]
|
||||||
|
model_path, delimiter, model_name = full_path.rpartition("/")
|
||||||
|
model_path += delimiter
|
||||||
|
|
||||||
|
values["client"] = GPT4AllModel(
|
||||||
|
model_name,
|
||||||
|
model_path=model_path or None,
|
||||||
|
model_type=values["backend"],
|
||||||
|
allow_download=False,
|
||||||
|
)
|
||||||
|
if values["n_threads"] is not None:
|
||||||
|
# set n_threads
|
||||||
|
values["client"].model.set_thread_count(values["n_threads"])
|
||||||
|
values["backend"] = values["client"].model.model_type
|
||||||
|
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
Loading…
Reference in New Issue
Block a user