From c9e2a0187549f6fa2661b943c13af9d63d44eee1 Mon Sep 17 00:00:00 2001 From: Alexey Nominas <60900649+Chae4ek@users.noreply.github.com> Date: Thu, 18 May 2023 22:38:54 +0600 Subject: [PATCH] Update GPT4ALL integration (#4567) # Update GPT4ALL integration GPT4ALL have completely changed their bindings. They use a bit odd implementation that doesn't fit well into base.py and it will probably be changed again, so it's a temporary solution. Fixes #3839, #4628 --- .../models/llms/integrations/gpt4all.ipynb | 9 +- langchain/llms/gpt4all.py | 103 +++++++----------- 2 files changed, 42 insertions(+), 70 deletions(-) diff --git a/docs/modules/models/llms/integrations/gpt4all.ipynb b/docs/modules/models/llms/integrations/gpt4all.ipynb index e7d3468e..7f467ad3 100644 --- a/docs/modules/models/llms/integrations/gpt4all.ipynb +++ b/docs/modules/models/llms/integrations/gpt4all.ipynb @@ -27,7 +27,7 @@ } ], "source": [ - "%pip install pygpt4all > /dev/null" + "%pip install gpt4all > /dev/null" ] }, { @@ -64,7 +64,7 @@ "source": [ "### Specify Model\n", "\n", - "To run locally, download a compatible ggml-formatted model. For more info, visit https://github.com/nomic-ai/pygpt4all\n", + "To run locally, download a compatible ggml-formatted model. For more info, visit https://github.com/nomic-ai/gpt4all\n", "\n", "For full installation instructions go [here](https://gpt4all.io/index.html).\n", "\n", @@ -102,7 +102,7 @@ "\n", "# Path(local_path).parent.mkdir(parents=True, exist_ok=True)\n", "\n", - "# # Example model. Check https://github.com/nomic-ai/pygpt4all for the latest models.\n", + "# # Example model. Check https://github.com/nomic-ai/gpt4all for the latest models.\n", "# url = 'http://gpt4all.io/models/ggml-gpt4all-l13b-snoozy.bin'\n", "\n", "# # send a GET request to the URL to download the file. Stream since it's large\n", @@ -126,7 +126,8 @@ "callbacks = [StreamingStdOutCallbackHandler()]\n", "# Verbose is required to pass to the callback manager\n", "llm = GPT4All(model=local_path, callbacks=callbacks, verbose=True)\n", - "# If you want to use GPT4ALL_J model add the backend parameter\n", + "# If you want to use a custom model add the backend parameter\n", + "# Check https://docs.gpt4all.io/gpt4all_python.html for supported backends\n", "llm = GPT4All(model=local_path, backend='gptj', callbacks=callbacks, verbose=True)" ] }, diff --git a/langchain/llms/gpt4all.py b/langchain/llms/gpt4all.py index a46691fa..166e6853 100644 --- a/langchain/llms/gpt4all.py +++ b/langchain/llms/gpt4all.py @@ -12,7 +12,7 @@ from langchain.llms.utils import enforce_stop_tokens class GPT4All(LLM): r"""Wrapper around GPT4All language models. - To use, you should have the ``pygpt4all`` python package installed, the + To use, you should have the ``gpt4all`` python package installed, the pre-trained model file, and the model's config information. Example: @@ -28,7 +28,7 @@ class GPT4All(LLM): model: str """Path to the pre-trained GPT4All model file.""" - backend: str = Field("llama", alias="backend") + backend: Optional[str] = Field(None, alias="backend") n_ctx: int = Field(512, alias="n_ctx") """Token context window.""" @@ -88,6 +88,10 @@ class GPT4All(LLM): streaming: bool = False """Whether to stream the results or not.""" + context_erase: float = 0.5 + """Leave (n_ctx * context_erase) tokens + starting from beginning if the context has run out.""" + client: Any = None #: :meta private: class Config: @@ -95,86 +99,55 @@ class GPT4All(LLM): extra = Extra.forbid - def _llama_default_params(self) -> Dict[str, Any]: - """Get the identifying parameters.""" + @staticmethod + def _model_param_names() -> Set[str]: return { - "n_predict": self.n_predict, - "n_threads": self.n_threads, - "repeat_last_n": self.repeat_last_n, - "repeat_penalty": self.repeat_penalty, - "top_k": self.top_k, - "top_p": self.top_p, - "temp": self.temp, + "n_ctx", + "n_predict", + "top_k", + "top_p", + "temp", + "n_batch", + "repeat_penalty", + "repeat_last_n", + "context_erase", } - def _gptj_default_params(self) -> Dict[str, Any]: - """Get the identifying parameters.""" + def _default_params(self) -> Dict[str, Any]: return { + "n_ctx": self.n_ctx, "n_predict": self.n_predict, - "n_threads": self.n_threads, "top_k": self.top_k, "top_p": self.top_p, "temp": self.temp, + "n_batch": self.n_batch, + "repeat_penalty": self.repeat_penalty, + "repeat_last_n": self.repeat_last_n, + "context_erase": self.context_erase, } - @staticmethod - def _llama_param_names() -> Set[str]: - """Get the identifying parameters.""" - return { - "seed", - "n_ctx", - "n_parts", - "f16_kv", - "logits_all", - "vocab_only", - "use_mlock", - "embedding", - } - - @staticmethod - def _gptj_param_names() -> Set[str]: - """Get the identifying parameters.""" - return set() - - @staticmethod - def _model_param_names(backend: str) -> Set[str]: - if backend == "llama": - return GPT4All._llama_param_names() - else: - return GPT4All._gptj_param_names() - - def _default_params(self) -> Dict[str, Any]: - if self.backend == "llama": - return self._llama_default_params() - else: - return self._gptj_default_params() - @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that the python package exists in the environment.""" try: - backend = values["backend"] - if backend == "llama": - from pygpt4all import GPT4All as GPT4AllModel - elif backend == "gptj": - from pygpt4all import GPT4All_J as GPT4AllModel - else: - raise ValueError(f"Incorrect gpt4all backend {cls.backend}") - - model_kwargs = { - k: v - for k, v in values.items() - if k in GPT4All._model_param_names(backend) - } + from gpt4all import GPT4All as GPT4AllModel + + full_path = values["model"] + model_path, delimiter, model_name = full_path.rpartition("/") + model_path += delimiter + values["client"] = GPT4AllModel( - model_path=values["model"], - **model_kwargs, + model_name=model_name, + model_path=model_path or None, + model_type=values["backend"], + allow_download=False, ) + values["backend"] = values["client"].model.model_type except ImportError: raise ValueError( - "Could not import pygpt4all python package. " - "Please install it with `pip install pygpt4all`." + "Could not import gpt4all python package. " + "Please install it with `pip install gpt4all`." ) return values @@ -185,9 +158,7 @@ class GPT4All(LLM): "model": self.model, **self._default_params(), **{ - k: v - for k, v in self.__dict__.items() - if k in self._model_param_names(self.backend) + k: v for k, v in self.__dict__.items() if k in self._model_param_names() }, }