mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
community[patch]: Update root_validators ChatModels: ChatBaichuan, QianfanChatEndpoint, MiniMaxChat, ChatSparkLLM, ChatZhipuAI (#22853)
This PR updates root validators for: - ChatModels: ChatBaichuan, QianfanChatEndpoint, MiniMaxChat, ChatSparkLLM, ChatZhipuAI Issues #22819 --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
parent
cb6cf4b631
commit
bc4cd9c5cc
@ -89,7 +89,7 @@ class ChatBaichuan(BaseChatModel):
|
|||||||
|
|
||||||
baichuan_api_base: str = Field(default=DEFAULT_API_BASE)
|
baichuan_api_base: str = Field(default=DEFAULT_API_BASE)
|
||||||
"""Baichuan custom endpoints"""
|
"""Baichuan custom endpoints"""
|
||||||
baichuan_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
|
baichuan_api_key: SecretStr = Field(alias="api_key")
|
||||||
"""Baichuan API Key"""
|
"""Baichuan API Key"""
|
||||||
baichuan_secret_key: Optional[SecretStr] = None
|
baichuan_secret_key: Optional[SecretStr] = None
|
||||||
"""[DEPRECATED, keeping it for for backward compatibility] Baichuan Secret Key"""
|
"""[DEPRECATED, keeping it for for backward compatibility] Baichuan Secret Key"""
|
||||||
@ -142,7 +142,7 @@ class ChatBaichuan(BaseChatModel):
|
|||||||
values["model_kwargs"] = extra
|
values["model_kwargs"] = extra
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@root_validator()
|
@root_validator(pre=True)
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
values["baichuan_api_base"] = get_from_dict_or_env(
|
values["baichuan_api_base"] = get_from_dict_or_env(
|
||||||
values,
|
values,
|
||||||
@ -153,11 +153,10 @@ class ChatBaichuan(BaseChatModel):
|
|||||||
values["baichuan_api_key"] = convert_to_secret_str(
|
values["baichuan_api_key"] = convert_to_secret_str(
|
||||||
get_from_dict_or_env(
|
get_from_dict_or_env(
|
||||||
values,
|
values,
|
||||||
"baichuan_api_key",
|
["baichuan_api_key", "api_key"],
|
||||||
"BAICHUAN_API_KEY",
|
"BAICHUAN_API_KEY",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -135,7 +135,7 @@ class QianfanChatEndpoint(BaseChatModel):
|
|||||||
|
|
||||||
client: Any #: :meta private:
|
client: Any #: :meta private:
|
||||||
|
|
||||||
qianfan_ak: Optional[SecretStr] = Field(default=None, alias="api_key")
|
qianfan_ak: SecretStr = Field(alias="api_key")
|
||||||
"""Qianfan API KEY"""
|
"""Qianfan API KEY"""
|
||||||
qianfan_sk: Optional[SecretStr] = Field(default=None, alias="secret_key")
|
qianfan_sk: Optional[SecretStr] = Field(default=None, alias="secret_key")
|
||||||
"""Qianfan SECRET KEY"""
|
"""Qianfan SECRET KEY"""
|
||||||
@ -171,35 +171,43 @@ class QianfanChatEndpoint(BaseChatModel):
|
|||||||
|
|
||||||
allow_population_by_field_name = True
|
allow_population_by_field_name = True
|
||||||
|
|
||||||
@root_validator()
|
@root_validator(pre=True)
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
values["qianfan_ak"] = convert_to_secret_str(
|
values["qianfan_ak"] = convert_to_secret_str(
|
||||||
get_from_dict_or_env(
|
get_from_dict_or_env(
|
||||||
values,
|
values,
|
||||||
"qianfan_ak",
|
["qianfan_ak", "api_key"],
|
||||||
"QIANFAN_AK",
|
"QIANFAN_AK",
|
||||||
default="",
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
values["qianfan_sk"] = convert_to_secret_str(
|
values["qianfan_sk"] = convert_to_secret_str(
|
||||||
get_from_dict_or_env(
|
get_from_dict_or_env(
|
||||||
values,
|
values,
|
||||||
"qianfan_sk",
|
["qianfan_sk", "secret_key"],
|
||||||
"QIANFAN_SK",
|
"QIANFAN_SK",
|
||||||
default="",
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
default_values = {
|
||||||
|
name: field.default
|
||||||
|
for name, field in cls.__fields__.items()
|
||||||
|
if field.default is not None
|
||||||
|
}
|
||||||
|
default_values.update(values)
|
||||||
params = {
|
params = {
|
||||||
**values.get("init_kwargs", {}),
|
**values.get("init_kwargs", {}),
|
||||||
"model": values["model"],
|
"model": default_values.get("model"),
|
||||||
"stream": values["streaming"],
|
"stream": default_values.get("streaming"),
|
||||||
}
|
}
|
||||||
if values["qianfan_ak"].get_secret_value() != "":
|
if values["qianfan_ak"].get_secret_value() != "":
|
||||||
params["ak"] = values["qianfan_ak"].get_secret_value()
|
params["ak"] = values["qianfan_ak"].get_secret_value()
|
||||||
if values["qianfan_sk"].get_secret_value() != "":
|
if values["qianfan_sk"].get_secret_value() != "":
|
||||||
params["sk"] = values["qianfan_sk"].get_secret_value()
|
params["sk"] = values["qianfan_sk"].get_secret_value()
|
||||||
if values["endpoint"] is not None and values["endpoint"] != "":
|
if (
|
||||||
params["endpoint"] = values["endpoint"]
|
default_values.get("endpoint") is not None
|
||||||
|
and default_values["endpoint"] != ""
|
||||||
|
):
|
||||||
|
params["endpoint"] = default_values["endpoint"]
|
||||||
try:
|
try:
|
||||||
import qianfan
|
import qianfan
|
||||||
|
|
||||||
|
@ -166,7 +166,7 @@ class MiniMaxChat(BaseChatModel):
|
|||||||
)
|
)
|
||||||
minimax_group_id: Optional[str] = Field(default=None, alias="group_id")
|
minimax_group_id: Optional[str] = Field(default=None, alias="group_id")
|
||||||
"""[DEPRECATED, keeping it for for backward compatibility] Group Id"""
|
"""[DEPRECATED, keeping it for for backward compatibility] Group Id"""
|
||||||
minimax_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
|
minimax_api_key: SecretStr = Field(alias="api_key")
|
||||||
"""Minimax API Key"""
|
"""Minimax API Key"""
|
||||||
streaming: bool = False
|
streaming: bool = False
|
||||||
"""Whether to stream the results or not."""
|
"""Whether to stream the results or not."""
|
||||||
@ -176,14 +176,18 @@ class MiniMaxChat(BaseChatModel):
|
|||||||
|
|
||||||
allow_population_by_field_name = True
|
allow_population_by_field_name = True
|
||||||
|
|
||||||
@root_validator(allow_reuse=True)
|
@root_validator(pre=True, allow_reuse=True)
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate that api key and python package exists in environment."""
|
"""Validate that api key and python package exists in environment."""
|
||||||
values["minimax_api_key"] = convert_to_secret_str(
|
values["minimax_api_key"] = convert_to_secret_str(
|
||||||
get_from_dict_or_env(values, "minimax_api_key", "MINIMAX_API_KEY")
|
get_from_dict_or_env(
|
||||||
|
values,
|
||||||
|
["minimax_api_key", "api_key"],
|
||||||
|
"MINIMAX_API_KEY",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
values["minimax_group_id"] = get_from_dict_or_env(
|
values["minimax_group_id"] = get_from_dict_or_env(
|
||||||
values, "minimax_group_id", "MINIMAX_GROUP_ID"
|
values, ["minimax_group_id", "group_id"], "MINIMAX_GROUP_ID"
|
||||||
)
|
)
|
||||||
# Get custom api url from environment.
|
# Get custom api url from environment.
|
||||||
values["minimax_api_host"] = get_from_dict_or_env(
|
values["minimax_api_host"] = get_from_dict_or_env(
|
||||||
|
@ -195,21 +195,21 @@ class ChatSparkLLM(BaseChatModel):
|
|||||||
|
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@root_validator()
|
@root_validator(pre=True)
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
values["spark_app_id"] = get_from_dict_or_env(
|
values["spark_app_id"] = get_from_dict_or_env(
|
||||||
values,
|
values,
|
||||||
"spark_app_id",
|
["spark_app_id", "app_id"],
|
||||||
"IFLYTEK_SPARK_APP_ID",
|
"IFLYTEK_SPARK_APP_ID",
|
||||||
)
|
)
|
||||||
values["spark_api_key"] = get_from_dict_or_env(
|
values["spark_api_key"] = get_from_dict_or_env(
|
||||||
values,
|
values,
|
||||||
"spark_api_key",
|
["spark_api_key", "api_key"],
|
||||||
"IFLYTEK_SPARK_API_KEY",
|
"IFLYTEK_SPARK_API_KEY",
|
||||||
)
|
)
|
||||||
values["spark_api_secret"] = get_from_dict_or_env(
|
values["spark_api_secret"] = get_from_dict_or_env(
|
||||||
values,
|
values,
|
||||||
"spark_api_secret",
|
["spark_api_secret", "api_secret"],
|
||||||
"IFLYTEK_SPARK_API_SECRET",
|
"IFLYTEK_SPARK_API_SECRET",
|
||||||
)
|
)
|
||||||
values["spark_api_url"] = get_from_dict_or_env(
|
values["spark_api_url"] = get_from_dict_or_env(
|
||||||
@ -224,9 +224,15 @@ class ChatSparkLLM(BaseChatModel):
|
|||||||
"IFLYTEK_SPARK_LLM_DOMAIN",
|
"IFLYTEK_SPARK_LLM_DOMAIN",
|
||||||
SPARK_LLM_DOMAIN,
|
SPARK_LLM_DOMAIN,
|
||||||
)
|
)
|
||||||
|
|
||||||
# put extra params into model_kwargs
|
# put extra params into model_kwargs
|
||||||
values["model_kwargs"]["temperature"] = values["temperature"] or cls.temperature
|
default_values = {
|
||||||
values["model_kwargs"]["top_k"] = values["top_k"] or cls.top_k
|
name: field.default
|
||||||
|
for name, field in cls.__fields__.items()
|
||||||
|
if field.default is not None
|
||||||
|
}
|
||||||
|
values["model_kwargs"]["temperature"] = default_values.get("temperature")
|
||||||
|
values["model_kwargs"]["top_k"] = default_values.get("top_k")
|
||||||
|
|
||||||
values["client"] = _SparkLLMClient(
|
values["client"] = _SparkLLMClient(
|
||||||
app_id=values["spark_app_id"],
|
app_id=values["spark_app_id"],
|
||||||
|
@ -377,10 +377,10 @@ class ChatZhipuAI(BaseChatModel):
|
|||||||
|
|
||||||
allow_population_by_field_name = True
|
allow_population_by_field_name = True
|
||||||
|
|
||||||
@root_validator()
|
@root_validator(pre=True)
|
||||||
def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
values["zhipuai_api_key"] = get_from_dict_or_env(
|
values["zhipuai_api_key"] = get_from_dict_or_env(
|
||||||
values, "zhipuai_api_key", "ZHIPUAI_API_KEY"
|
values, ["zhipuai_api_key", "api_key"], "ZHIPUAI_API_KEY"
|
||||||
)
|
)
|
||||||
values["zhipuai_api_base"] = get_from_dict_or_env(
|
values["zhipuai_api_base"] = get_from_dict_or_env(
|
||||||
values, "zhipuai_api_base", "ZHIPUAI_API_BASE", default=ZHIPUAI_API_BASE
|
values, "zhipuai_api_base", "ZHIPUAI_API_BASE", default=ZHIPUAI_API_BASE
|
||||||
|
@ -7,7 +7,7 @@ from langchain_community.chat_models.baichuan import ChatBaichuan
|
|||||||
|
|
||||||
|
|
||||||
def test_chat_baichuan_default() -> None:
|
def test_chat_baichuan_default() -> None:
|
||||||
chat = ChatBaichuan(streaming=True)
|
chat = ChatBaichuan(streaming=True) # type: ignore[call-arg]
|
||||||
message = HumanMessage(content="请完整背诵将进酒,背诵5遍")
|
message = HumanMessage(content="请完整背诵将进酒,背诵5遍")
|
||||||
response = chat.invoke([message])
|
response = chat.invoke([message])
|
||||||
assert isinstance(response, AIMessage)
|
assert isinstance(response, AIMessage)
|
||||||
@ -15,7 +15,7 @@ def test_chat_baichuan_default() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_chat_baichuan_default_non_streaming() -> None:
|
def test_chat_baichuan_default_non_streaming() -> None:
|
||||||
chat = ChatBaichuan()
|
chat = ChatBaichuan() # type: ignore[call-arg]
|
||||||
message = HumanMessage(content="请完整背诵将进酒,背诵5遍")
|
message = HumanMessage(content="请完整背诵将进酒,背诵5遍")
|
||||||
response = chat.invoke([message])
|
response = chat.invoke([message])
|
||||||
assert isinstance(response, AIMessage)
|
assert isinstance(response, AIMessage)
|
||||||
@ -39,7 +39,7 @@ def test_chat_baichuan_turbo_non_streaming() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_chat_baichuan_with_temperature() -> None:
|
def test_chat_baichuan_with_temperature() -> None:
|
||||||
chat = ChatBaichuan(temperature=1.0)
|
chat = ChatBaichuan(temperature=1.0) # type: ignore[call-arg]
|
||||||
message = HumanMessage(content="Hello")
|
message = HumanMessage(content="Hello")
|
||||||
response = chat.invoke([message])
|
response = chat.invoke([message])
|
||||||
assert isinstance(response, AIMessage)
|
assert isinstance(response, AIMessage)
|
||||||
@ -47,7 +47,7 @@ def test_chat_baichuan_with_temperature() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_chat_baichuan_with_kwargs() -> None:
|
def test_chat_baichuan_with_kwargs() -> None:
|
||||||
chat = ChatBaichuan()
|
chat = ChatBaichuan() # type: ignore[call-arg]
|
||||||
message = HumanMessage(content="百川192K API是什么时候上线的?")
|
message = HumanMessage(content="百川192K API是什么时候上线的?")
|
||||||
response = chat.invoke(
|
response = chat.invoke(
|
||||||
[message], temperature=0.88, top_p=0.7, with_search_enhance=True
|
[message], temperature=0.88, top_p=0.7, with_search_enhance=True
|
||||||
@ -58,7 +58,7 @@ def test_chat_baichuan_with_kwargs() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_extra_kwargs() -> None:
|
def test_extra_kwargs() -> None:
|
||||||
chat = ChatBaichuan(temperature=0.88, top_p=0.7, with_search_enhance=True)
|
chat = ChatBaichuan(temperature=0.88, top_p=0.7, with_search_enhance=True) # type: ignore[call-arg]
|
||||||
assert chat.temperature == 0.88
|
assert chat.temperature == 0.88
|
||||||
assert chat.top_p == 0.7
|
assert chat.top_p == 0.7
|
||||||
assert chat.with_search_enhance is True
|
assert chat.with_search_enhance is True
|
||||||
|
@ -107,7 +107,7 @@ def test_baichuan_key_masked_when_passed_from_env(
|
|||||||
"""Test initialization with an API key provided via an env variable"""
|
"""Test initialization with an API key provided via an env variable"""
|
||||||
monkeypatch.setenv("BAICHUAN_API_KEY", "test-api-key")
|
monkeypatch.setenv("BAICHUAN_API_KEY", "test-api-key")
|
||||||
|
|
||||||
chat = ChatBaichuan()
|
chat = ChatBaichuan() # type: ignore[call-arg]
|
||||||
print(chat.baichuan_api_key, end="") # noqa: T201
|
print(chat.baichuan_api_key, end="") # noqa: T201
|
||||||
captured = capsys.readouterr()
|
captured = capsys.readouterr()
|
||||||
assert captured.out == "**********"
|
assert captured.out == "**********"
|
||||||
|
Loading…
Reference in New Issue
Block a user