community[patch]: Update root_validators ChatModels: ChatBaichuan, QianfanChatEndpoint, MiniMaxChat, ChatSparkLLM, ChatZhipuAI (#22853)

This PR updates root validators for:

- ChatModels: ChatBaichuan, QianfanChatEndpoint, MiniMaxChat,
ChatSparkLLM, ChatZhipuAI

Issues #22819

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
maang-h 2024-06-21 00:36:41 +08:00 committed by GitHub
parent cb6cf4b631
commit bc4cd9c5cc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 49 additions and 32 deletions

View File

@ -89,7 +89,7 @@ class ChatBaichuan(BaseChatModel):
baichuan_api_base: str = Field(default=DEFAULT_API_BASE) baichuan_api_base: str = Field(default=DEFAULT_API_BASE)
"""Baichuan custom endpoints""" """Baichuan custom endpoints"""
baichuan_api_key: Optional[SecretStr] = Field(default=None, alias="api_key") baichuan_api_key: SecretStr = Field(alias="api_key")
"""Baichuan API Key""" """Baichuan API Key"""
baichuan_secret_key: Optional[SecretStr] = None baichuan_secret_key: Optional[SecretStr] = None
"""[DEPRECATED, keeping it for for backward compatibility] Baichuan Secret Key""" """[DEPRECATED, keeping it for for backward compatibility] Baichuan Secret Key"""
@ -142,7 +142,7 @@ class ChatBaichuan(BaseChatModel):
values["model_kwargs"] = extra values["model_kwargs"] = extra
return values return values
@root_validator() @root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
values["baichuan_api_base"] = get_from_dict_or_env( values["baichuan_api_base"] = get_from_dict_or_env(
values, values,
@ -153,11 +153,10 @@ class ChatBaichuan(BaseChatModel):
values["baichuan_api_key"] = convert_to_secret_str( values["baichuan_api_key"] = convert_to_secret_str(
get_from_dict_or_env( get_from_dict_or_env(
values, values,
"baichuan_api_key", ["baichuan_api_key", "api_key"],
"BAICHUAN_API_KEY", "BAICHUAN_API_KEY",
) )
) )
return values return values
@property @property

View File

@ -135,7 +135,7 @@ class QianfanChatEndpoint(BaseChatModel):
client: Any #: :meta private: client: Any #: :meta private:
qianfan_ak: Optional[SecretStr] = Field(default=None, alias="api_key") qianfan_ak: SecretStr = Field(alias="api_key")
"""Qianfan API KEY""" """Qianfan API KEY"""
qianfan_sk: Optional[SecretStr] = Field(default=None, alias="secret_key") qianfan_sk: Optional[SecretStr] = Field(default=None, alias="secret_key")
"""Qianfan SECRET KEY""" """Qianfan SECRET KEY"""
@ -171,35 +171,43 @@ class QianfanChatEndpoint(BaseChatModel):
allow_population_by_field_name = True allow_population_by_field_name = True
@root_validator() @root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
values["qianfan_ak"] = convert_to_secret_str( values["qianfan_ak"] = convert_to_secret_str(
get_from_dict_or_env( get_from_dict_or_env(
values, values,
"qianfan_ak", ["qianfan_ak", "api_key"],
"QIANFAN_AK", "QIANFAN_AK",
default="",
) )
) )
values["qianfan_sk"] = convert_to_secret_str( values["qianfan_sk"] = convert_to_secret_str(
get_from_dict_or_env( get_from_dict_or_env(
values, values,
"qianfan_sk", ["qianfan_sk", "secret_key"],
"QIANFAN_SK", "QIANFAN_SK",
default="",
) )
) )
default_values = {
name: field.default
for name, field in cls.__fields__.items()
if field.default is not None
}
default_values.update(values)
params = { params = {
**values.get("init_kwargs", {}), **values.get("init_kwargs", {}),
"model": values["model"], "model": default_values.get("model"),
"stream": values["streaming"], "stream": default_values.get("streaming"),
} }
if values["qianfan_ak"].get_secret_value() != "": if values["qianfan_ak"].get_secret_value() != "":
params["ak"] = values["qianfan_ak"].get_secret_value() params["ak"] = values["qianfan_ak"].get_secret_value()
if values["qianfan_sk"].get_secret_value() != "": if values["qianfan_sk"].get_secret_value() != "":
params["sk"] = values["qianfan_sk"].get_secret_value() params["sk"] = values["qianfan_sk"].get_secret_value()
if values["endpoint"] is not None and values["endpoint"] != "": if (
params["endpoint"] = values["endpoint"] default_values.get("endpoint") is not None
and default_values["endpoint"] != ""
):
params["endpoint"] = default_values["endpoint"]
try: try:
import qianfan import qianfan

View File

@ -166,7 +166,7 @@ class MiniMaxChat(BaseChatModel):
) )
minimax_group_id: Optional[str] = Field(default=None, alias="group_id") minimax_group_id: Optional[str] = Field(default=None, alias="group_id")
"""[DEPRECATED, keeping it for for backward compatibility] Group Id""" """[DEPRECATED, keeping it for for backward compatibility] Group Id"""
minimax_api_key: Optional[SecretStr] = Field(default=None, alias="api_key") minimax_api_key: SecretStr = Field(alias="api_key")
"""Minimax API Key""" """Minimax API Key"""
streaming: bool = False streaming: bool = False
"""Whether to stream the results or not.""" """Whether to stream the results or not."""
@ -176,14 +176,18 @@ class MiniMaxChat(BaseChatModel):
allow_population_by_field_name = True allow_population_by_field_name = True
@root_validator(allow_reuse=True) @root_validator(pre=True, allow_reuse=True)
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
values["minimax_api_key"] = convert_to_secret_str( values["minimax_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "minimax_api_key", "MINIMAX_API_KEY") get_from_dict_or_env(
values,
["minimax_api_key", "api_key"],
"MINIMAX_API_KEY",
)
) )
values["minimax_group_id"] = get_from_dict_or_env( values["minimax_group_id"] = get_from_dict_or_env(
values, "minimax_group_id", "MINIMAX_GROUP_ID" values, ["minimax_group_id", "group_id"], "MINIMAX_GROUP_ID"
) )
# Get custom api url from environment. # Get custom api url from environment.
values["minimax_api_host"] = get_from_dict_or_env( values["minimax_api_host"] = get_from_dict_or_env(

View File

@ -195,21 +195,21 @@ class ChatSparkLLM(BaseChatModel):
return values return values
@root_validator() @root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
values["spark_app_id"] = get_from_dict_or_env( values["spark_app_id"] = get_from_dict_or_env(
values, values,
"spark_app_id", ["spark_app_id", "app_id"],
"IFLYTEK_SPARK_APP_ID", "IFLYTEK_SPARK_APP_ID",
) )
values["spark_api_key"] = get_from_dict_or_env( values["spark_api_key"] = get_from_dict_or_env(
values, values,
"spark_api_key", ["spark_api_key", "api_key"],
"IFLYTEK_SPARK_API_KEY", "IFLYTEK_SPARK_API_KEY",
) )
values["spark_api_secret"] = get_from_dict_or_env( values["spark_api_secret"] = get_from_dict_or_env(
values, values,
"spark_api_secret", ["spark_api_secret", "api_secret"],
"IFLYTEK_SPARK_API_SECRET", "IFLYTEK_SPARK_API_SECRET",
) )
values["spark_api_url"] = get_from_dict_or_env( values["spark_api_url"] = get_from_dict_or_env(
@ -224,9 +224,15 @@ class ChatSparkLLM(BaseChatModel):
"IFLYTEK_SPARK_LLM_DOMAIN", "IFLYTEK_SPARK_LLM_DOMAIN",
SPARK_LLM_DOMAIN, SPARK_LLM_DOMAIN,
) )
# put extra params into model_kwargs # put extra params into model_kwargs
values["model_kwargs"]["temperature"] = values["temperature"] or cls.temperature default_values = {
values["model_kwargs"]["top_k"] = values["top_k"] or cls.top_k name: field.default
for name, field in cls.__fields__.items()
if field.default is not None
}
values["model_kwargs"]["temperature"] = default_values.get("temperature")
values["model_kwargs"]["top_k"] = default_values.get("top_k")
values["client"] = _SparkLLMClient( values["client"] = _SparkLLMClient(
app_id=values["spark_app_id"], app_id=values["spark_app_id"],

View File

@ -377,10 +377,10 @@ class ChatZhipuAI(BaseChatModel):
allow_population_by_field_name = True allow_population_by_field_name = True
@root_validator() @root_validator(pre=True)
def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]: def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]:
values["zhipuai_api_key"] = get_from_dict_or_env( values["zhipuai_api_key"] = get_from_dict_or_env(
values, "zhipuai_api_key", "ZHIPUAI_API_KEY" values, ["zhipuai_api_key", "api_key"], "ZHIPUAI_API_KEY"
) )
values["zhipuai_api_base"] = get_from_dict_or_env( values["zhipuai_api_base"] = get_from_dict_or_env(
values, "zhipuai_api_base", "ZHIPUAI_API_BASE", default=ZHIPUAI_API_BASE values, "zhipuai_api_base", "ZHIPUAI_API_BASE", default=ZHIPUAI_API_BASE

View File

@ -7,7 +7,7 @@ from langchain_community.chat_models.baichuan import ChatBaichuan
def test_chat_baichuan_default() -> None: def test_chat_baichuan_default() -> None:
chat = ChatBaichuan(streaming=True) chat = ChatBaichuan(streaming=True) # type: ignore[call-arg]
message = HumanMessage(content="请完整背诵将进酒背诵5遍") message = HumanMessage(content="请完整背诵将进酒背诵5遍")
response = chat.invoke([message]) response = chat.invoke([message])
assert isinstance(response, AIMessage) assert isinstance(response, AIMessage)
@ -15,7 +15,7 @@ def test_chat_baichuan_default() -> None:
def test_chat_baichuan_default_non_streaming() -> None: def test_chat_baichuan_default_non_streaming() -> None:
chat = ChatBaichuan() chat = ChatBaichuan() # type: ignore[call-arg]
message = HumanMessage(content="请完整背诵将进酒背诵5遍") message = HumanMessage(content="请完整背诵将进酒背诵5遍")
response = chat.invoke([message]) response = chat.invoke([message])
assert isinstance(response, AIMessage) assert isinstance(response, AIMessage)
@ -39,7 +39,7 @@ def test_chat_baichuan_turbo_non_streaming() -> None:
def test_chat_baichuan_with_temperature() -> None: def test_chat_baichuan_with_temperature() -> None:
chat = ChatBaichuan(temperature=1.0) chat = ChatBaichuan(temperature=1.0) # type: ignore[call-arg]
message = HumanMessage(content="Hello") message = HumanMessage(content="Hello")
response = chat.invoke([message]) response = chat.invoke([message])
assert isinstance(response, AIMessage) assert isinstance(response, AIMessage)
@ -47,7 +47,7 @@ def test_chat_baichuan_with_temperature() -> None:
def test_chat_baichuan_with_kwargs() -> None: def test_chat_baichuan_with_kwargs() -> None:
chat = ChatBaichuan() chat = ChatBaichuan() # type: ignore[call-arg]
message = HumanMessage(content="百川192K API是什么时候上线的") message = HumanMessage(content="百川192K API是什么时候上线的")
response = chat.invoke( response = chat.invoke(
[message], temperature=0.88, top_p=0.7, with_search_enhance=True [message], temperature=0.88, top_p=0.7, with_search_enhance=True
@ -58,7 +58,7 @@ def test_chat_baichuan_with_kwargs() -> None:
def test_extra_kwargs() -> None: def test_extra_kwargs() -> None:
chat = ChatBaichuan(temperature=0.88, top_p=0.7, with_search_enhance=True) chat = ChatBaichuan(temperature=0.88, top_p=0.7, with_search_enhance=True) # type: ignore[call-arg]
assert chat.temperature == 0.88 assert chat.temperature == 0.88
assert chat.top_p == 0.7 assert chat.top_p == 0.7
assert chat.with_search_enhance is True assert chat.with_search_enhance is True

View File

@ -107,7 +107,7 @@ def test_baichuan_key_masked_when_passed_from_env(
"""Test initialization with an API key provided via an env variable""" """Test initialization with an API key provided via an env variable"""
monkeypatch.setenv("BAICHUAN_API_KEY", "test-api-key") monkeypatch.setenv("BAICHUAN_API_KEY", "test-api-key")
chat = ChatBaichuan() chat = ChatBaichuan() # type: ignore[call-arg]
print(chat.baichuan_api_key, end="") # noqa: T201 print(chat.baichuan_api_key, end="") # noqa: T201
captured = capsys.readouterr() captured = capsys.readouterr()
assert captured.out == "**********" assert captured.out == "**********"