feat: Add with_history option for chatglm (#8048)

In certain 0-shot scenarios, the existing stateful language model can
unintentionally send/accumulate the .history.

This commit adds the "with_history" option to chatglm, allowing users to
control the behavior of .history and prevent unintended accumulation.

Possible reviewers @hwchase17 @baskaryan @mlot

Refer to discussion over this thread:
https://twitter.com/wey_gu/status/1681996149543276545?s=20
This commit is contained in:
Wey Gu 2023-07-21 13:25:37 +08:00 committed by GitHub
parent 1f3b987860
commit cf60cff1ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 25 additions and 3 deletions

View File

@ -95,6 +95,22 @@
"\n",
"llm_chain.run(question)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"By Default, ChatGLM is statful to keep track of the conversation history and send the accumulated context to the model. To enable stateless mode, we could set ChatGLM.with_history as `False` explicitly."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"llm.with_history = False"
]
}
],
"metadata": {

View File

@ -1,3 +1,4 @@
import logging
from typing import Any, List, Mapping, Optional
import requests
@ -6,6 +7,8 @@ from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
logger = logging.getLogger(__name__)
class ChatGLM(LLM):
"""ChatGLM LLM service.
@ -34,6 +37,8 @@ class ChatGLM(LLM):
"""History of the conversation"""
top_p: float = 0.7
"""Top P for nucleus sampling from 0 to 1"""
with_history: bool = True
"""Whether to use history or not"""
@property
def _llm_type(self) -> str:
@ -85,7 +90,7 @@ class ChatGLM(LLM):
payload.update(_model_kwargs)
payload.update(kwargs)
# print("ChatGLM payload:", payload)
logger.debug(f"ChatGLM payload: {payload}")
# call api
try:
@ -93,7 +98,7 @@ class ChatGLM(LLM):
except requests.exceptions.RequestException as e:
raise ValueError(f"Error raised by inference endpoint: {e}")
# print("ChatGLM resp:", response)
logger.debug(f"ChatGLM response: {response}")
if response.status_code != 200:
raise ValueError(f"Failed with response: {response}")
@ -119,5 +124,6 @@ class ChatGLM(LLM):
if stop is not None:
text = enforce_stop_tokens(text, stop)
if self.with_history:
self.history = self.history + [[None, parsed_response["response"]]]
return text