feat: support ChatModels Qianfan QianfanChatEndpoint function_call (#11107)

- **Description:** 
* feature for `QianfanChatEndpoint` function_call ability, add
integration_test for it
    * add `model`, `endpoint` supported in calling params
    * add raw response in ChatModel Message
- **Issue:** 
    * #10867 
    * #11105 
    * #10215
- **Dependencies:** no
- **Tag maintainer:** @baskaryan 
- **Twitter handle:** no
This commit is contained in:
DanielZzz 2023-10-18 04:33:55 +08:00 committed by GitHub
parent 67300567d3
commit b647505280
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 190 additions and 15 deletions

View File

@ -22,7 +22,6 @@ from langchain.schema.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
FunctionMessage,
HumanMessage,
@ -34,13 +33,6 @@ from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
def _convert_resp_to_message_chunk(resp: Mapping[str, Any]) -> BaseMessageChunk:
return AIMessageChunk(
content=resp["result"],
role="assistant",
)
def convert_message_to_dict(message: BaseMessage) -> dict:
"""Convert a message to a dictionary that can be passed to the API."""
message_dict: Dict[str, Any]
@ -51,7 +43,7 @@ def convert_message_to_dict(message: BaseMessage) -> dict:
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
if "function_call" in message.additional_kwargs:
message_dict["functions"] = message.additional_kwargs["function_call"]
message_dict["function_call"] = message.additional_kwargs["function_call"]
# If function call only, content is None not empty string
if message_dict["content"] == "":
message_dict["content"] = None
@ -67,6 +59,21 @@ def convert_message_to_dict(message: BaseMessage) -> dict:
return message_dict
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> AIMessage:
content = _dict.get("result", "") or ""
if _dict.get("function_call"):
additional_kwargs = {"function_call": dict(_dict["function_call"])}
if "thoughts" in additional_kwargs["function_call"]:
# align to api sample, which affects the llm function_call output
additional_kwargs["function_call"].pop("thoughts")
else:
additional_kwargs = {}
return AIMessage(
content=content,
additional_kwargs={**_dict.get("body", {}), **additional_kwargs},
)
class QianfanChatEndpoint(BaseChatModel):
"""Baidu Qianfan chat models.
@ -164,6 +171,8 @@ class QianfanChatEndpoint(BaseChatModel):
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling OpenAI API."""
normal_params = {
"model": self.model,
"endpoint": self.endpoint,
"stream": self.streaming,
"request_timeout": self.request_timeout,
"top_p": self.top_p,
@ -243,10 +252,13 @@ class QianfanChatEndpoint(BaseChatModel):
)
params = self._convert_prompt_msg_params(messages, **kwargs)
response_payload = self.client.do(**params)
lc_msg = AIMessage(content=response_payload["result"], additional_kwargs={})
lc_msg = _convert_dict_to_message(response_payload)
gen = ChatGeneration(
message=lc_msg,
generation_info=dict(finish_reason="stop"),
generation_info={
"finish_reason": "stop",
**response_payload.get("body", {}),
},
)
token_usage = response_payload.get("usage", {})
llm_output = {"token_usage": token_usage, "model_name": self.model}
@ -276,11 +288,14 @@ class QianfanChatEndpoint(BaseChatModel):
)
params = self._convert_prompt_msg_params(messages, **kwargs)
response_payload = await self.client.ado(**params)
lc_msg = AIMessage(content=response_payload["result"], additional_kwargs={})
lc_msg = _convert_dict_to_message(response_payload)
generations = []
gen = ChatGeneration(
message=lc_msg,
generation_info=dict(finish_reason="stop"),
generation_info={
"finish_reason": "stop",
**response_payload.get("body", {}),
},
)
generations.append(gen)
token_usage = response_payload.get("usage", {})
@ -297,9 +312,14 @@ class QianfanChatEndpoint(BaseChatModel):
params = self._convert_prompt_msg_params(messages, **kwargs)
for res in self.client.do(**params):
if res:
msg = _convert_dict_to_message(res)
chunk = ChatGenerationChunk(
text=res["result"],
message=_convert_resp_to_message_chunk(res),
message=AIMessageChunk(
content=msg.content,
role="assistant",
additional_kwargs=msg.additional_kwargs,
),
)
yield chunk
if run_manager:
@ -315,9 +335,14 @@ class QianfanChatEndpoint(BaseChatModel):
params = self._convert_prompt_msg_params(messages, **kwargs)
async for res in await self.client.ado(**params):
if res:
msg = _convert_dict_to_message(res)
chunk = ChatGenerationChunk(
text=res["result"],
message=_convert_resp_to_message_chunk(res),
message=AIMessageChunk(
content=msg.content,
role="assistant",
additional_kwargs=msg.additional_kwargs,
),
)
yield chunk
if run_manager:

