diff --git a/libs/langchain/langchain/chat_models/baidu_qianfan_endpoint.py b/libs/langchain/langchain/chat_models/baidu_qianfan_endpoint.py index df035464e5..e9526f0634 100644 --- a/libs/langchain/langchain/chat_models/baidu_qianfan_endpoint.py +++ b/libs/langchain/langchain/chat_models/baidu_qianfan_endpoint.py @@ -22,7 +22,6 @@ from langchain.schema.messages import ( AIMessage, AIMessageChunk, BaseMessage, - BaseMessageChunk, ChatMessage, FunctionMessage, HumanMessage, @@ -34,13 +33,6 @@ from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) -def _convert_resp_to_message_chunk(resp: Mapping[str, Any]) -> BaseMessageChunk: - return AIMessageChunk( - content=resp["result"], - role="assistant", - ) - - def convert_message_to_dict(message: BaseMessage) -> dict: """Convert a message to a dictionary that can be passed to the API.""" message_dict: Dict[str, Any] @@ -51,7 +43,7 @@ def convert_message_to_dict(message: BaseMessage) -> dict: elif isinstance(message, AIMessage): message_dict = {"role": "assistant", "content": message.content} if "function_call" in message.additional_kwargs: - message_dict["functions"] = message.additional_kwargs["function_call"] + message_dict["function_call"] = message.additional_kwargs["function_call"] # If function call only, content is None not empty string if message_dict["content"] == "": message_dict["content"] = None @@ -67,6 +59,21 @@ def convert_message_to_dict(message: BaseMessage) -> dict: return message_dict +def _convert_dict_to_message(_dict: Mapping[str, Any]) -> AIMessage: + content = _dict.get("result", "") or "" + if _dict.get("function_call"): + additional_kwargs = {"function_call": dict(_dict["function_call"])} + if "thoughts" in additional_kwargs["function_call"]: + # align to api sample, which affects the llm function_call output + additional_kwargs["function_call"].pop("thoughts") + else: + additional_kwargs = {} + return AIMessage( + content=content, + additional_kwargs={**_dict.get("body", {}), **additional_kwargs}, + ) + + class QianfanChatEndpoint(BaseChatModel): """Baidu Qianfan chat models. @@ -164,6 +171,8 @@ class QianfanChatEndpoint(BaseChatModel): def _default_params(self) -> Dict[str, Any]: """Get the default parameters for calling OpenAI API.""" normal_params = { + "model": self.model, + "endpoint": self.endpoint, "stream": self.streaming, "request_timeout": self.request_timeout, "top_p": self.top_p, @@ -243,10 +252,13 @@ class QianfanChatEndpoint(BaseChatModel): ) params = self._convert_prompt_msg_params(messages, **kwargs) response_payload = self.client.do(**params) - lc_msg = AIMessage(content=response_payload["result"], additional_kwargs={}) + lc_msg = _convert_dict_to_message(response_payload) gen = ChatGeneration( message=lc_msg, - generation_info=dict(finish_reason="stop"), + generation_info={ + "finish_reason": "stop", + **response_payload.get("body", {}), + }, ) token_usage = response_payload.get("usage", {}) llm_output = {"token_usage": token_usage, "model_name": self.model} @@ -276,11 +288,14 @@ class QianfanChatEndpoint(BaseChatModel): ) params = self._convert_prompt_msg_params(messages, **kwargs) response_payload = await self.client.ado(**params) - lc_msg = AIMessage(content=response_payload["result"], additional_kwargs={}) + lc_msg = _convert_dict_to_message(response_payload) generations = [] gen = ChatGeneration( message=lc_msg, - generation_info=dict(finish_reason="stop"), + generation_info={ + "finish_reason": "stop", + **response_payload.get("body", {}), + }, ) generations.append(gen) token_usage = response_payload.get("usage", {}) @@ -297,9 +312,14 @@ class QianfanChatEndpoint(BaseChatModel): params = self._convert_prompt_msg_params(messages, **kwargs) for res in self.client.do(**params): if res: + msg = _convert_dict_to_message(res) chunk = ChatGenerationChunk( text=res["result"], - message=_convert_resp_to_message_chunk(res), + message=AIMessageChunk( + content=msg.content, + role="assistant", + additional_kwargs=msg.additional_kwargs, + ), ) yield chunk if run_manager: @@ -315,9 +335,14 @@ class QianfanChatEndpoint(BaseChatModel): params = self._convert_prompt_msg_params(messages, **kwargs) async for res in await self.client.ado(**params): if res: + msg = _convert_dict_to_message(res) chunk = ChatGenerationChunk( text=res["result"], - message=_convert_resp_to_message_chunk(res), + message=AIMessageChunk( + content=msg.content, + role="assistant", + additional_kwargs=msg.additional_kwargs, + ), ) yield chunk if run_manager: diff --git a/libs/langchain/langchain/llms/baidu_qianfan_endpoint.py b/libs/langchain/langchain/llms/baidu_qianfan_endpoint.py index 8b548a7cd9..5546170af4 100644 --- a/libs/langchain/langchain/llms/baidu_qianfan_endpoint.py +++ b/libs/langchain/langchain/llms/baidu_qianfan_endpoint.py @@ -118,6 +118,8 @@ class QianfanLLMEndpoint(LLM): def _default_params(self) -> Dict[str, Any]: """Get the default parameters for calling OpenAI API.""" normal_params = { + "model": self.model, + "endpoint": self.endpoint, "stream": self.streaming, "request_timeout": self.request_timeout, "top_p": self.top_p, diff --git a/libs/langchain/tests/integration_tests/chat_models/test_qianfan_endpoint.py b/libs/langchain/tests/integration_tests/chat_models/test_qianfan_endpoint.py index 41300688bc..3c82ed10ef 100644 --- a/libs/langchain/tests/integration_tests/chat_models/test_qianfan_endpoint.py +++ b/libs/langchain/tests/integration_tests/chat_models/test_qianfan_endpoint.py @@ -1,16 +1,87 @@ """Test Baidu Qianfan Chat Endpoint.""" +from typing import Any + from langchain.callbacks.manager import CallbackManager +from langchain.chains.openai_functions import ( + create_openai_fn_chain, +) from langchain.chat_models.baidu_qianfan_endpoint import QianfanChatEndpoint +from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate from langchain.schema import ( AIMessage, BaseMessage, ChatGeneration, + FunctionMessage, HumanMessage, LLMResult, ) from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler +_FUNCTIONS: Any = [ + { + "name": "format_person_info", + "description": ( + "Output formatter. Should always be used to format your response to the" + " user." + ), + "parameters": { + "title": "Person", + "description": "Identifying information about a person.", + "type": "object", + "properties": { + "name": { + "title": "Name", + "description": "The person's name", + "type": "string", + }, + "age": { + "title": "Age", + "description": "The person's age", + "type": "integer", + }, + "fav_food": { + "title": "Fav Food", + "description": "The person's favorite food", + "type": "string", + }, + }, + "required": ["name", "age"], + }, + }, + { + "name": "get_current_temperature", + "description": ("Used to get the location's temperature."), + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "city name", + }, + "unit": { + "type": "string", + "enum": ["centigrade", "Fahrenheit"], + }, + }, + "required": ["location", "unit"], + }, + "responses": { + "type": "object", + "properties": { + "temperature": { + "type": "integer", + "description": "city temperature", + }, + "unit": { + "type": "string", + "enum": ["centigrade", "Fahrenheit"], + }, + }, + }, + }, +] + def test_default_call() -> None: """Test default model(`ERNIE-Bot`) call.""" @@ -28,6 +99,14 @@ def test_model() -> None: assert isinstance(response.content, str) +def test_model_param() -> None: + """Test model params works.""" + chat = QianfanChatEndpoint() + response = chat(model="BLOOMZ-7B", messages=[HumanMessage(content="Hello")]) + assert isinstance(response, BaseMessage) + assert isinstance(response.content, str) + + def test_endpoint() -> None: """Test user custom model deployments like some open source models.""" chat = QianfanChatEndpoint(endpoint="qianfan_bloomz_7b_compressed") @@ -36,6 +115,18 @@ def test_endpoint() -> None: assert isinstance(response.content, str) +def test_endpoint_param() -> None: + """Test user custom model deployments like some open source models.""" + chat = QianfanChatEndpoint() + response = chat( + messages=[ + HumanMessage(endpoint="qianfan_bloomz_7b_compressed", content="Hello") + ] + ) + assert isinstance(response, BaseMessage) + assert isinstance(response.content, str) + + def test_multiple_history() -> None: """Tests multiple history works.""" chat = QianfanChatEndpoint() @@ -83,3 +174,60 @@ def test_multiple_messages() -> None: assert isinstance(generation, ChatGeneration) assert isinstance(generation.text, str) assert generation.text == generation.message.content + + +def test_functions_call_thoughts() -> None: + chat = QianfanChatEndpoint(model="ERNIE-Bot") + + prompt_tmpl = "Use the given functions to answer following question: {input}" + prompt_msgs = [ + HumanMessagePromptTemplate.from_template(prompt_tmpl), + ] + prompt = ChatPromptTemplate(messages=prompt_msgs) + + chain = create_openai_fn_chain( + _FUNCTIONS, + chat, + prompt, + output_parser=None, + ) + + message = HumanMessage(content="What's the temperature in Shanghai today?") + response = chain.generate([{"input": message}]) + assert isinstance(response.generations[0][0], ChatGeneration) + assert isinstance(response.generations[0][0].message, AIMessage) + assert "function_call" in response.generations[0][0].message.additional_kwargs + + +def test_functions_call() -> None: + chat = QianfanChatEndpoint(model="ERNIE-Bot") + + prompt = ChatPromptTemplate( + messages=[ + HumanMessage(content="What's the temperature in Shanghai today?"), + AIMessage( + content="", + additional_kwargs={ + "function_call": { + "name": "get_current_temperature", + "thoughts": "i will use get_current_temperature " + "to resolve the questions", + "arguments": '{"location":"Shanghai","unit":"centigrade"}', + } + }, + ), + FunctionMessage( + name="get_current_weather", + content='{"temperature": "25", \ + "unit": "摄氏度", "description": "晴朗"}', + ), + ] + ) + llm_chain = create_openai_fn_chain( + _FUNCTIONS, + chat, + prompt, + output_parser=None, + ) + resp = llm_chain.generate([{}]) + assert isinstance(resp, LLMResult)