mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
feat: Add with_history option for chatglm (#8048)
In certain 0-shot scenarios, the existing stateful language model can unintentionally send/accumulate the .history. This commit adds the "with_history" option to chatglm, allowing users to control the behavior of .history and prevent unintended accumulation. Possible reviewers @hwchase17 @baskaryan @mlot Refer to discussion over this thread: https://twitter.com/wey_gu/status/1681996149543276545?s=20
This commit is contained in:
parent
1f3b987860
commit
cf60cff1ef
@ -95,6 +95,22 @@
|
||||
"\n",
|
||||
"llm_chain.run(question)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"By Default, ChatGLM is statful to keep track of the conversation history and send the accumulated context to the model. To enable stateless mode, we could set ChatGLM.with_history as `False` explicitly."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm.with_history = False"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
@ -1,3 +1,4 @@
|
||||
import logging
|
||||
from typing import Any, List, Mapping, Optional
|
||||
|
||||
import requests
|
||||
@ -6,6 +7,8 @@ from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatGLM(LLM):
|
||||
"""ChatGLM LLM service.
|
||||
@ -34,6 +37,8 @@ class ChatGLM(LLM):
|
||||
"""History of the conversation"""
|
||||
top_p: float = 0.7
|
||||
"""Top P for nucleus sampling from 0 to 1"""
|
||||
with_history: bool = True
|
||||
"""Whether to use history or not"""
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
@ -85,7 +90,7 @@ class ChatGLM(LLM):
|
||||
payload.update(_model_kwargs)
|
||||
payload.update(kwargs)
|
||||
|
||||
# print("ChatGLM payload:", payload)
|
||||
logger.debug(f"ChatGLM payload: {payload}")
|
||||
|
||||
# call api
|
||||
try:
|
||||
@ -93,7 +98,7 @@ class ChatGLM(LLM):
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise ValueError(f"Error raised by inference endpoint: {e}")
|
||||
|
||||
# print("ChatGLM resp:", response)
|
||||
logger.debug(f"ChatGLM response: {response}")
|
||||
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed with response: {response}")
|
||||
@ -119,5 +124,6 @@ class ChatGLM(LLM):
|
||||
|
||||
if stop is not None:
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
if self.with_history:
|
||||
self.history = self.history + [[None, parsed_response["response"]]]
|
||||
return text
|
||||
|
Loading…
Reference in New Issue
Block a user