From bc4cd9c5ccb46dac499f757cc6e2a00b3379c50f Mon Sep 17 00:00:00 2001 From: maang-h <55082429+maang-h@users.noreply.github.com> Date: Fri, 21 Jun 2024 00:36:41 +0800 Subject: [PATCH] community[patch]: Update root_validators ChatModels: ChatBaichuan, QianfanChatEndpoint, MiniMaxChat, ChatSparkLLM, ChatZhipuAI (#22853) This PR updates root validators for: - ChatModels: ChatBaichuan, QianfanChatEndpoint, MiniMaxChat, ChatSparkLLM, ChatZhipuAI Issues #22819 --------- Co-authored-by: Eugene Yurtsev --- .../chat_models/baichuan.py | 7 ++--- .../chat_models/baidu_qianfan_endpoint.py | 28 ++++++++++++------- .../chat_models/minimax.py | 12 +++++--- .../chat_models/sparkllm.py | 18 ++++++++---- .../chat_models/zhipuai.py | 4 +-- .../chat_models/test_baichuan.py | 10 +++---- .../unit_tests/chat_models/test_baichuan.py | 2 +- 7 files changed, 49 insertions(+), 32 deletions(-) diff --git a/libs/community/langchain_community/chat_models/baichuan.py b/libs/community/langchain_community/chat_models/baichuan.py index e881b36abb..91d1f76dfe 100644 --- a/libs/community/langchain_community/chat_models/baichuan.py +++ b/libs/community/langchain_community/chat_models/baichuan.py @@ -89,7 +89,7 @@ class ChatBaichuan(BaseChatModel): baichuan_api_base: str = Field(default=DEFAULT_API_BASE) """Baichuan custom endpoints""" - baichuan_api_key: Optional[SecretStr] = Field(default=None, alias="api_key") + baichuan_api_key: SecretStr = Field(alias="api_key") """Baichuan API Key""" baichuan_secret_key: Optional[SecretStr] = None """[DEPRECATED, keeping it for for backward compatibility] Baichuan Secret Key""" @@ -142,7 +142,7 @@ class ChatBaichuan(BaseChatModel): values["model_kwargs"] = extra return values - @root_validator() + @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: values["baichuan_api_base"] = get_from_dict_or_env( values, @@ -153,11 +153,10 @@ class ChatBaichuan(BaseChatModel): values["baichuan_api_key"] = convert_to_secret_str( get_from_dict_or_env( values, - "baichuan_api_key", + ["baichuan_api_key", "api_key"], "BAICHUAN_API_KEY", ) ) - return values @property diff --git a/libs/community/langchain_community/chat_models/baidu_qianfan_endpoint.py b/libs/community/langchain_community/chat_models/baidu_qianfan_endpoint.py index 9003140497..019acffd15 100644 --- a/libs/community/langchain_community/chat_models/baidu_qianfan_endpoint.py +++ b/libs/community/langchain_community/chat_models/baidu_qianfan_endpoint.py @@ -135,7 +135,7 @@ class QianfanChatEndpoint(BaseChatModel): client: Any #: :meta private: - qianfan_ak: Optional[SecretStr] = Field(default=None, alias="api_key") + qianfan_ak: SecretStr = Field(alias="api_key") """Qianfan API KEY""" qianfan_sk: Optional[SecretStr] = Field(default=None, alias="secret_key") """Qianfan SECRET KEY""" @@ -171,35 +171,43 @@ class QianfanChatEndpoint(BaseChatModel): allow_population_by_field_name = True - @root_validator() + @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: values["qianfan_ak"] = convert_to_secret_str( get_from_dict_or_env( values, - "qianfan_ak", + ["qianfan_ak", "api_key"], "QIANFAN_AK", - default="", ) ) values["qianfan_sk"] = convert_to_secret_str( get_from_dict_or_env( values, - "qianfan_sk", + ["qianfan_sk", "secret_key"], "QIANFAN_SK", - default="", ) ) + + default_values = { + name: field.default + for name, field in cls.__fields__.items() + if field.default is not None + } + default_values.update(values) params = { **values.get("init_kwargs", {}), - "model": values["model"], - "stream": values["streaming"], + "model": default_values.get("model"), + "stream": default_values.get("streaming"), } if values["qianfan_ak"].get_secret_value() != "": params["ak"] = values["qianfan_ak"].get_secret_value() if values["qianfan_sk"].get_secret_value() != "": params["sk"] = values["qianfan_sk"].get_secret_value() - if values["endpoint"] is not None and values["endpoint"] != "": - params["endpoint"] = values["endpoint"] + if ( + default_values.get("endpoint") is not None + and default_values["endpoint"] != "" + ): + params["endpoint"] = default_values["endpoint"] try: import qianfan diff --git a/libs/community/langchain_community/chat_models/minimax.py b/libs/community/langchain_community/chat_models/minimax.py index 8761b04d39..a2cf559c41 100644 --- a/libs/community/langchain_community/chat_models/minimax.py +++ b/libs/community/langchain_community/chat_models/minimax.py @@ -166,7 +166,7 @@ class MiniMaxChat(BaseChatModel): ) minimax_group_id: Optional[str] = Field(default=None, alias="group_id") """[DEPRECATED, keeping it for for backward compatibility] Group Id""" - minimax_api_key: Optional[SecretStr] = Field(default=None, alias="api_key") + minimax_api_key: SecretStr = Field(alias="api_key") """Minimax API Key""" streaming: bool = False """Whether to stream the results or not.""" @@ -176,14 +176,18 @@ class MiniMaxChat(BaseChatModel): allow_population_by_field_name = True - @root_validator(allow_reuse=True) + @root_validator(pre=True, allow_reuse=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" values["minimax_api_key"] = convert_to_secret_str( - get_from_dict_or_env(values, "minimax_api_key", "MINIMAX_API_KEY") + get_from_dict_or_env( + values, + ["minimax_api_key", "api_key"], + "MINIMAX_API_KEY", + ) ) values["minimax_group_id"] = get_from_dict_or_env( - values, "minimax_group_id", "MINIMAX_GROUP_ID" + values, ["minimax_group_id", "group_id"], "MINIMAX_GROUP_ID" ) # Get custom api url from environment. values["minimax_api_host"] = get_from_dict_or_env( diff --git a/libs/community/langchain_community/chat_models/sparkllm.py b/libs/community/langchain_community/chat_models/sparkllm.py index 20dc1380c9..b0e207b126 100644 --- a/libs/community/langchain_community/chat_models/sparkllm.py +++ b/libs/community/langchain_community/chat_models/sparkllm.py @@ -195,21 +195,21 @@ class ChatSparkLLM(BaseChatModel): return values - @root_validator() + @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: values["spark_app_id"] = get_from_dict_or_env( values, - "spark_app_id", + ["spark_app_id", "app_id"], "IFLYTEK_SPARK_APP_ID", ) values["spark_api_key"] = get_from_dict_or_env( values, - "spark_api_key", + ["spark_api_key", "api_key"], "IFLYTEK_SPARK_API_KEY", ) values["spark_api_secret"] = get_from_dict_or_env( values, - "spark_api_secret", + ["spark_api_secret", "api_secret"], "IFLYTEK_SPARK_API_SECRET", ) values["spark_api_url"] = get_from_dict_or_env( @@ -224,9 +224,15 @@ class ChatSparkLLM(BaseChatModel): "IFLYTEK_SPARK_LLM_DOMAIN", SPARK_LLM_DOMAIN, ) + # put extra params into model_kwargs - values["model_kwargs"]["temperature"] = values["temperature"] or cls.temperature - values["model_kwargs"]["top_k"] = values["top_k"] or cls.top_k + default_values = { + name: field.default + for name, field in cls.__fields__.items() + if field.default is not None + } + values["model_kwargs"]["temperature"] = default_values.get("temperature") + values["model_kwargs"]["top_k"] = default_values.get("top_k") values["client"] = _SparkLLMClient( app_id=values["spark_app_id"], diff --git a/libs/community/langchain_community/chat_models/zhipuai.py b/libs/community/langchain_community/chat_models/zhipuai.py index 062878ed2d..fd0648e912 100644 --- a/libs/community/langchain_community/chat_models/zhipuai.py +++ b/libs/community/langchain_community/chat_models/zhipuai.py @@ -377,10 +377,10 @@ class ChatZhipuAI(BaseChatModel): allow_population_by_field_name = True - @root_validator() + @root_validator(pre=True) def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]: values["zhipuai_api_key"] = get_from_dict_or_env( - values, "zhipuai_api_key", "ZHIPUAI_API_KEY" + values, ["zhipuai_api_key", "api_key"], "ZHIPUAI_API_KEY" ) values["zhipuai_api_base"] = get_from_dict_or_env( values, "zhipuai_api_base", "ZHIPUAI_API_BASE", default=ZHIPUAI_API_BASE diff --git a/libs/community/tests/integration_tests/chat_models/test_baichuan.py b/libs/community/tests/integration_tests/chat_models/test_baichuan.py index 62008a8ab2..715391c196 100644 --- a/libs/community/tests/integration_tests/chat_models/test_baichuan.py +++ b/libs/community/tests/integration_tests/chat_models/test_baichuan.py @@ -7,7 +7,7 @@ from langchain_community.chat_models.baichuan import ChatBaichuan def test_chat_baichuan_default() -> None: - chat = ChatBaichuan(streaming=True) + chat = ChatBaichuan(streaming=True) # type: ignore[call-arg] message = HumanMessage(content="请完整背诵将进酒,背诵5遍") response = chat.invoke([message]) assert isinstance(response, AIMessage) @@ -15,7 +15,7 @@ def test_chat_baichuan_default() -> None: def test_chat_baichuan_default_non_streaming() -> None: - chat = ChatBaichuan() + chat = ChatBaichuan() # type: ignore[call-arg] message = HumanMessage(content="请完整背诵将进酒,背诵5遍") response = chat.invoke([message]) assert isinstance(response, AIMessage) @@ -39,7 +39,7 @@ def test_chat_baichuan_turbo_non_streaming() -> None: def test_chat_baichuan_with_temperature() -> None: - chat = ChatBaichuan(temperature=1.0) + chat = ChatBaichuan(temperature=1.0) # type: ignore[call-arg] message = HumanMessage(content="Hello") response = chat.invoke([message]) assert isinstance(response, AIMessage) @@ -47,7 +47,7 @@ def test_chat_baichuan_with_temperature() -> None: def test_chat_baichuan_with_kwargs() -> None: - chat = ChatBaichuan() + chat = ChatBaichuan() # type: ignore[call-arg] message = HumanMessage(content="百川192K API是什么时候上线的?") response = chat.invoke( [message], temperature=0.88, top_p=0.7, with_search_enhance=True @@ -58,7 +58,7 @@ def test_chat_baichuan_with_kwargs() -> None: def test_extra_kwargs() -> None: - chat = ChatBaichuan(temperature=0.88, top_p=0.7, with_search_enhance=True) + chat = ChatBaichuan(temperature=0.88, top_p=0.7, with_search_enhance=True) # type: ignore[call-arg] assert chat.temperature == 0.88 assert chat.top_p == 0.7 assert chat.with_search_enhance is True diff --git a/libs/community/tests/unit_tests/chat_models/test_baichuan.py b/libs/community/tests/unit_tests/chat_models/test_baichuan.py index e79efac726..def23d56e5 100644 --- a/libs/community/tests/unit_tests/chat_models/test_baichuan.py +++ b/libs/community/tests/unit_tests/chat_models/test_baichuan.py @@ -107,7 +107,7 @@ def test_baichuan_key_masked_when_passed_from_env( """Test initialization with an API key provided via an env variable""" monkeypatch.setenv("BAICHUAN_API_KEY", "test-api-key") - chat = ChatBaichuan() + chat = ChatBaichuan() # type: ignore[call-arg] print(chat.baichuan_api_key, end="") # noqa: T201 captured = capsys.readouterr() assert captured.out == "**********"