Fix GPT4All bug w/ "n_ctx" param (#7093)

Running `GPT4All` per the
[docs](https://python.langchain.com/docs/modules/model_io/models/llms/integrations/gpt4all),
I see:

```
$ from langchain.llms import GPT4All
$ model = GPT4All(model=local_path)
$ model("The capital of France is ", max_tokens=10)
TypeError: generate() got an unexpected keyword argument 'n_ctx'
```

It appears `n_ctx` is [no longer a supported
param](https://docs.gpt4all.io/gpt4all_python.html#gpt4all.gpt4all.GPT4All.generate)
in the GPT4All API from https://github.com/nomic-ai/gpt4all/pull/1090.

It now uses `max_tokens`, so I set this.

And I also set other defaults used in GPT4All client
[here](https://github.com/nomic-ai/gpt4all/blob/main/gpt4all-bindings/python/gpt4all/gpt4all.py).

Confirm it now works:
```
$ from langchain.llms import GPT4All
$ model = GPT4All(model=local_path)
$ model("The capital of France is ", max_tokens=10)
< Model logging > 
"....Paris."
```

---------

Co-authored-by: R. Lance Martin <rlm@Rs-MacBook-Pro.local>
This commit is contained in:
Lance Martin 2023-07-04 08:53:52 -07:00 committed by GitHub
parent 6631fd5168
commit 265c285057
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 26 additions and 37 deletions

View File

@ -32,7 +32,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": 2,
"metadata": { "metadata": {
"tags": [] "tags": []
}, },
@ -45,7 +45,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 3,
"metadata": { "metadata": {
"tags": [] "tags": []
}, },
@ -64,13 +64,20 @@
"source": [ "source": [
"### Specify Model\n", "### Specify Model\n",
"\n", "\n",
"To run locally, download a compatible ggml-formatted model. For more info, visit https://github.com/nomic-ai/gpt4all\n", "To run locally, download a compatible ggml-formatted model. \n",
" \n",
"**Download option 1**: The [gpt4all page](https://gpt4all.io/index.html) has a useful `Model Explorer` section:\n",
"\n", "\n",
"For full installation instructions go [here](https://gpt4all.io/index.html).\n", "* Select a model of interest\n",
"* Download using the UI and move the `.bin` to the `local_path` (noted below)\n",
"\n", "\n",
"The GPT4All Chat installer needs to decompress a 3GB LLM model during the installation process!\n", "For more info, visit https://github.com/nomic-ai/gpt4all.\n",
"\n", "\n",
"Note that new models are uploaded regularly - check the link above for the most recent `.bin` URL" "--- \n",
"\n",
"**Download option 2**: Uncomment the below block to download a model. \n",
"\n",
"* You may want to update `url` to a new version, whih can be browsed using the [gpt4all page](https://gpt4all.io/index.html)."
] ]
}, },
{ {
@ -81,22 +88,8 @@
"source": [ "source": [
"local_path = (\n", "local_path = (\n",
" \"./models/ggml-gpt4all-l13b-snoozy.bin\" # replace with your desired local file path\n", " \"./models/ggml-gpt4all-l13b-snoozy.bin\" # replace with your desired local file path\n",
")" ")\n",
] "\n",
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Uncomment the below block to download a model. You may want to update `url` to a new version."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# import requests\n", "# import requests\n",
"\n", "\n",
"# from pathlib import Path\n", "# from pathlib import Path\n",
@ -126,8 +119,10 @@
"source": [ "source": [
"# Callbacks support token-wise streaming\n", "# Callbacks support token-wise streaming\n",
"callbacks = [StreamingStdOutCallbackHandler()]\n", "callbacks = [StreamingStdOutCallbackHandler()]\n",
"\n",
"# Verbose is required to pass to the callback manager\n", "# Verbose is required to pass to the callback manager\n",
"llm = GPT4All(model=local_path, callbacks=callbacks, verbose=True)\n", "llm = GPT4All(model=local_path, callbacks=callbacks, verbose=True)\n",
"\n",
"# If you want to use a custom model add the backend parameter\n", "# If you want to use a custom model add the backend parameter\n",
"# Check https://docs.gpt4all.io/gpt4all_python.html for supported backends\n", "# Check https://docs.gpt4all.io/gpt4all_python.html for supported backends\n",
"llm = GPT4All(model=local_path, backend=\"gptj\", callbacks=callbacks, verbose=True)" "llm = GPT4All(model=local_path, backend=\"gptj\", callbacks=callbacks, verbose=True)"
@ -170,7 +165,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.11.2" "version": "3.9.16"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@ -19,7 +19,7 @@ class GPT4All(LLM):
.. code-block:: python .. code-block:: python
from langchain.llms import GPT4All from langchain.llms import GPT4All
model = GPT4All(model="./models/gpt4all-model.bin", n_ctx=512, n_threads=8) model = GPT4All(model="./models/gpt4all-model.bin", n_threads=8)
# Simplest invocation # Simplest invocation
response = model("Once upon a time, ") response = model("Once upon a time, ")
@ -30,7 +30,7 @@ class GPT4All(LLM):
backend: Optional[str] = Field(None, alias="backend") backend: Optional[str] = Field(None, alias="backend")
n_ctx: int = Field(512, alias="n_ctx") max_tokens: int = Field(200, alias="max_tokens")
"""Token context window.""" """Token context window."""
n_parts: int = Field(-1, alias="n_parts") n_parts: int = Field(-1, alias="n_parts")
@ -61,10 +61,10 @@ class GPT4All(LLM):
n_predict: Optional[int] = 256 n_predict: Optional[int] = 256
"""The maximum number of tokens to generate.""" """The maximum number of tokens to generate."""
temp: Optional[float] = 0.8 temp: Optional[float] = 0.7
"""The temperature to use for sampling.""" """The temperature to use for sampling."""
top_p: Optional[float] = 0.95 top_p: Optional[float] = 0.1
"""The top-p value to use for sampling.""" """The top-p value to use for sampling."""
top_k: Optional[int] = 40 top_k: Optional[int] = 40
@ -79,19 +79,15 @@ class GPT4All(LLM):
repeat_last_n: Optional[int] = 64 repeat_last_n: Optional[int] = 64
"Last n tokens to penalize" "Last n tokens to penalize"
repeat_penalty: Optional[float] = 1.3 repeat_penalty: Optional[float] = 1.18
"""The penalty to apply to repeated tokens.""" """The penalty to apply to repeated tokens."""
n_batch: int = Field(1, alias="n_batch") n_batch: int = Field(8, alias="n_batch")
"""Batch size for prompt processing.""" """Batch size for prompt processing."""
streaming: bool = False streaming: bool = False
"""Whether to stream the results or not.""" """Whether to stream the results or not."""
context_erase: float = 0.5
"""Leave (n_ctx * context_erase) tokens
starting from beginning if the context has run out."""
allow_download: bool = False allow_download: bool = False
"""If model does not exist in ~/.cache/gpt4all/, download it.""" """If model does not exist in ~/.cache/gpt4all/, download it."""
@ -105,7 +101,7 @@ class GPT4All(LLM):
@staticmethod @staticmethod
def _model_param_names() -> Set[str]: def _model_param_names() -> Set[str]:
return { return {
"n_ctx", "max_tokens",
"n_predict", "n_predict",
"top_k", "top_k",
"top_p", "top_p",
@ -113,12 +109,11 @@ class GPT4All(LLM):
"n_batch", "n_batch",
"repeat_penalty", "repeat_penalty",
"repeat_last_n", "repeat_last_n",
"context_erase",
} }
def _default_params(self) -> Dict[str, Any]: def _default_params(self) -> Dict[str, Any]:
return { return {
"n_ctx": self.n_ctx, "max_tokens": self.max_tokens,
"n_predict": self.n_predict, "n_predict": self.n_predict,
"top_k": self.top_k, "top_k": self.top_k,
"top_p": self.top_p, "top_p": self.top_p,
@ -126,7 +121,6 @@ class GPT4All(LLM):
"n_batch": self.n_batch, "n_batch": self.n_batch,
"repeat_penalty": self.repeat_penalty, "repeat_penalty": self.repeat_penalty,
"repeat_last_n": self.repeat_last_n, "repeat_last_n": self.repeat_last_n,
"context_erase": self.context_erase,
} }
@root_validator() @root_validator()