diff --git a/libs/langchain/langchain/chat_models/baichuan.py b/libs/langchain/langchain/chat_models/baichuan.py index 39b14d2d11..91be5c38cf 100644 --- a/libs/langchain/langchain/chat_models/baichuan.py +++ b/libs/langchain/langchain/chat_models/baichuan.py @@ -2,13 +2,13 @@ import hashlib import json import logging import time -from typing import Any, Dict, Iterator, List, Mapping, Optional, Type +from typing import Any, Dict, Iterator, List, Mapping, Optional, Type, Union import requests from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.chat_models.base import BaseChatModel, _generate_from_stream -from langchain.pydantic_v1 import Field, root_validator +from langchain.pydantic_v1 import Field, SecretStr, root_validator from langchain.schema import ( AIMessage, BaseMessage, @@ -29,7 +29,7 @@ from langchain.utils import get_from_dict_or_env, get_pydantic_field_names logger = logging.getLogger(__name__) -def convert_message_to_dict(message: BaseMessage) -> dict: +def _convert_message_to_dict(message: BaseMessage) -> dict: message_dict: Dict[str, Any] if isinstance(message, ChatMessage): message_dict = {"role": message.role, "content": message.content} @@ -69,6 +69,21 @@ def _convert_delta_to_message_chunk( return default_class(content=content) +def _to_secret(value: Union[SecretStr, str]) -> SecretStr: + """Convert a string to a SecretStr if needed.""" + if isinstance(value, SecretStr): + return value + return SecretStr(value) + + +# signature generation +def _signature(secret_key: SecretStr, payload: Dict[str, Any], timestamp: int) -> str: + input_str = secret_key.get_secret_value() + json.dumps(payload) + str(timestamp) + md5 = hashlib.md5() + md5.update(input_str.encode("utf-8")) + return md5.hexdigest() + + class ChatBaichuan(BaseChatModel): """Baichuan chat models API by Baichuan Intelligent Technology. @@ -90,21 +105,25 @@ class ChatBaichuan(BaseChatModel): """Baichuan custom endpoints""" baichuan_api_key: Optional[str] = None """Baichuan API Key""" - baichuan_secret_key: Optional[str] = None + baichuan_secret_key: Optional[SecretStr] = None """Baichuan Secret Key""" - streaming: Optional[bool] = False - """streaming mode.""" - request_timeout: Optional[int] = 60 + streaming: bool = False + """Whether to stream the results or not.""" + request_timeout: int = 60 """request timeout for chat http requests""" model = "Baichuan2-53B" """model name of Baichuan, default is `Baichuan2-53B`.""" temperature: float = 0.3 + """What sampling temperature to use.""" top_k: int = 5 + """What search sampling control to use.""" top_p: float = 0.85 + """What probability mass to use.""" with_search_enhance: bool = False """Whether to use search enhance, default is False.""" model_kwargs: Dict[str, Any] = Field(default_factory=dict) + """Holds any model parameters valid for API call not explicitly specified.""" class Config: """Configuration for this pydantic object.""" @@ -149,10 +168,12 @@ class ChatBaichuan(BaseChatModel): "baichuan_api_key", "BAICHUAN_API_KEY", ) - values["baichuan_secret_key"] = get_from_dict_or_env( - values, - "baichuan_secret_key", - "BAICHUAN_SECRET_KEY", + values["baichuan_secret_key"] = _to_secret( + get_from_dict_or_env( + values, + "baichuan_secret_key", + "BAICHUAN_SECRET_KEY", + ) ) return values @@ -169,15 +190,6 @@ class ChatBaichuan(BaseChatModel): return {**normal_params, **self.model_kwargs} - def _signature(self, data: Dict[str, Any], timestamp: int) -> str: - if self.baichuan_secret_key is None: - raise ValueError("Baichuan secret key is not set.") - - input_str = self.baichuan_secret_key + json.dumps(data) + str(timestamp) - md5 = hashlib.md5() - md5.update(input_str.encode("utf-8")) - return md5.hexdigest() - def _generate( self, messages: List[BaseMessage], @@ -224,6 +236,9 @@ class ChatBaichuan(BaseChatModel): run_manager.on_llm_new_token(chunk.content) def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response: + if self.baichuan_secret_key is None: + raise ValueError("Baichuan secret key is not set.") + parameters = {**self._default_params, **kwargs} model = parameters.pop("model") @@ -231,7 +246,7 @@ class ChatBaichuan(BaseChatModel): payload = { "model": model, - "messages": [convert_message_to_dict(m) for m in messages], + "messages": [_convert_message_to_dict(m) for m in messages], "parameters": parameters, } @@ -249,7 +264,11 @@ class ChatBaichuan(BaseChatModel): "Content-Type": "application/json", "Authorization": f"Bearer {self.baichuan_api_key}", "X-BC-Timestamp": str(timestamp), - "X-BC-Signature": self._signature(payload, timestamp), + "X-BC-Signature": _signature( + secret_key=self.baichuan_secret_key, + payload=payload, + timestamp=timestamp, + ), "X-BC-Sign-Algo": "MD5", **headers, }, diff --git a/libs/langchain/tests/integration_tests/chat_models/test_baichuan.py b/libs/langchain/tests/integration_tests/chat_models/test_baichuan.py new file mode 100644 index 0000000000..58b5dd5aa2 --- /dev/null +++ b/libs/langchain/tests/integration_tests/chat_models/test_baichuan.py @@ -0,0 +1,40 @@ +from langchain.chat_models.baichuan import ChatBaichuan +from langchain.schema.messages import AIMessage, HumanMessage + + +def test_chat_baichuan() -> None: + chat = ChatBaichuan() + message = HumanMessage(content="Hello") + response = chat([message]) + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + + +def test_chat_baichuan_with_model() -> None: + chat = ChatBaichuan(model="Baichuan2-13B") + message = HumanMessage(content="Hello") + response = chat([message]) + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + + +def test_chat_baichuan_with_temperature() -> None: + chat = ChatBaichuan(model="Baichuan2-13B", temperature=1.0) + message = HumanMessage(content="Hello") + response = chat([message]) + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + + +def test_chat_baichuan_with_kwargs() -> None: + chat = ChatBaichuan() + message = HumanMessage(content="Hello") + response = chat([message], temperature=0.88, top_p=0.7) + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + + +def test_extra_kwargs() -> None: + chat = ChatBaichuan(temperature=0.88, top_p=0.7) + assert chat.temperature == 0.88 + assert chat.top_p == 0.7 diff --git a/libs/langchain/tests/unit_tests/chat_models/test_baichuan.py b/libs/langchain/tests/unit_tests/chat_models/test_baichuan.py new file mode 100644 index 0000000000..000771b7dc --- /dev/null +++ b/libs/langchain/tests/unit_tests/chat_models/test_baichuan.py @@ -0,0 +1,99 @@ +import pytest + +from langchain.chat_models.baichuan import ( + _convert_delta_to_message_chunk, + _convert_dict_to_message, + _convert_message_to_dict, + _signature, +) +from langchain.pydantic_v1 import SecretStr +from langchain.schema.messages import ( + AIMessage, + AIMessageChunk, + ChatMessage, + FunctionMessage, + HumanMessage, + HumanMessageChunk, + SystemMessage, +) + + +def test__convert_message_to_dict_human() -> None: + message = HumanMessage(content="foo") + result = _convert_message_to_dict(message) + expected_output = {"role": "user", "content": "foo"} + assert result == expected_output + + +def test__convert_message_to_dict_ai() -> None: + message = AIMessage(content="foo") + result = _convert_message_to_dict(message) + expected_output = {"role": "assistant", "content": "foo"} + assert result == expected_output + + +def test__convert_message_to_dict_system() -> None: + message = SystemMessage(content="foo") + with pytest.raises(TypeError) as e: + _convert_message_to_dict(message) + assert "Got unknown type" in str(e) + + +def test__convert_message_to_dict_function() -> None: + message = FunctionMessage(name="foo", content="bar") + with pytest.raises(TypeError) as e: + _convert_message_to_dict(message) + assert "Got unknown type" in str(e) + + +def test__convert_dict_to_message_human() -> None: + message_dict = {"role": "user", "content": "foo"} + result = _convert_dict_to_message(message_dict) + expected_output = HumanMessage(content="foo") + assert result == expected_output + + +def test__convert_dict_to_message_ai() -> None: + message_dict = {"role": "assistant", "content": "foo"} + result = _convert_dict_to_message(message_dict) + expected_output = AIMessage(content="foo") + assert result == expected_output + + +def test__convert_dict_to_message_other_role() -> None: + message_dict = {"role": "system", "content": "foo"} + result = _convert_dict_to_message(message_dict) + expected_output = ChatMessage(role="system", content="foo") + assert result == expected_output + + +def test__convert_delta_to_message_assistant() -> None: + delta = {"role": "assistant", "content": "foo"} + result = _convert_delta_to_message_chunk(delta, AIMessageChunk) + expected_output = AIMessageChunk(content="foo") + assert result == expected_output + + +def test__convert_delta_to_message_human() -> None: + delta = {"role": "user", "content": "foo"} + result = _convert_delta_to_message_chunk(delta, HumanMessageChunk) + expected_output = HumanMessageChunk(content="foo") + assert result == expected_output + + +def test__signature() -> None: + secret_key = SecretStr("YOUR_SECRET_KEY") + + result = _signature( + secret_key=secret_key, + payload={ + "model": "Baichuan2-53B", + "messages": [{"role": "user", "content": "Hi"}], + }, + timestamp=1697734335, + ) + + # The signature was generated by the demo provided by Baichuan. + # https://platform.baichuan-ai.com/docs/api#4 + expected_output = "24a50b2db1648e25a244c67c5ab57d3f" + assert result == expected_output