View File

@ -118,6 +118,8 @@ class QianfanLLMEndpoint(LLM):
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling OpenAI API."""
normal_params = {
"model": self.model,
"endpoint": self.endpoint,
"stream": self.streaming,
"request_timeout": self.request_timeout,
"top_p": self.top_p,

View File

@ -1,16 +1,87 @@
"""Test Baidu Qianfan Chat Endpoint."""
from typing import Any
from langchain.callbacks.manager import CallbackManager
from langchain.chains.openai_functions import (
create_openai_fn_chain,
)
from langchain.chat_models.baidu_qianfan_endpoint import QianfanChatEndpoint
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain.schema import (
AIMessage,
BaseMessage,
ChatGeneration,
FunctionMessage,
HumanMessage,
LLMResult,
)
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
_FUNCTIONS: Any = [
{
"name": "format_person_info",
"description": (
"Output formatter. Should always be used to format your response to the"
" user."
),
"parameters": {
"title": "Person",
"description": "Identifying information about a person.",
"type": "object",
"properties": {
"name": {
"title": "Name",
"description": "The person's name",
"type": "string",
},
"age": {
"title": "Age",
"description": "The person's age",
"type": "integer",
},
"fav_food": {
"title": "Fav Food",
"description": "The person's favorite food",
"type": "string",
},
},
"required": ["name", "age"],
},
},
{
"name": "get_current_temperature",
"description": ("Used to get the location's temperature."),
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "city name",
},
"unit": {
"type": "string",
"enum": ["centigrade", "Fahrenheit"],
},
},
"required": ["location", "unit"],
},
"responses": {
"type": "object",
"properties": {
"temperature": {
"type": "integer",
"description": "city temperature",
},
"unit": {
"type": "string",
"enum": ["centigrade", "Fahrenheit"],
},
},
},
},
]
def test_default_call() -> None:
"""Test default model(`ERNIE-Bot`) call."""
@ -28,6 +99,14 @@ def test_model() -> None:
assert isinstance(response.content, str)
def test_model_param() -> None:
"""Test model params works."""
chat = QianfanChatEndpoint()
response = chat(model="BLOOMZ-7B", messages=[HumanMessage(content="Hello")])
assert isinstance(response, BaseMessage)
assert isinstance(response.content, str)
def test_endpoint() -> None:
"""Test user custom model deployments like some open source models."""
chat = QianfanChatEndpoint(endpoint="qianfan_bloomz_7b_compressed")
@ -36,6 +115,18 @@ def test_endpoint() -> None:
assert isinstance(response.content, str)
def test_endpoint_param() -> None:
"""Test user custom model deployments like some open source models."""
chat = QianfanChatEndpoint()
response = chat(
messages=[
HumanMessage(endpoint="qianfan_bloomz_7b_compressed", content="Hello")
]
)
assert isinstance(response, BaseMessage)
assert isinstance(response.content, str)
def test_multiple_history() -> None:
"""Tests multiple history works."""
chat = QianfanChatEndpoint()
@ -83,3 +174,60 @@ def test_multiple_messages() -> None:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.text == generation.message.content
def test_functions_call_thoughts() -> None:
chat = QianfanChatEndpoint(model="ERNIE-Bot")
prompt_tmpl = "Use the given functions to answer following question: {input}"
prompt_msgs = [
HumanMessagePromptTemplate.from_template(prompt_tmpl),
]
prompt = ChatPromptTemplate(messages=prompt_msgs)
chain = create_openai_fn_chain(
_FUNCTIONS,
chat,
prompt,
output_parser=None,
)
message = HumanMessage(content="What's the temperature in Shanghai today?")
response = chain.generate([{"input": message}])
assert isinstance(response.generations[0][0], ChatGeneration)
assert isinstance(response.generations[0][0].message, AIMessage)
assert "function_call" in response.generations[0][0].message.additional_kwargs
def test_functions_call() -> None:
chat = QianfanChatEndpoint(model="ERNIE-Bot")
prompt = ChatPromptTemplate(
messages=[
HumanMessage(content="What's the temperature in Shanghai today?"),
AIMessage(
content="",
additional_kwargs={
"function_call": {
"name": "get_current_temperature",
"thoughts": "i will use get_current_temperature "
"to resolve the questions",
"arguments": '{"location":"Shanghai","unit":"centigrade"}',
}
},
),
FunctionMessage(
name="get_current_weather",
content='{"temperature": "25", \
"unit": "摄氏度", "description": "晴朗"}',
),
]
)
llm_chain = create_openai_fn_chain(
_FUNCTIONS,
chat,
prompt,
output_parser=None,
)
resp = llm_chain.generate([{}])
assert isinstance(resp, LLMResult)