From 3fb5e4d18544bbde2da91317862395120aa7a3e6 Mon Sep 17 00:00:00 2001 From: John Mai Date: Tue, 17 Oct 2023 13:30:57 -0500 Subject: [PATCH] Add Baichuan chat model (#11923) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Description: A large language models developed by Baichuan Intelligent Technology,https://www.baichuan-ai.com/home Issue: None Dependencies: None Tag maintainer: Twitter handle: --- docs/docs/integrations/chat/baichuan.ipynb | 157 ++++++++++ .../langchain/chat_models/__init__.py | 2 + .../langchain/chat_models/baichuan.py | 274 ++++++++++++++++++ 3 files changed, 433 insertions(+) create mode 100644 docs/docs/integrations/chat/baichuan.ipynb create mode 100644 libs/langchain/langchain/chat_models/baichuan.py diff --git a/docs/docs/integrations/chat/baichuan.ipynb b/docs/docs/integrations/chat/baichuan.ipynb new file mode 100644 index 0000000000..9718734cb3 --- /dev/null +++ b/docs/docs/integrations/chat/baichuan.ipynb @@ -0,0 +1,157 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Baichuan Chat\n", + "\n", + "Baichuan chat models API by Baichuan Intelligent Technology. For more information, see [https://platform.baichuan-ai.com/docs/api](https://platform.baichuan-ai.com/docs/api)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "ExecuteTime": { + "end_time": "2023-10-17T15:14:24.186131Z", + "start_time": "2023-10-17T15:14:23.831767Z" + } + }, + "outputs": [], + "source": [ + "from langchain.chat_models import ChatBaichuan\n", + "from langchain.schema import HumanMessage" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "ExecuteTime": { + "end_time": "2023-10-17T15:14:24.191123Z", + "start_time": "2023-10-17T15:14:24.186330Z" + } + }, + "outputs": [], + "source": [ + "chat = ChatBaichuan(\n", + " baichuan_api_key='YOUR_API_KEY',\n", + " baichuan_secret_key='YOUR_SECRET_KEY'\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "or you can set `api_key` and `secret_key` in your environment variables\n", + "```bash\n", + "export BAICHUAN_API_KEY=YOUR_API_KEY\n", + "export BAICHUAN_SECRET_KEY=YOUR_SECRET_KEY\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "ExecuteTime": { + "end_time": "2023-10-17T15:14:25.853218Z", + "start_time": "2023-10-17T15:14:24.192408Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "AIMessage(content='首先,我们需要确定闰年的二月有多少天。闰年的二月有29天。\\n\\n然后,我们可以计算你的月薪:\\n\\n日薪 = 月薪 / (当月天数)\\n\\n所以,你的月薪 = 日薪 * 当月天数\\n\\n将数值代入公式:\\n\\n月薪 = 8元/天 * 29天 = 232元\\n\\n因此,你在闰年的二月的月薪是232元。')" + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chat([\n", + " HumanMessage(content='我日薪8块钱,请问在闰年的二月,我月薪多少')\n", + "])" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## For ChatBaichuan with Streaming" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 5, + "outputs": [], + "source": [ + "chat = ChatBaichuan(\n", + " baichuan_api_key='YOUR_API_KEY',\n", + " baichuan_secret_key='YOUR_SECRET_KEY',\n", + " streaming=True\n", + ")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-17T15:14:25.870044Z", + "start_time": "2023-10-17T15:14:25.863381Z" + } + } + }, + { + "cell_type": "code", + "execution_count": 6, + "outputs": [ + { + "data": { + "text/plain": "AIMessageChunk(content='首先,我们需要确定闰年的二月有多少天。闰年的二月有29天。\\n\\n然后,我们可以计算你的月薪:\\n\\n日薪 = 月薪 / (当月天数)\\n\\n所以,你的月薪 = 日薪 * 当月天数\\n\\n将数值代入公式:\\n\\n月薪 = 8元/天 * 29天 = 232元\\n\\n因此,你在闰年的二月的月薪是232元。')" + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chat([\n", + " HumanMessage(content='我日薪8块钱,请问在闰年的二月,我月薪多少')\n", + "])" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-17T15:14:27.153546Z", + "start_time": "2023-10-17T15:14:25.868470Z" + } + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/libs/langchain/langchain/chat_models/__init__.py b/libs/langchain/langchain/chat_models/__init__.py index 885243e054..03a42bda26 100644 --- a/libs/langchain/langchain/chat_models/__init__.py +++ b/libs/langchain/langchain/chat_models/__init__.py @@ -20,6 +20,7 @@ an interface where "chat messages" are the inputs and outputs. from langchain.chat_models.anthropic import ChatAnthropic from langchain.chat_models.anyscale import ChatAnyscale from langchain.chat_models.azure_openai import AzureChatOpenAI +from langchain.chat_models.baichuan import ChatBaichuan from langchain.chat_models.baidu_qianfan_endpoint import QianfanChatEndpoint from langchain.chat_models.bedrock import BedrockChat from langchain.chat_models.cohere import ChatCohere @@ -65,4 +66,5 @@ __all__ = [ "QianfanChatEndpoint", "ChatFireworks", "ChatYandexGPT", + "ChatBaichuan", ] diff --git a/libs/langchain/langchain/chat_models/baichuan.py b/libs/langchain/langchain/chat_models/baichuan.py new file mode 100644 index 0000000000..39b14d2d11 --- /dev/null +++ b/libs/langchain/langchain/chat_models/baichuan.py @@ -0,0 +1,274 @@ +import hashlib +import json +import logging +import time +from typing import Any, Dict, Iterator, List, Mapping, Optional, Type + +import requests + +from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.chat_models.base import BaseChatModel, _generate_from_stream +from langchain.pydantic_v1 import Field, root_validator +from langchain.schema import ( + AIMessage, + BaseMessage, + ChatGeneration, + ChatMessage, + ChatResult, + HumanMessage, +) +from langchain.schema.messages import ( + AIMessageChunk, + BaseMessageChunk, + ChatMessageChunk, + HumanMessageChunk, +) +from langchain.schema.output import ChatGenerationChunk +from langchain.utils import get_from_dict_or_env, get_pydantic_field_names + +logger = logging.getLogger(__name__) + + +def convert_message_to_dict(message: BaseMessage) -> dict: + message_dict: Dict[str, Any] + if isinstance(message, ChatMessage): + message_dict = {"role": message.role, "content": message.content} + elif isinstance(message, HumanMessage): + message_dict = {"role": "user", "content": message.content} + elif isinstance(message, AIMessage): + message_dict = {"role": "assistant", "content": message.content} + else: + raise TypeError(f"Got unknown type {message}") + + return message_dict + + +def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: + role = _dict["role"] + if role == "user": + return HumanMessage(content=_dict["content"]) + elif role == "assistant": + return AIMessage(content=_dict.get("content", "") or "") + else: + return ChatMessage(content=_dict["content"], role=role) + + +def _convert_delta_to_message_chunk( + _dict: Mapping[str, Any], default_class: Type[BaseMessageChunk] +) -> BaseMessageChunk: + role = _dict.get("role") + content = _dict.get("content") or "" + + if role == "user" or default_class == HumanMessageChunk: + return HumanMessageChunk(content=content) + elif role == "assistant" or default_class == AIMessageChunk: + return AIMessageChunk(content=content) + elif role or default_class == ChatMessageChunk: + return ChatMessageChunk(content=content, role=role) + else: + return default_class(content=content) + + +class ChatBaichuan(BaseChatModel): + """Baichuan chat models API by Baichuan Intelligent Technology. + + For more information, see https://platform.baichuan-ai.com/docs/api + """ + + @property + def lc_secrets(self) -> Dict[str, str]: + return { + "baichuan_api_key": "BAICHUAN_API_KEY", + "baichuan_secret_key": "BAICHUAN_SECRET_KEY", + } + + @property + def lc_serializable(self) -> bool: + return True + + baichuan_api_base: str = "https://api.baichuan-ai.com" + """Baichuan custom endpoints""" + baichuan_api_key: Optional[str] = None + """Baichuan API Key""" + baichuan_secret_key: Optional[str] = None + """Baichuan Secret Key""" + streaming: Optional[bool] = False + """streaming mode.""" + request_timeout: Optional[int] = 60 + """request timeout for chat http requests""" + + model = "Baichuan2-53B" + """model name of Baichuan, default is `Baichuan2-53B`.""" + temperature: float = 0.3 + top_k: int = 5 + top_p: float = 0.85 + with_search_enhance: bool = False + """Whether to use search enhance, default is False.""" + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + + class Config: + """Configuration for this pydantic object.""" + + allow_population_by_field_name = True + + @root_validator(pre=True) + def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """Build extra kwargs from additional params that were passed in.""" + all_required_field_names = get_pydantic_field_names(cls) + extra = values.get("model_kwargs", {}) + for field_name in list(values): + if field_name in extra: + raise ValueError(f"Found {field_name} supplied twice.") + if field_name not in all_required_field_names: + logger.warning( + f"""WARNING! {field_name} is not default parameter. + {field_name} was transferred to model_kwargs. + Please confirm that {field_name} is what you intended.""" + ) + extra[field_name] = values.pop(field_name) + + invalid_model_kwargs = all_required_field_names.intersection(extra.keys()) + if invalid_model_kwargs: + raise ValueError( + f"Parameters {invalid_model_kwargs} should be specified explicitly. " + f"Instead they were passed in as part of `model_kwargs` parameter." + ) + + values["model_kwargs"] = extra + return values + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + values["baichuan_api_base"] = get_from_dict_or_env( + values, + "baichuan_api_base", + "BAICHUAN_API_BASE", + ) + values["baichuan_api_key"] = get_from_dict_or_env( + values, + "baichuan_api_key", + "BAICHUAN_API_KEY", + ) + values["baichuan_secret_key"] = get_from_dict_or_env( + values, + "baichuan_secret_key", + "BAICHUAN_SECRET_KEY", + ) + + return values + + @property + def _default_params(self) -> Dict[str, Any]: + """Get the default parameters for calling Baichuan API.""" + normal_params = { + "model": self.model, + "top_p": self.top_p, + "top_k": self.top_k, + "with_search_enhance": self.with_search_enhance, + } + + return {**normal_params, **self.model_kwargs} + + def _signature(self, data: Dict[str, Any], timestamp: int) -> str: + if self.baichuan_secret_key is None: + raise ValueError("Baichuan secret key is not set.") + + input_str = self.baichuan_secret_key + json.dumps(data) + str(timestamp) + md5 = hashlib.md5() + md5.update(input_str.encode("utf-8")) + return md5.hexdigest() + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + if self.streaming: + stream_iter = self._stream( + messages=messages, stop=stop, run_manager=run_manager, **kwargs + ) + return _generate_from_stream(stream_iter) + + res = self._chat(messages, **kwargs) + + response = res.json() + + if response.get("code") != 0: + raise ValueError(f"Error from Baichuan api response: {response}") + + return self._create_chat_result(response) + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + res = self._chat(messages, **kwargs) + + default_chunk_class = AIMessageChunk + for chunk in res.iter_lines(): + response = json.loads(chunk) + if response.get("code") != 0: + raise ValueError(f"Error from Baichuan api response: {response}") + + data = response.get("data") + for m in data.get("messages"): + chunk = _convert_delta_to_message_chunk(m, default_chunk_class) + default_chunk_class = chunk.__class__ + yield ChatGenerationChunk(message=chunk) + if run_manager: + run_manager.on_llm_new_token(chunk.content) + + def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response: + parameters = {**self._default_params, **kwargs} + + model = parameters.pop("model") + headers = parameters.pop("headers", {}) + + payload = { + "model": model, + "messages": [convert_message_to_dict(m) for m in messages], + "parameters": parameters, + } + + timestamp = int(time.time()) + + url = f"{self.baichuan_api_base}/v1" + if self.streaming: + url = f"{url}/stream" + url = f"{url}/chat" + + res = requests.post( + url=url, + timeout=self.request_timeout, + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {self.baichuan_api_key}", + "X-BC-Timestamp": str(timestamp), + "X-BC-Signature": self._signature(payload, timestamp), + "X-BC-Sign-Algo": "MD5", + **headers, + }, + json=payload, + stream=self.streaming, + ) + return res + + def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult: + generations = [] + for m in response["data"]["messages"]: + message = _convert_dict_to_message(m) + gen = ChatGeneration(message=message) + generations.append(gen) + + token_usage = response["usage"] + llm_output = {"token_usage": token_usage, "model": self.model} + return ChatResult(generations=generations, llm_output=llm_output) + + @property + def _llm_type(self) -> str: + return "baichuan-chat"