`baichuan_secret_key` use pydantic.types.SecretStr & Add Baichuan tests (#12031)

### Description
- `baichuan_secret_key` use pydantic.types.SecretStr
- Add Baichuan tests
pull/12041/head^2
John Mai 9 months ago committed by GitHub
parent 85bac75729
commit 8eb40b5fe2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -2,13 +2,13 @@ import hashlib
import json
import logging
import time
from typing import Any, Dict, Iterator, List, Mapping, Optional, Type
from typing import Any, Dict, Iterator, List, Mapping, Optional, Type, Union
import requests
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.chat_models.base import BaseChatModel, _generate_from_stream
from langchain.pydantic_v1 import Field, root_validator
from langchain.pydantic_v1 import Field, SecretStr, root_validator
from langchain.schema import (
AIMessage,
BaseMessage,
@ -29,7 +29,7 @@ from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
logger = logging.getLogger(__name__)
def convert_message_to_dict(message: BaseMessage) -> dict:
def _convert_message_to_dict(message: BaseMessage) -> dict:
message_dict: Dict[str, Any]
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
@ -69,6 +69,21 @@ def _convert_delta_to_message_chunk(
return default_class(content=content)
def _to_secret(value: Union[SecretStr, str]) -> SecretStr:
"""Convert a string to a SecretStr if needed."""
if isinstance(value, SecretStr):
return value
return SecretStr(value)
# signature generation
def _signature(secret_key: SecretStr, payload: Dict[str, Any], timestamp: int) -> str:
input_str = secret_key.get_secret_value() + json.dumps(payload) + str(timestamp)
md5 = hashlib.md5()
md5.update(input_str.encode("utf-8"))
return md5.hexdigest()
class ChatBaichuan(BaseChatModel):
"""Baichuan chat models API by Baichuan Intelligent Technology.
@ -90,21 +105,25 @@ class ChatBaichuan(BaseChatModel):
"""Baichuan custom endpoints"""
baichuan_api_key: Optional[str] = None
"""Baichuan API Key"""
baichuan_secret_key: Optional[str] = None
baichuan_secret_key: Optional[SecretStr] = None
"""Baichuan Secret Key"""
streaming: Optional[bool] = False
"""streaming mode."""
request_timeout: Optional[int] = 60
streaming: bool = False
"""Whether to stream the results or not."""
request_timeout: int = 60
"""request timeout for chat http requests"""
model = "Baichuan2-53B"
"""model name of Baichuan, default is `Baichuan2-53B`."""
temperature: float = 0.3
"""What sampling temperature to use."""
top_k: int = 5
"""What search sampling control to use."""
top_p: float = 0.85
"""What probability mass to use."""
with_search_enhance: bool = False
"""Whether to use search enhance, default is False."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for API call not explicitly specified."""
class Config:
"""Configuration for this pydantic object."""
@ -149,10 +168,12 @@ class ChatBaichuan(BaseChatModel):
"baichuan_api_key",
"BAICHUAN_API_KEY",
)
values["baichuan_secret_key"] = get_from_dict_or_env(
values,
"baichuan_secret_key",
"BAICHUAN_SECRET_KEY",
values["baichuan_secret_key"] = _to_secret(
get_from_dict_or_env(
values,
"baichuan_secret_key",
"BAICHUAN_SECRET_KEY",
)
)
return values
@ -169,15 +190,6 @@ class ChatBaichuan(BaseChatModel):
return {**normal_params, **self.model_kwargs}
def _signature(self, data: Dict[str, Any], timestamp: int) -> str:
if self.baichuan_secret_key is None:
raise ValueError("Baichuan secret key is not set.")
input_str = self.baichuan_secret_key + json.dumps(data) + str(timestamp)
md5 = hashlib.md5()
md5.update(input_str.encode("utf-8"))
return md5.hexdigest()
def _generate(
self,
messages: List[BaseMessage],
@ -224,6 +236,9 @@ class ChatBaichuan(BaseChatModel):
run_manager.on_llm_new_token(chunk.content)
def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response:
if self.baichuan_secret_key is None:
raise ValueError("Baichuan secret key is not set.")
parameters = {**self._default_params, **kwargs}
model = parameters.pop("model")
@ -231,7 +246,7 @@ class ChatBaichuan(BaseChatModel):
payload = {
"model": model,
"messages": [convert_message_to_dict(m) for m in messages],
"messages": [_convert_message_to_dict(m) for m in messages],
"parameters": parameters,
}
@ -249,7 +264,11 @@ class ChatBaichuan(BaseChatModel):
"Content-Type": "application/json",
"Authorization": f"Bearer {self.baichuan_api_key}",
"X-BC-Timestamp": str(timestamp),
"X-BC-Signature": self._signature(payload, timestamp),
"X-BC-Signature": _signature(
secret_key=self.baichuan_secret_key,
payload=payload,
timestamp=timestamp,
),
"X-BC-Sign-Algo": "MD5",
**headers,
},

@ -0,0 +1,40 @@
from langchain.chat_models.baichuan import ChatBaichuan
from langchain.schema.messages import AIMessage, HumanMessage
def test_chat_baichuan() -> None:
chat = ChatBaichuan()
message = HumanMessage(content="Hello")
response = chat([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
def test_chat_baichuan_with_model() -> None:
chat = ChatBaichuan(model="Baichuan2-13B")
message = HumanMessage(content="Hello")
response = chat([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
def test_chat_baichuan_with_temperature() -> None:
chat = ChatBaichuan(model="Baichuan2-13B", temperature=1.0)
message = HumanMessage(content="Hello")
response = chat([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
def test_chat_baichuan_with_kwargs() -> None:
chat = ChatBaichuan()
message = HumanMessage(content="Hello")
response = chat([message], temperature=0.88, top_p=0.7)
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
def test_extra_kwargs() -> None:
chat = ChatBaichuan(temperature=0.88, top_p=0.7)
assert chat.temperature == 0.88
assert chat.top_p == 0.7

@ -0,0 +1,99 @@
import pytest
from langchain.chat_models.baichuan import (
_convert_delta_to_message_chunk,
_convert_dict_to_message,
_convert_message_to_dict,
_signature,
)
from langchain.pydantic_v1 import SecretStr
from langchain.schema.messages import (
AIMessage,
AIMessageChunk,
ChatMessage,
FunctionMessage,
HumanMessage,
HumanMessageChunk,
SystemMessage,
)
def test__convert_message_to_dict_human() -> None:
message = HumanMessage(content="foo")
result = _convert_message_to_dict(message)
expected_output = {"role": "user", "content": "foo"}
assert result == expected_output
def test__convert_message_to_dict_ai() -> None:
message = AIMessage(content="foo")
result = _convert_message_to_dict(message)
expected_output = {"role": "assistant", "content": "foo"}
assert result == expected_output
def test__convert_message_to_dict_system() -> None:
message = SystemMessage(content="foo")
with pytest.raises(TypeError) as e:
_convert_message_to_dict(message)
assert "Got unknown type" in str(e)
def test__convert_message_to_dict_function() -> None:
message = FunctionMessage(name="foo", content="bar")
with pytest.raises(TypeError) as e:
_convert_message_to_dict(message)
assert "Got unknown type" in str(e)
def test__convert_dict_to_message_human() -> None:
message_dict = {"role": "user", "content": "foo"}
result = _convert_dict_to_message(message_dict)
expected_output = HumanMessage(content="foo")
assert result == expected_output
def test__convert_dict_to_message_ai() -> None:
message_dict = {"role": "assistant", "content": "foo"}
result = _convert_dict_to_message(message_dict)
expected_output = AIMessage(content="foo")
assert result == expected_output
def test__convert_dict_to_message_other_role() -> None:
message_dict = {"role": "system", "content": "foo"}
result = _convert_dict_to_message(message_dict)
expected_output = ChatMessage(role="system", content="foo")
assert result == expected_output
def test__convert_delta_to_message_assistant() -> None:
delta = {"role": "assistant", "content": "foo"}
result = _convert_delta_to_message_chunk(delta, AIMessageChunk)
expected_output = AIMessageChunk(content="foo")
assert result == expected_output
def test__convert_delta_to_message_human() -> None:
delta = {"role": "user", "content": "foo"}
result = _convert_delta_to_message_chunk(delta, HumanMessageChunk)
expected_output = HumanMessageChunk(content="foo")
assert result == expected_output
def test__signature() -> None:
secret_key = SecretStr("YOUR_SECRET_KEY")
result = _signature(
secret_key=secret_key,
payload={
"model": "Baichuan2-53B",
"messages": [{"role": "user", "content": "Hi"}],
},
timestamp=1697734335,
)
# The signature was generated by the demo provided by Baichuan.
# https://platform.baichuan-ai.com/docs/api#4
expected_output = "24a50b2db1648e25a244c67c5ab57d3f"
assert result == expected_output
Loading…
Cancel
Save