community[patch]: Update root_validators ChatModels: ChatBaichuan, QianfanChatEndpoint, MiniMaxChat, ChatSparkLLM, ChatZhipuAI (#22853)

This PR updates root validators for:

- ChatModels: ChatBaichuan, QianfanChatEndpoint, MiniMaxChat,
ChatSparkLLM, ChatZhipuAI

Issues #22819

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
pull/23240/head
maang-h 2 months ago committed by GitHub
parent cb6cf4b631
commit bc4cd9c5cc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -89,7 +89,7 @@ class ChatBaichuan(BaseChatModel):
baichuan_api_base: str = Field(default=DEFAULT_API_BASE)
"""Baichuan custom endpoints"""
baichuan_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
baichuan_api_key: SecretStr = Field(alias="api_key")
"""Baichuan API Key"""
baichuan_secret_key: Optional[SecretStr] = None
"""[DEPRECATED, keeping it for for backward compatibility] Baichuan Secret Key"""
@ -142,7 +142,7 @@ class ChatBaichuan(BaseChatModel):
values["model_kwargs"] = extra
return values
@root_validator()
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
values["baichuan_api_base"] = get_from_dict_or_env(
values,
@ -153,11 +153,10 @@ class ChatBaichuan(BaseChatModel):
values["baichuan_api_key"] = convert_to_secret_str(
get_from_dict_or_env(
values,
"baichuan_api_key",
["baichuan_api_key", "api_key"],
"BAICHUAN_API_KEY",
)
)
return values
@property

@ -135,7 +135,7 @@ class QianfanChatEndpoint(BaseChatModel):
client: Any #: :meta private:
qianfan_ak: Optional[SecretStr] = Field(default=None, alias="api_key")
qianfan_ak: SecretStr = Field(alias="api_key")
"""Qianfan API KEY"""
qianfan_sk: Optional[SecretStr] = Field(default=None, alias="secret_key")
"""Qianfan SECRET KEY"""
@ -171,35 +171,43 @@ class QianfanChatEndpoint(BaseChatModel):
allow_population_by_field_name = True
@root_validator()
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
values["qianfan_ak"] = convert_to_secret_str(
get_from_dict_or_env(
values,
"qianfan_ak",
["qianfan_ak", "api_key"],
"QIANFAN_AK",
default="",
)
)
values["qianfan_sk"] = convert_to_secret_str(
get_from_dict_or_env(
values,
"qianfan_sk",
["qianfan_sk", "secret_key"],
"QIANFAN_SK",
default="",
)
)
default_values = {
name: field.default
for name, field in cls.__fields__.items()
if field.default is not None
}
default_values.update(values)
params = {
**values.get("init_kwargs", {}),
"model": values["model"],
"stream": values["streaming"],
"model": default_values.get("model"),
"stream": default_values.get("streaming"),
}
if values["qianfan_ak"].get_secret_value() != "":
params["ak"] = values["qianfan_ak"].get_secret_value()
if values["qianfan_sk"].get_secret_value() != "":
params["sk"] = values["qianfan_sk"].get_secret_value()
if values["endpoint"] is not None and values["endpoint"] != "":
params["endpoint"] = values["endpoint"]
if (
default_values.get("endpoint") is not None
and default_values["endpoint"] != ""
):
params["endpoint"] = default_values["endpoint"]
try:
import qianfan

@ -166,7 +166,7 @@ class MiniMaxChat(BaseChatModel):
)
minimax_group_id: Optional[str] = Field(default=None, alias="group_id")
"""[DEPRECATED, keeping it for for backward compatibility] Group Id"""
minimax_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
minimax_api_key: SecretStr = Field(alias="api_key")
"""Minimax API Key"""
streaming: bool = False
"""Whether to stream the results or not."""
@ -176,14 +176,18 @@ class MiniMaxChat(BaseChatModel):
allow_population_by_field_name = True
@root_validator(allow_reuse=True)
@root_validator(pre=True, allow_reuse=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["minimax_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "minimax_api_key", "MINIMAX_API_KEY")
get_from_dict_or_env(
values,
["minimax_api_key", "api_key"],
"MINIMAX_API_KEY",
)
)
values["minimax_group_id"] = get_from_dict_or_env(
values, "minimax_group_id", "MINIMAX_GROUP_ID"
values, ["minimax_group_id", "group_id"], "MINIMAX_GROUP_ID"
)
# Get custom api url from environment.
values["minimax_api_host"] = get_from_dict_or_env(

@ -195,21 +195,21 @@ class ChatSparkLLM(BaseChatModel):
return values
@root_validator()
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
values["spark_app_id"] = get_from_dict_or_env(
values,
"spark_app_id",
["spark_app_id", "app_id"],
"IFLYTEK_SPARK_APP_ID",
)
values["spark_api_key"] = get_from_dict_or_env(
values,
"spark_api_key",
["spark_api_key", "api_key"],
"IFLYTEK_SPARK_API_KEY",
)
values["spark_api_secret"] = get_from_dict_or_env(
values,
"spark_api_secret",
["spark_api_secret", "api_secret"],
"IFLYTEK_SPARK_API_SECRET",
)
values["spark_api_url"] = get_from_dict_or_env(
@ -224,9 +224,15 @@ class ChatSparkLLM(BaseChatModel):
"IFLYTEK_SPARK_LLM_DOMAIN",
SPARK_LLM_DOMAIN,
)
# put extra params into model_kwargs
values["model_kwargs"]["temperature"] = values["temperature"] or cls.temperature
values["model_kwargs"]["top_k"] = values["top_k"] or cls.top_k
default_values = {
name: field.default
for name, field in cls.__fields__.items()
if field.default is not None
}
values["model_kwargs"]["temperature"] = default_values.get("temperature")
values["model_kwargs"]["top_k"] = default_values.get("top_k")
values["client"] = _SparkLLMClient(
app_id=values["spark_app_id"],

@ -377,10 +377,10 @@ class ChatZhipuAI(BaseChatModel):
allow_population_by_field_name = True
@root_validator()
@root_validator(pre=True)
def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]:
values["zhipuai_api_key"] = get_from_dict_or_env(
values, "zhipuai_api_key", "ZHIPUAI_API_KEY"
values, ["zhipuai_api_key", "api_key"], "ZHIPUAI_API_KEY"
)
values["zhipuai_api_base"] = get_from_dict_or_env(
values, "zhipuai_api_base", "ZHIPUAI_API_BASE", default=ZHIPUAI_API_BASE

@ -7,7 +7,7 @@ from langchain_community.chat_models.baichuan import ChatBaichuan
def test_chat_baichuan_default() -> None:
chat = ChatBaichuan(streaming=True)
chat = ChatBaichuan(streaming=True) # type: ignore[call-arg]
message = HumanMessage(content="请完整背诵将进酒背诵5遍")
response = chat.invoke([message])
assert isinstance(response, AIMessage)
@ -15,7 +15,7 @@ def test_chat_baichuan_default() -> None:
def test_chat_baichuan_default_non_streaming() -> None:
chat = ChatBaichuan()
chat = ChatBaichuan() # type: ignore[call-arg]
message = HumanMessage(content="请完整背诵将进酒背诵5遍")
response = chat.invoke([message])
assert isinstance(response, AIMessage)
@ -39,7 +39,7 @@ def test_chat_baichuan_turbo_non_streaming() -> None:
def test_chat_baichuan_with_temperature() -> None:
chat = ChatBaichuan(temperature=1.0)
chat = ChatBaichuan(temperature=1.0) # type: ignore[call-arg]
message = HumanMessage(content="Hello")
response = chat.invoke([message])
assert isinstance(response, AIMessage)
@ -47,7 +47,7 @@ def test_chat_baichuan_with_temperature() -> None:
def test_chat_baichuan_with_kwargs() -> None:
chat = ChatBaichuan()
chat = ChatBaichuan() # type: ignore[call-arg]
message = HumanMessage(content="百川192K API是什么时候上线的")
response = chat.invoke(
[message], temperature=0.88, top_p=0.7, with_search_enhance=True
@ -58,7 +58,7 @@ def test_chat_baichuan_with_kwargs() -> None:
def test_extra_kwargs() -> None:
chat = ChatBaichuan(temperature=0.88, top_p=0.7, with_search_enhance=True)
chat = ChatBaichuan(temperature=0.88, top_p=0.7, with_search_enhance=True) # type: ignore[call-arg]
assert chat.temperature == 0.88
assert chat.top_p == 0.7
assert chat.with_search_enhance is True

@ -107,7 +107,7 @@ def test_baichuan_key_masked_when_passed_from_env(
"""Test initialization with an API key provided via an env variable"""
monkeypatch.setenv("BAICHUAN_API_KEY", "test-api-key")
chat = ChatBaichuan()
chat = ChatBaichuan() # type: ignore[call-arg]
print(chat.baichuan_api_key, end="") # noqa: T201
captured = capsys.readouterr()
assert captured.out == "**********"

Loading…
Cancel
Save