From f39340ff6b29eddae3fdac71b80dcfded8cc9d07 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=ADctor=20Navarro=20Ar=C3=A1nguiz?= Date: Wed, 31 May 2023 16:32:31 -0400 Subject: [PATCH] Add allow_download as attribute for GPT4All (#5512) # Added support for download GPT4All model if does not exist I've include the class attribute `allow_download` to the GPT4All class. By default, `allow_download` is set to False. ## Changes Made - Added a new attribute `allow_download` to the GPT4All class. - Updated the `validate_environment` method to pass the `allow_download` parameter to the GPT4All model constructor. ## Context This change provides more control over model downloading in the GPT4All class. Previously, if the model file was not found in the cache directory `~/.cache/gpt4all/`, the package returned error "Failed to retrieve model (type=value_error)". Now, if `allow_download` is set as True then it will use GPT4All package to download it . With the addition of the `allow_download` attribute, users can now choose whether the wrapper is allowed to download the model or not. ## Dependencies There are no new dependencies introduced by this change. It only utilizes existing functionality provided by the GPT4All package. ## Testing Since this is a minor change to the existing behavior, the existing test suite for the GPT4All package should cover this scenario Co-authored-by: Vokturz --- langchain/llms/gpt4all.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/langchain/llms/gpt4all.py b/langchain/llms/gpt4all.py index e7472b68..8fc9e84f 100644 --- a/langchain/llms/gpt4all.py +++ b/langchain/llms/gpt4all.py @@ -92,6 +92,9 @@ class GPT4All(LLM): """Leave (n_ctx * context_erase) tokens starting from beginning if the context has run out.""" + allow_download: bool = False + """If model does not exist in ~/.cache/gpt4all/, download it.""" + client: Any = None #: :meta private: class Config: @@ -145,7 +148,7 @@ class GPT4All(LLM): model_name, model_path=model_path or None, model_type=values["backend"], - allow_download=False, + allow_download=values["allow_download"], ) if values["n_threads"] is not None: # set n_threads