diff --git a/libs/langchain/langchain/llms/gpt4all.py b/libs/langchain/langchain/llms/gpt4all.py index 67aa6d49a7..3f9a397ffb 100644 --- a/libs/langchain/langchain/llms/gpt4all.py +++ b/libs/langchain/langchain/llms/gpt4all.py @@ -89,6 +89,9 @@ class GPT4All(LLM): allow_download: bool = False """If model does not exist in ~/.cache/gpt4all/, download it.""" + device: Optional[str] = Field("cpu", alias="device") + """Device name: cpu, gpu, nvidia, intel, amd or DeviceName.""" + client: Any = None #: :meta private: class Config: @@ -141,6 +144,7 @@ class GPT4All(LLM): model_path=model_path or None, model_type=values["backend"], allow_download=values["allow_download"], + device=values["device"], ) if values["n_threads"] is not None: # set n_threads