mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
community: Improve QianfanChatEndpoint tool result to model (#24466)
- **Description:** `QianfanChatEndpoint` When using tool result to answer questions, the content of the tool is required to be in Dict format. Of course, this can require users to return Dict format when calling the tool, but in order to be consistent with other Chat Models, I think such modifications are necessary.
This commit is contained in:
parent
02f0a29293
commit
721f709dec
@ -1,3 +1,4 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from operator import itemgetter
|
||||
@ -65,7 +66,7 @@ def convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
elif isinstance(message, (FunctionMessage, ToolMessage)):
|
||||
message_dict = {
|
||||
"role": "function",
|
||||
"content": message.content,
|
||||
"content": _create_tool_content(message.content),
|
||||
"name": message.name or message.additional_kwargs.get("name"),
|
||||
}
|
||||
else:
|
||||
@ -74,6 +75,20 @@ def convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
return message_dict
|
||||
|
||||
|
||||
def _create_tool_content(content: Union[str, List[Union[str, Dict[Any, Any]]]]) -> str:
|
||||
"""Convert tool content to dict scheme."""
|
||||
if isinstance(content, str):
|
||||
try:
|
||||
if isinstance(json.loads(content), dict):
|
||||
return content
|
||||
else:
|
||||
return json.dumps({"tool_result": content})
|
||||
except json.JSONDecodeError:
|
||||
return json.dumps({"tool_result": content})
|
||||
else:
|
||||
return json.dumps({"tool_result": content})
|
||||
|
||||
|
||||
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> AIMessage:
|
||||
content = _dict.get("result", "") or ""
|
||||
additional_kwargs: Mapping[str, Any] = {}
|
||||
|
@ -0,0 +1,39 @@
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||
from langchain_core.messages.tool import ToolCall
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from langchain_community.chat_models import QianfanChatEndpoint
|
||||
|
||||
|
||||
@tool
|
||||
def get_current_weather(location: str, unit: str = "摄氏度") -> str:
|
||||
"""获取指定地点的天气"""
|
||||
return f"{location}是晴朗,25{unit}左右。"
|
||||
|
||||
|
||||
def test_chat_qianfan_tool_result_to_model() -> None:
|
||||
"""Test QianfanChatEndpoint invoke with tool_calling result."""
|
||||
messages = [
|
||||
HumanMessage("上海天气怎么样?"),
|
||||
AIMessage(
|
||||
content=" ",
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
name="get_current_weather",
|
||||
args={"location": "上海", "unit": "摄氏度"},
|
||||
id="foo",
|
||||
type="tool_call",
|
||||
),
|
||||
],
|
||||
),
|
||||
ToolMessage(
|
||||
content="上海是晴天,25度左右。",
|
||||
tool_call_id="foo",
|
||||
name="get_current_weather",
|
||||
),
|
||||
]
|
||||
chat = QianfanChatEndpoint(model="ERNIE-3.5-8K") # type: ignore[call-arg]
|
||||
llm_with_tool = chat.bind_tools([get_current_weather])
|
||||
response = llm_with_tool.invoke(messages)
|
||||
assert isinstance(response, AIMessage)
|
||||
print(response.content) # noqa: T201
|
Loading…
Reference in New Issue
Block a user