community[patch]: truncate zhipuai `temperature` and `top_p` parameters to [0.01, 0.99] (#20261)

ZhipuAI API only accepts `temperature` parameter between `(0, 1)` open
interval, and if `0` is passed, it responds with status code `400`.

However, 0 and 1 is often accepted by other APIs, for example, OpenAI
allows `[0, 2]` for temperature closed range.

This PR truncates temperature parameter passed to `[0.01, 0.99]` to
improve the compatibility between langchain's ecosystem's and ZhipuAI
(e.g., ragas `evaluate` often generates temperature 0, which results in
a lot of 400 invalid responses). The PR also truncates `top_p` parameter
since it has the same restriction.

Reference: [glm-4 doc](https://open.bigmodel.cn/dev/api#glm-4) (which
unfortunately is in Chinese though).

---------

Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
pull/20633/head^2
Congyu 3 months ago committed by GitHub
parent d5c22b80a5
commit dd5139e304
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -148,6 +148,20 @@ def _convert_delta_to_message_chunk(
return default_class(content=content)
def _truncate_params(payload: Dict[str, Any]) -> None:
"""Truncate temperature and top_p parameters between [0.01, 0.99].
ZhipuAI only support temperature / top_p between (0, 1) open interval,
so we truncate them to [0.01, 0.99].
"""
temperature = payload.get("temperature")
top_p = payload.get("top_p")
if temperature is not None:
payload["temperature"] = max(0.01, min(0.99, temperature))
if top_p is not None:
payload["top_p"] = max(0.01, min(0.99, top_p))
class ChatZhipuAI(BaseChatModel):
"""
`ZhipuAI` large language chat models API.
@ -213,7 +227,7 @@ class ChatZhipuAI(BaseChatModel):
model_name: Optional[str] = Field(default="glm-4", alias="model")
"""
Model name to use, see 'https://open.bigmodel.cn/dev/api#language'.
or you can use any finetune model of glm series.
Alternatively, you can use any fine-tuned model from the GLM series.
"""
temperature: float = 0.95
@ -309,6 +323,7 @@ class ChatZhipuAI(BaseChatModel):
"messages": message_dicts,
"stream": False,
}
_truncate_params(payload)
headers = {
"Authorization": _get_jwt_token(self.zhipuai_api_key),
"Accept": "application/json",
@ -334,6 +349,7 @@ class ChatZhipuAI(BaseChatModel):
raise ValueError("Did not find zhipu_api_base.")
message_dicts, params = self._create_message_dicts(messages, stop)
payload = {**params, **kwargs, "messages": message_dicts, "stream": True}
_truncate_params(payload)
headers = {
"Authorization": _get_jwt_token(self.zhipuai_api_key),
"Accept": "application/json",
@ -394,6 +410,7 @@ class ChatZhipuAI(BaseChatModel):
"messages": message_dicts,
"stream": False,
}
_truncate_params(payload)
headers = {
"Authorization": _get_jwt_token(self.zhipuai_api_key),
"Accept": "application/json",
@ -418,6 +435,7 @@ class ChatZhipuAI(BaseChatModel):
raise ValueError("Did not find zhipu_api_base.")
message_dicts, params = self._create_message_dicts(messages, stop)
payload = {**params, **kwargs, "messages": message_dicts, "stream": True}
_truncate_params(payload)
headers = {
"Authorization": _get_jwt_token(self.zhipuai_api_key),
"Accept": "application/json",

Loading…
Cancel
Save