mirror of
https://github.com/hwchase17/langchain
synced 2024-11-02 09:40:22 +00:00
openai[patch]: use tool_calls in request (#20272)
This commit is contained in:
parent
e936fba428
commit
c706689413
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
@ -50,8 +51,10 @@ from langchain_core.messages import (
|
|||||||
FunctionMessageChunk,
|
FunctionMessageChunk,
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
HumanMessageChunk,
|
HumanMessageChunk,
|
||||||
|
InvalidToolCall,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
SystemMessageChunk,
|
SystemMessageChunk,
|
||||||
|
ToolCall,
|
||||||
ToolMessage,
|
ToolMessage,
|
||||||
ToolMessageChunk,
|
ToolMessageChunk,
|
||||||
)
|
)
|
||||||
@ -169,20 +172,25 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
|||||||
message_dict["role"] = "assistant"
|
message_dict["role"] = "assistant"
|
||||||
if "function_call" in message.additional_kwargs:
|
if "function_call" in message.additional_kwargs:
|
||||||
message_dict["function_call"] = message.additional_kwargs["function_call"]
|
message_dict["function_call"] = message.additional_kwargs["function_call"]
|
||||||
# If function call only, content is None not empty string
|
if message.tool_calls or message.invalid_tool_calls:
|
||||||
if message_dict["content"] == "":
|
message_dict["tool_calls"] = [
|
||||||
message_dict["content"] = None
|
_lc_tool_call_to_openai_tool_call(tc) for tc in message.tool_calls
|
||||||
if "tool_calls" in message.additional_kwargs:
|
] + [
|
||||||
|
_lc_invalid_tool_call_to_openai_tool_call(tc)
|
||||||
|
for tc in message.invalid_tool_calls
|
||||||
|
]
|
||||||
|
elif "tool_calls" in message.additional_kwargs:
|
||||||
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
|
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
|
||||||
# If tool calls only, content is None not empty string
|
|
||||||
if message_dict["content"] == "":
|
|
||||||
message_dict["content"] = None
|
|
||||||
|
|
||||||
tool_call_supported_props = {"id", "type", "function"}
|
tool_call_supported_props = {"id", "type", "function"}
|
||||||
message_dict["tool_calls"] = [
|
message_dict["tool_calls"] = [
|
||||||
{k: v for k, v in tool_call.items() if k in tool_call_supported_props}
|
{k: v for k, v in tool_call.items() if k in tool_call_supported_props}
|
||||||
for tool_call in message_dict["tool_calls"]
|
for tool_call in message_dict["tool_calls"]
|
||||||
]
|
]
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
# If tool calls present, content null value should be None not empty string.
|
||||||
|
if "function_call" in message_dict or "tool_calls" in message_dict:
|
||||||
|
message_dict["content"] = message_dict["content"] or None
|
||||||
elif isinstance(message, SystemMessage):
|
elif isinstance(message, SystemMessage):
|
||||||
message_dict["role"] = "system"
|
message_dict["role"] = "system"
|
||||||
elif isinstance(message, FunctionMessage):
|
elif isinstance(message, FunctionMessage):
|
||||||
@ -1067,3 +1075,27 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
|
|
||||||
def _is_pydantic_class(obj: Any) -> bool:
|
def _is_pydantic_class(obj: Any) -> bool:
|
||||||
return isinstance(obj, type) and issubclass(obj, BaseModel)
|
return isinstance(obj, type) and issubclass(obj, BaseModel)
|
||||||
|
|
||||||
|
|
||||||
|
def _lc_tool_call_to_openai_tool_call(tool_call: ToolCall) -> dict:
|
||||||
|
return {
|
||||||
|
"type": "function",
|
||||||
|
"id": tool_call["id"],
|
||||||
|
"function": {
|
||||||
|
"name": tool_call["name"],
|
||||||
|
"arguments": json.dumps(tool_call["args"]),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _lc_invalid_tool_call_to_openai_tool_call(
|
||||||
|
invalid_tool_call: InvalidToolCall,
|
||||||
|
) -> dict:
|
||||||
|
return {
|
||||||
|
"type": "function",
|
||||||
|
"id": invalid_tool_call["id"],
|
||||||
|
"function": {
|
||||||
|
"name": invalid_tool_call["name"],
|
||||||
|
"arguments": invalid_tool_call["args"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
@ -10,6 +10,7 @@ from langchain_core.messages import (
|
|||||||
BaseMessageChunk,
|
BaseMessageChunk,
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
|
ToolCall,
|
||||||
ToolMessage,
|
ToolMessage,
|
||||||
)
|
)
|
||||||
from langchain_core.outputs import (
|
from langchain_core.outputs import (
|
||||||
@ -519,6 +520,49 @@ def test_tool_use() -> None:
|
|||||||
llm_with_tool.invoke(msgs)
|
llm_with_tool.invoke(msgs)
|
||||||
|
|
||||||
|
|
||||||
|
def test_manual_tool_call_msg() -> None:
|
||||||
|
"""Test passing in manually construct tool call message."""
|
||||||
|
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
|
||||||
|
llm_with_tool = llm.bind_tools(tools=[GenerateUsername])
|
||||||
|
msgs: List = [
|
||||||
|
HumanMessage("Sally has green hair, what would her username be?"),
|
||||||
|
AIMessage(
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
ToolCall(
|
||||||
|
name="GenerateUsername",
|
||||||
|
args={"name": "Sally", "hair_color": "green"},
|
||||||
|
id="foo",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
ToolMessage("sally_green_hair", tool_call_id="foo"),
|
||||||
|
]
|
||||||
|
output: AIMessage = cast(AIMessage, llm_with_tool.invoke(msgs))
|
||||||
|
assert output.content
|
||||||
|
# Should not have called the tool again.
|
||||||
|
assert not output.tool_calls and not output.invalid_tool_calls
|
||||||
|
|
||||||
|
# OpenAI should error when tool call id doesn't match across AIMessage and
|
||||||
|
# ToolMessage
|
||||||
|
msgs = [
|
||||||
|
HumanMessage("Sally has green hair, what would her username be?"),
|
||||||
|
AIMessage(
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
ToolCall(
|
||||||
|
name="GenerateUsername",
|
||||||
|
args={"name": "Sally", "hair_color": "green"},
|
||||||
|
id="bar",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
ToolMessage("sally_green_hair", tool_call_id="foo"),
|
||||||
|
]
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
llm_with_tool.invoke(msgs)
|
||||||
|
|
||||||
|
|
||||||
def test_openai_structured_output() -> None:
|
def test_openai_structured_output() -> None:
|
||||||
class MyModel(BaseModel):
|
class MyModel(BaseModel):
|
||||||
"""A Person"""
|
"""A Person"""
|
||||||
|
@ -126,7 +126,7 @@ def test__convert_dict_to_message_tool_call() -> None:
|
|||||||
assert _convert_message_to_dict(expected_output) == message
|
assert _convert_message_to_dict(expected_output) == message
|
||||||
|
|
||||||
# Test malformed tool call
|
# Test malformed tool call
|
||||||
raw_tool_calls = [
|
raw_tool_calls: list = [
|
||||||
{
|
{
|
||||||
"id": "call_wm0JY6CdwOMZ4eTxHWUThDNz",
|
"id": "call_wm0JY6CdwOMZ4eTxHWUThDNz",
|
||||||
"function": {
|
"function": {
|
||||||
@ -144,6 +144,7 @@ def test__convert_dict_to_message_tool_call() -> None:
|
|||||||
"type": "function",
|
"type": "function",
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
raw_tool_calls = list(sorted(raw_tool_calls, key=lambda x: x["id"]))
|
||||||
message = {"role": "assistant", "content": None, "tool_calls": raw_tool_calls}
|
message = {"role": "assistant", "content": None, "tool_calls": raw_tool_calls}
|
||||||
result = _convert_dict_to_message(message)
|
result = _convert_dict_to_message(message)
|
||||||
expected_output = AIMessage(
|
expected_output = AIMessage(
|
||||||
@ -166,7 +167,11 @@ def test__convert_dict_to_message_tool_call() -> None:
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
assert result == expected_output
|
assert result == expected_output
|
||||||
assert _convert_message_to_dict(expected_output) == message
|
reverted_message_dict = _convert_message_to_dict(expected_output)
|
||||||
|
reverted_message_dict["tool_calls"] = list(
|
||||||
|
sorted(reverted_message_dict["tool_calls"], key=lambda x: x["id"])
|
||||||
|
)
|
||||||
|
assert reverted_message_dict == message
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
Loading…
Reference in New Issue
Block a user