From 04b74d0446bdb8fc1f9e544d2f164a59bbd0df0c Mon Sep 17 00:00:00 2001 From: PawelFaron <42373772+PawelFaron@users.noreply.github.com> Date: Sun, 7 May 2023 00:14:09 +0200 Subject: [PATCH] Adjusted GPT4All llm to streaming API and added support for GPT4All_J (#4131) Fix for these issues: https://github.com/hwchase17/langchain/issues/4126 https://github.com/hwchase17/langchain/issues/3839#issuecomment-1534258559 --------- Co-authored-by: Pawel Faron --- .../models/llms/integrations/gpt4all.ipynb | 4 +- langchain/llms/gpt4all.py | 67 ++++++++++++++----- 2 files changed, 54 insertions(+), 17 deletions(-) diff --git a/docs/modules/models/llms/integrations/gpt4all.ipynb b/docs/modules/models/llms/integrations/gpt4all.ipynb index c67670e3..e7d3468e 100644 --- a/docs/modules/models/llms/integrations/gpt4all.ipynb +++ b/docs/modules/models/llms/integrations/gpt4all.ipynb @@ -125,7 +125,9 @@ "# Callbacks support token-wise streaming\n", "callbacks = [StreamingStdOutCallbackHandler()]\n", "# Verbose is required to pass to the callback manager\n", - "llm = GPT4All(model=local_path, callbacks=callbacks, verbose=True)" + "llm = GPT4All(model=local_path, callbacks=callbacks, verbose=True)\n", + "# If you want to use GPT4ALL_J model add the backend parameter\n", + "llm = GPT4All(model=local_path, backend='gptj', callbacks=callbacks, verbose=True)" ] }, { diff --git a/langchain/llms/gpt4all.py b/langchain/llms/gpt4all.py index 845f9018..a46691fa 100644 --- a/langchain/llms/gpt4all.py +++ b/langchain/llms/gpt4all.py @@ -28,6 +28,8 @@ class GPT4All(LLM): model: str """Path to the pre-trained GPT4All model file.""" + backend: str = Field("llama", alias="backend") + n_ctx: int = Field(512, alias="n_ctx") """Token context window.""" @@ -93,14 +95,11 @@ class GPT4All(LLM): extra = Extra.forbid - @property - def _default_params(self) -> Dict[str, Any]: + def _llama_default_params(self) -> Dict[str, Any]: """Get the identifying parameters.""" return { - "seed": self.seed, "n_predict": self.n_predict, "n_threads": self.n_threads, - "n_batch": self.n_batch, "repeat_last_n": self.repeat_last_n, "repeat_penalty": self.repeat_penalty, "top_k": self.top_k, @@ -108,6 +107,16 @@ class GPT4All(LLM): "temp": self.temp, } + def _gptj_default_params(self) -> Dict[str, Any]: + """Get the identifying parameters.""" + return { + "n_predict": self.n_predict, + "n_threads": self.n_threads, + "top_k": self.top_k, + "top_p": self.top_p, + "temp": self.temp, + } + @staticmethod def _llama_param_names() -> Set[str]: """Get the identifying parameters.""" @@ -122,14 +131,41 @@ class GPT4All(LLM): "embedding", } + @staticmethod + def _gptj_param_names() -> Set[str]: + """Get the identifying parameters.""" + return set() + + @staticmethod + def _model_param_names(backend: str) -> Set[str]: + if backend == "llama": + return GPT4All._llama_param_names() + else: + return GPT4All._gptj_param_names() + + def _default_params(self) -> Dict[str, Any]: + if self.backend == "llama": + return self._llama_default_params() + else: + return self._gptj_default_params() + @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that the python package exists in the environment.""" try: - from pygpt4all.models.gpt4all import GPT4All as GPT4AllModel + backend = values["backend"] + if backend == "llama": + from pygpt4all import GPT4All as GPT4AllModel + elif backend == "gptj": + from pygpt4all import GPT4All_J as GPT4AllModel + else: + raise ValueError(f"Incorrect gpt4all backend {cls.backend}") - llama_keys = cls._llama_param_names() - model_kwargs = {k: v for k, v in values.items() if k in llama_keys} + model_kwargs = { + k: v + for k, v in values.items() + if k in GPT4All._model_param_names(backend) + } values["client"] = GPT4AllModel( model_path=values["model"], **model_kwargs, @@ -147,11 +183,11 @@ class GPT4All(LLM): """Get the identifying parameters.""" return { "model": self.model, - **self._default_params, + **self._default_params(), **{ k: v for k, v in self.__dict__.items() - if k in GPT4All._llama_param_names() + if k in self._model_param_names(self.backend) }, } @@ -181,15 +217,14 @@ class GPT4All(LLM): prompt = "Once upon a time, " response = model(prompt, n_predict=55) """ + text_callback = None if run_manager: text_callback = partial(run_manager.on_llm_new_token, verbose=self.verbose) - text = self.client.generate( - prompt, - new_text_callback=text_callback, - **self._default_params, - ) - else: - text = self.client.generate(prompt, **self._default_params) + text = "" + for token in self.client.generate(prompt, **self._default_params()): + if text_callback: + text_callback(token) + text += token if stop is not None: text = enforce_stop_tokens(text, stop) return text