mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Fix GPT4All bug w/ "n_ctx" param (#7093)
Running `GPT4All` per the [docs](https://python.langchain.com/docs/modules/model_io/models/llms/integrations/gpt4all), I see: ``` $ from langchain.llms import GPT4All $ model = GPT4All(model=local_path) $ model("The capital of France is ", max_tokens=10) TypeError: generate() got an unexpected keyword argument 'n_ctx' ``` It appears `n_ctx` is [no longer a supported param](https://docs.gpt4all.io/gpt4all_python.html#gpt4all.gpt4all.GPT4All.generate) in the GPT4All API from https://github.com/nomic-ai/gpt4all/pull/1090. It now uses `max_tokens`, so I set this. And I also set other defaults used in GPT4All client [here](https://github.com/nomic-ai/gpt4all/blob/main/gpt4all-bindings/python/gpt4all/gpt4all.py). Confirm it now works: ``` $ from langchain.llms import GPT4All $ model = GPT4All(model=local_path) $ model("The capital of France is ", max_tokens=10) < Model logging > "....Paris." ``` --------- Co-authored-by: R. Lance Martin <rlm@Rs-MacBook-Pro.local>
This commit is contained in:
parent
6631fd5168
commit
265c285057
@ -32,7 +32,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 1,
|
"execution_count": 2,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
},
|
},
|
||||||
@ -45,7 +45,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": 3,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
},
|
},
|
||||||
@ -64,13 +64,20 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"### Specify Model\n",
|
"### Specify Model\n",
|
||||||
"\n",
|
"\n",
|
||||||
"To run locally, download a compatible ggml-formatted model. For more info, visit https://github.com/nomic-ai/gpt4all\n",
|
"To run locally, download a compatible ggml-formatted model. \n",
|
||||||
|
" \n",
|
||||||
|
"**Download option 1**: The [gpt4all page](https://gpt4all.io/index.html) has a useful `Model Explorer` section:\n",
|
||||||
"\n",
|
"\n",
|
||||||
"For full installation instructions go [here](https://gpt4all.io/index.html).\n",
|
"* Select a model of interest\n",
|
||||||
|
"* Download using the UI and move the `.bin` to the `local_path` (noted below)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"The GPT4All Chat installer needs to decompress a 3GB LLM model during the installation process!\n",
|
"For more info, visit https://github.com/nomic-ai/gpt4all.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Note that new models are uploaded regularly - check the link above for the most recent `.bin` URL"
|
"--- \n",
|
||||||
|
"\n",
|
||||||
|
"**Download option 2**: Uncomment the below block to download a model. \n",
|
||||||
|
"\n",
|
||||||
|
"* You may want to update `url` to a new version, whih can be browsed using the [gpt4all page](https://gpt4all.io/index.html)."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -81,22 +88,8 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"local_path = (\n",
|
"local_path = (\n",
|
||||||
" \"./models/ggml-gpt4all-l13b-snoozy.bin\" # replace with your desired local file path\n",
|
" \"./models/ggml-gpt4all-l13b-snoozy.bin\" # replace with your desired local file path\n",
|
||||||
")"
|
")\n",
|
||||||
]
|
"\n",
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"Uncomment the below block to download a model. You may want to update `url` to a new version."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# import requests\n",
|
"# import requests\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# from pathlib import Path\n",
|
"# from pathlib import Path\n",
|
||||||
@ -126,8 +119,10 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"# Callbacks support token-wise streaming\n",
|
"# Callbacks support token-wise streaming\n",
|
||||||
"callbacks = [StreamingStdOutCallbackHandler()]\n",
|
"callbacks = [StreamingStdOutCallbackHandler()]\n",
|
||||||
|
"\n",
|
||||||
"# Verbose is required to pass to the callback manager\n",
|
"# Verbose is required to pass to the callback manager\n",
|
||||||
"llm = GPT4All(model=local_path, callbacks=callbacks, verbose=True)\n",
|
"llm = GPT4All(model=local_path, callbacks=callbacks, verbose=True)\n",
|
||||||
|
"\n",
|
||||||
"# If you want to use a custom model add the backend parameter\n",
|
"# If you want to use a custom model add the backend parameter\n",
|
||||||
"# Check https://docs.gpt4all.io/gpt4all_python.html for supported backends\n",
|
"# Check https://docs.gpt4all.io/gpt4all_python.html for supported backends\n",
|
||||||
"llm = GPT4All(model=local_path, backend=\"gptj\", callbacks=callbacks, verbose=True)"
|
"llm = GPT4All(model=local_path, backend=\"gptj\", callbacks=callbacks, verbose=True)"
|
||||||
@ -170,7 +165,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.11.2"
|
"version": "3.9.16"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
@ -19,7 +19,7 @@ class GPT4All(LLM):
|
|||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from langchain.llms import GPT4All
|
from langchain.llms import GPT4All
|
||||||
model = GPT4All(model="./models/gpt4all-model.bin", n_ctx=512, n_threads=8)
|
model = GPT4All(model="./models/gpt4all-model.bin", n_threads=8)
|
||||||
|
|
||||||
# Simplest invocation
|
# Simplest invocation
|
||||||
response = model("Once upon a time, ")
|
response = model("Once upon a time, ")
|
||||||
@ -30,7 +30,7 @@ class GPT4All(LLM):
|
|||||||
|
|
||||||
backend: Optional[str] = Field(None, alias="backend")
|
backend: Optional[str] = Field(None, alias="backend")
|
||||||
|
|
||||||
n_ctx: int = Field(512, alias="n_ctx")
|
max_tokens: int = Field(200, alias="max_tokens")
|
||||||
"""Token context window."""
|
"""Token context window."""
|
||||||
|
|
||||||
n_parts: int = Field(-1, alias="n_parts")
|
n_parts: int = Field(-1, alias="n_parts")
|
||||||
@ -61,10 +61,10 @@ class GPT4All(LLM):
|
|||||||
n_predict: Optional[int] = 256
|
n_predict: Optional[int] = 256
|
||||||
"""The maximum number of tokens to generate."""
|
"""The maximum number of tokens to generate."""
|
||||||
|
|
||||||
temp: Optional[float] = 0.8
|
temp: Optional[float] = 0.7
|
||||||
"""The temperature to use for sampling."""
|
"""The temperature to use for sampling."""
|
||||||
|
|
||||||
top_p: Optional[float] = 0.95
|
top_p: Optional[float] = 0.1
|
||||||
"""The top-p value to use for sampling."""
|
"""The top-p value to use for sampling."""
|
||||||
|
|
||||||
top_k: Optional[int] = 40
|
top_k: Optional[int] = 40
|
||||||
@ -79,19 +79,15 @@ class GPT4All(LLM):
|
|||||||
repeat_last_n: Optional[int] = 64
|
repeat_last_n: Optional[int] = 64
|
||||||
"Last n tokens to penalize"
|
"Last n tokens to penalize"
|
||||||
|
|
||||||
repeat_penalty: Optional[float] = 1.3
|
repeat_penalty: Optional[float] = 1.18
|
||||||
"""The penalty to apply to repeated tokens."""
|
"""The penalty to apply to repeated tokens."""
|
||||||
|
|
||||||
n_batch: int = Field(1, alias="n_batch")
|
n_batch: int = Field(8, alias="n_batch")
|
||||||
"""Batch size for prompt processing."""
|
"""Batch size for prompt processing."""
|
||||||
|
|
||||||
streaming: bool = False
|
streaming: bool = False
|
||||||
"""Whether to stream the results or not."""
|
"""Whether to stream the results or not."""
|
||||||
|
|
||||||
context_erase: float = 0.5
|
|
||||||
"""Leave (n_ctx * context_erase) tokens
|
|
||||||
starting from beginning if the context has run out."""
|
|
||||||
|
|
||||||
allow_download: bool = False
|
allow_download: bool = False
|
||||||
"""If model does not exist in ~/.cache/gpt4all/, download it."""
|
"""If model does not exist in ~/.cache/gpt4all/, download it."""
|
||||||
|
|
||||||
@ -105,7 +101,7 @@ class GPT4All(LLM):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _model_param_names() -> Set[str]:
|
def _model_param_names() -> Set[str]:
|
||||||
return {
|
return {
|
||||||
"n_ctx",
|
"max_tokens",
|
||||||
"n_predict",
|
"n_predict",
|
||||||
"top_k",
|
"top_k",
|
||||||
"top_p",
|
"top_p",
|
||||||
@ -113,12 +109,11 @@ class GPT4All(LLM):
|
|||||||
"n_batch",
|
"n_batch",
|
||||||
"repeat_penalty",
|
"repeat_penalty",
|
||||||
"repeat_last_n",
|
"repeat_last_n",
|
||||||
"context_erase",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def _default_params(self) -> Dict[str, Any]:
|
def _default_params(self) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"n_ctx": self.n_ctx,
|
"max_tokens": self.max_tokens,
|
||||||
"n_predict": self.n_predict,
|
"n_predict": self.n_predict,
|
||||||
"top_k": self.top_k,
|
"top_k": self.top_k,
|
||||||
"top_p": self.top_p,
|
"top_p": self.top_p,
|
||||||
@ -126,7 +121,6 @@ class GPT4All(LLM):
|
|||||||
"n_batch": self.n_batch,
|
"n_batch": self.n_batch,
|
||||||
"repeat_penalty": self.repeat_penalty,
|
"repeat_penalty": self.repeat_penalty,
|
||||||
"repeat_last_n": self.repeat_last_n,
|
"repeat_last_n": self.repeat_last_n,
|
||||||
"context_erase": self.context_erase,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
|
Loading…
Reference in New Issue
Block a user