mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
feat: support ChatModels Qianfan QianfanChatEndpoint
function_call (#11107)
- **Description:** * feature for `QianfanChatEndpoint` function_call ability, add integration_test for it * add `model`, `endpoint` supported in calling params * add raw response in ChatModel Message - **Issue:** * #10867 * #11105 * #10215 - **Dependencies:** no - **Tag maintainer:** @baskaryan - **Twitter handle:** no
This commit is contained in:
parent
67300567d3
commit
b647505280
@ -22,7 +22,6 @@ from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
ChatMessage,
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
@ -34,13 +33,6 @@ from langchain.utils import get_from_dict_or_env
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _convert_resp_to_message_chunk(resp: Mapping[str, Any]) -> BaseMessageChunk:
|
||||
return AIMessageChunk(
|
||||
content=resp["result"],
|
||||
role="assistant",
|
||||
)
|
||||
|
||||
|
||||
def convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
"""Convert a message to a dictionary that can be passed to the API."""
|
||||
message_dict: Dict[str, Any]
|
||||
@ -51,7 +43,7 @@ def convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if "function_call" in message.additional_kwargs:
|
||||
message_dict["functions"] = message.additional_kwargs["function_call"]
|
||||
message_dict["function_call"] = message.additional_kwargs["function_call"]
|
||||
# If function call only, content is None not empty string
|
||||
if message_dict["content"] == "":
|
||||
message_dict["content"] = None
|
||||
@ -67,6 +59,21 @@ def convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
return message_dict
|
||||
|
||||
|
||||
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> AIMessage:
|
||||
content = _dict.get("result", "") or ""
|
||||
if _dict.get("function_call"):
|
||||
additional_kwargs = {"function_call": dict(_dict["function_call"])}
|
||||
if "thoughts" in additional_kwargs["function_call"]:
|
||||
# align to api sample, which affects the llm function_call output
|
||||
additional_kwargs["function_call"].pop("thoughts")
|
||||
else:
|
||||
additional_kwargs = {}
|
||||
return AIMessage(
|
||||
content=content,
|
||||
additional_kwargs={**_dict.get("body", {}), **additional_kwargs},
|
||||
)
|
||||
|
||||
|
||||
class QianfanChatEndpoint(BaseChatModel):
|
||||
"""Baidu Qianfan chat models.
|
||||
|
||||
@ -164,6 +171,8 @@ class QianfanChatEndpoint(BaseChatModel):
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling OpenAI API."""
|
||||
normal_params = {
|
||||
"model": self.model,
|
||||
"endpoint": self.endpoint,
|
||||
"stream": self.streaming,
|
||||
"request_timeout": self.request_timeout,
|
||||
"top_p": self.top_p,
|
||||
@ -243,10 +252,13 @@ class QianfanChatEndpoint(BaseChatModel):
|
||||
)
|
||||
params = self._convert_prompt_msg_params(messages, **kwargs)
|
||||
response_payload = self.client.do(**params)
|
||||
lc_msg = AIMessage(content=response_payload["result"], additional_kwargs={})
|
||||
lc_msg = _convert_dict_to_message(response_payload)
|
||||
gen = ChatGeneration(
|
||||
message=lc_msg,
|
||||
generation_info=dict(finish_reason="stop"),
|
||||
generation_info={
|
||||
"finish_reason": "stop",
|
||||
**response_payload.get("body", {}),
|
||||
},
|
||||
)
|
||||
token_usage = response_payload.get("usage", {})
|
||||
llm_output = {"token_usage": token_usage, "model_name": self.model}
|
||||
@ -276,11 +288,14 @@ class QianfanChatEndpoint(BaseChatModel):
|
||||
)
|
||||
params = self._convert_prompt_msg_params(messages, **kwargs)
|
||||
response_payload = await self.client.ado(**params)
|
||||
lc_msg = AIMessage(content=response_payload["result"], additional_kwargs={})
|
||||
lc_msg = _convert_dict_to_message(response_payload)
|
||||
generations = []
|
||||
gen = ChatGeneration(
|
||||
message=lc_msg,
|
||||
generation_info=dict(finish_reason="stop"),
|
||||
generation_info={
|
||||
"finish_reason": "stop",
|
||||
**response_payload.get("body", {}),
|
||||
},
|
||||
)
|
||||
generations.append(gen)
|
||||
token_usage = response_payload.get("usage", {})
|
||||
@ -297,9 +312,14 @@ class QianfanChatEndpoint(BaseChatModel):
|
||||
params = self._convert_prompt_msg_params(messages, **kwargs)
|
||||
for res in self.client.do(**params):
|
||||
if res:
|
||||
msg = _convert_dict_to_message(res)
|
||||
chunk = ChatGenerationChunk(
|
||||
text=res["result"],
|
||||
message=_convert_resp_to_message_chunk(res),
|
||||
message=AIMessageChunk(
|
||||
content=msg.content,
|
||||
role="assistant",
|
||||
additional_kwargs=msg.additional_kwargs,
|
||||
),
|
||||
)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
@ -315,9 +335,14 @@ class QianfanChatEndpoint(BaseChatModel):
|
||||
params = self._convert_prompt_msg_params(messages, **kwargs)
|
||||
async for res in await self.client.ado(**params):
|
||||
if res:
|
||||
msg = _convert_dict_to_message(res)
|
||||
chunk = ChatGenerationChunk(
|
||||
text=res["result"],
|
||||
message=_convert_resp_to_message_chunk(res),
|
||||
message=AIMessageChunk(
|
||||
content=msg.content,
|
||||
role="assistant",
|
||||
additional_kwargs=msg.additional_kwargs,
|
||||
),
|
||||
)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
|
@ -118,6 +118,8 @@ class QianfanLLMEndpoint(LLM):
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling OpenAI API."""
|
||||
normal_params = {
|
||||
"model": self.model,
|
||||
"endpoint": self.endpoint,
|
||||
"stream": self.streaming,
|
||||
"request_timeout": self.request_timeout,
|
||||
"top_p": self.top_p,
|
||||
|
@ -1,16 +1,87 @@
|
||||
"""Test Baidu Qianfan Chat Endpoint."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
from langchain.chains.openai_functions import (
|
||||
create_openai_fn_chain,
|
||||
)
|
||||
from langchain.chat_models.baidu_qianfan_endpoint import QianfanChatEndpoint
|
||||
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
|
||||
from langchain.schema import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatGeneration,
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
LLMResult,
|
||||
)
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
_FUNCTIONS: Any = [
|
||||
{
|
||||
"name": "format_person_info",
|
||||
"description": (
|
||||
"Output formatter. Should always be used to format your response to the"
|
||||
" user."
|
||||
),
|
||||
"parameters": {
|
||||
"title": "Person",
|
||||
"description": "Identifying information about a person.",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"title": "Name",
|
||||
"description": "The person's name",
|
||||
"type": "string",
|
||||
},
|
||||
"age": {
|
||||
"title": "Age",
|
||||
"description": "The person's age",
|
||||
"type": "integer",
|
||||
},
|
||||
"fav_food": {
|
||||
"title": "Fav Food",
|
||||
"description": "The person's favorite food",
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
"required": ["name", "age"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "get_current_temperature",
|
||||
"description": ("Used to get the location's temperature."),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "city name",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["centigrade", "Fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["location", "unit"],
|
||||
},
|
||||
"responses": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"temperature": {
|
||||
"type": "integer",
|
||||
"description": "city temperature",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["centigrade", "Fahrenheit"],
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_default_call() -> None:
|
||||
"""Test default model(`ERNIE-Bot`) call."""
|
||||
@ -28,6 +99,14 @@ def test_model() -> None:
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_model_param() -> None:
|
||||
"""Test model params works."""
|
||||
chat = QianfanChatEndpoint()
|
||||
response = chat(model="BLOOMZ-7B", messages=[HumanMessage(content="Hello")])
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_endpoint() -> None:
|
||||
"""Test user custom model deployments like some open source models."""
|
||||
chat = QianfanChatEndpoint(endpoint="qianfan_bloomz_7b_compressed")
|
||||
@ -36,6 +115,18 @@ def test_endpoint() -> None:
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_endpoint_param() -> None:
|
||||
"""Test user custom model deployments like some open source models."""
|
||||
chat = QianfanChatEndpoint()
|
||||
response = chat(
|
||||
messages=[
|
||||
HumanMessage(endpoint="qianfan_bloomz_7b_compressed", content="Hello")
|
||||
]
|
||||
)
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_multiple_history() -> None:
|
||||
"""Tests multiple history works."""
|
||||
chat = QianfanChatEndpoint()
|
||||
@ -83,3 +174,60 @@ def test_multiple_messages() -> None:
|
||||
assert isinstance(generation, ChatGeneration)
|
||||
assert isinstance(generation.text, str)
|
||||
assert generation.text == generation.message.content
|
||||
|
||||
|
||||
def test_functions_call_thoughts() -> None:
|
||||
chat = QianfanChatEndpoint(model="ERNIE-Bot")
|
||||
|
||||
prompt_tmpl = "Use the given functions to answer following question: {input}"
|
||||
prompt_msgs = [
|
||||
HumanMessagePromptTemplate.from_template(prompt_tmpl),
|
||||
]
|
||||
prompt = ChatPromptTemplate(messages=prompt_msgs)
|
||||
|
||||
chain = create_openai_fn_chain(
|
||||
_FUNCTIONS,
|
||||
chat,
|
||||
prompt,
|
||||
output_parser=None,
|
||||
)
|
||||
|
||||
message = HumanMessage(content="What's the temperature in Shanghai today?")
|
||||
response = chain.generate([{"input": message}])
|
||||
assert isinstance(response.generations[0][0], ChatGeneration)
|
||||
assert isinstance(response.generations[0][0].message, AIMessage)
|
||||
assert "function_call" in response.generations[0][0].message.additional_kwargs
|
||||
|
||||
|
||||
def test_functions_call() -> None:
|
||||
chat = QianfanChatEndpoint(model="ERNIE-Bot")
|
||||
|
||||
prompt = ChatPromptTemplate(
|
||||
messages=[
|
||||
HumanMessage(content="What's the temperature in Shanghai today?"),
|
||||
AIMessage(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {
|
||||
"name": "get_current_temperature",
|
||||
"thoughts": "i will use get_current_temperature "
|
||||
"to resolve the questions",
|
||||
"arguments": '{"location":"Shanghai","unit":"centigrade"}',
|
||||
}
|
||||
},
|
||||
),
|
||||
FunctionMessage(
|
||||
name="get_current_weather",
|
||||
content='{"temperature": "25", \
|
||||
"unit": "摄氏度", "description": "晴朗"}',
|
||||
),
|
||||
]
|
||||
)
|
||||
llm_chain = create_openai_fn_chain(
|
||||
_FUNCTIONS,
|
||||
chat,
|
||||
prompt,
|
||||
output_parser=None,
|
||||
)
|
||||
resp = llm_chain.generate([{}])
|
||||
assert isinstance(resp, LLMResult)
|
||||
|
Loading…
Reference in New Issue
Block a user