openai[patch]: use tool_calls in request (#20272)

This commit is contained in:
Bagatur 2024-04-11 03:55:52 -07:00 committed by GitHub
parent e936fba428
commit c706689413
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 93 additions and 12 deletions

View File

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import json
import logging import logging
import os import os
import sys import sys
@ -50,8 +51,10 @@ from langchain_core.messages import (
FunctionMessageChunk, FunctionMessageChunk,
HumanMessage, HumanMessage,
HumanMessageChunk, HumanMessageChunk,
InvalidToolCall,
SystemMessage, SystemMessage,
SystemMessageChunk, SystemMessageChunk,
ToolCall,
ToolMessage, ToolMessage,
ToolMessageChunk, ToolMessageChunk,
) )
@ -169,20 +172,25 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
message_dict["role"] = "assistant" message_dict["role"] = "assistant"
if "function_call" in message.additional_kwargs: if "function_call" in message.additional_kwargs:
message_dict["function_call"] = message.additional_kwargs["function_call"] message_dict["function_call"] = message.additional_kwargs["function_call"]
# If function call only, content is None not empty string if message.tool_calls or message.invalid_tool_calls:
if message_dict["content"] == "": message_dict["tool_calls"] = [
message_dict["content"] = None _lc_tool_call_to_openai_tool_call(tc) for tc in message.tool_calls
if "tool_calls" in message.additional_kwargs: ] + [
_lc_invalid_tool_call_to_openai_tool_call(tc)
for tc in message.invalid_tool_calls
]
elif "tool_calls" in message.additional_kwargs:
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"] message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
# If tool calls only, content is None not empty string
if message_dict["content"] == "":
message_dict["content"] = None
tool_call_supported_props = {"id", "type", "function"} tool_call_supported_props = {"id", "type", "function"}
message_dict["tool_calls"] = [ message_dict["tool_calls"] = [
{k: v for k, v in tool_call.items() if k in tool_call_supported_props} {k: v for k, v in tool_call.items() if k in tool_call_supported_props}
for tool_call in message_dict["tool_calls"] for tool_call in message_dict["tool_calls"]
] ]
else:
pass
# If tool calls present, content null value should be None not empty string.
if "function_call" in message_dict or "tool_calls" in message_dict:
message_dict["content"] = message_dict["content"] or None
elif isinstance(message, SystemMessage): elif isinstance(message, SystemMessage):
message_dict["role"] = "system" message_dict["role"] = "system"
elif isinstance(message, FunctionMessage): elif isinstance(message, FunctionMessage):
@ -1067,3 +1075,27 @@ class ChatOpenAI(BaseChatModel):
def _is_pydantic_class(obj: Any) -> bool: def _is_pydantic_class(obj: Any) -> bool:
return isinstance(obj, type) and issubclass(obj, BaseModel) return isinstance(obj, type) and issubclass(obj, BaseModel)
def _lc_tool_call_to_openai_tool_call(tool_call: ToolCall) -> dict:
return {
"type": "function",
"id": tool_call["id"],
"function": {
"name": tool_call["name"],
"arguments": json.dumps(tool_call["args"]),
},
}
def _lc_invalid_tool_call_to_openai_tool_call(
invalid_tool_call: InvalidToolCall,
) -> dict:
return {
"type": "function",
"id": invalid_tool_call["id"],
"function": {
"name": invalid_tool_call["name"],
"arguments": invalid_tool_call["args"],
},
}

View File

@ -10,6 +10,7 @@ from langchain_core.messages import (
BaseMessageChunk, BaseMessageChunk,
HumanMessage, HumanMessage,
SystemMessage, SystemMessage,
ToolCall,
ToolMessage, ToolMessage,
) )
from langchain_core.outputs import ( from langchain_core.outputs import (
@ -519,6 +520,49 @@ def test_tool_use() -> None:
llm_with_tool.invoke(msgs) llm_with_tool.invoke(msgs)
def test_manual_tool_call_msg() -> None:
"""Test passing in manually construct tool call message."""
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
llm_with_tool = llm.bind_tools(tools=[GenerateUsername])
msgs: List = [
HumanMessage("Sally has green hair, what would her username be?"),
AIMessage(
content="",
tool_calls=[
ToolCall(
name="GenerateUsername",
args={"name": "Sally", "hair_color": "green"},
id="foo",
)
],
),
ToolMessage("sally_green_hair", tool_call_id="foo"),
]
output: AIMessage = cast(AIMessage, llm_with_tool.invoke(msgs))
assert output.content
# Should not have called the tool again.
assert not output.tool_calls and not output.invalid_tool_calls
# OpenAI should error when tool call id doesn't match across AIMessage and
# ToolMessage
msgs = [
HumanMessage("Sally has green hair, what would her username be?"),
AIMessage(
content="",
tool_calls=[
ToolCall(
name="GenerateUsername",
args={"name": "Sally", "hair_color": "green"},
id="bar",
)
],
),
ToolMessage("sally_green_hair", tool_call_id="foo"),
]
with pytest.raises(Exception):
llm_with_tool.invoke(msgs)
def test_openai_structured_output() -> None: def test_openai_structured_output() -> None:
class MyModel(BaseModel): class MyModel(BaseModel):
"""A Person""" """A Person"""

View File

@ -126,7 +126,7 @@ def test__convert_dict_to_message_tool_call() -> None:
assert _convert_message_to_dict(expected_output) == message assert _convert_message_to_dict(expected_output) == message
# Test malformed tool call # Test malformed tool call
raw_tool_calls = [ raw_tool_calls: list = [
{ {
"id": "call_wm0JY6CdwOMZ4eTxHWUThDNz", "id": "call_wm0JY6CdwOMZ4eTxHWUThDNz",
"function": { "function": {
@ -144,6 +144,7 @@ def test__convert_dict_to_message_tool_call() -> None:
"type": "function", "type": "function",
}, },
] ]
raw_tool_calls = list(sorted(raw_tool_calls, key=lambda x: x["id"]))
message = {"role": "assistant", "content": None, "tool_calls": raw_tool_calls} message = {"role": "assistant", "content": None, "tool_calls": raw_tool_calls}
result = _convert_dict_to_message(message) result = _convert_dict_to_message(message)
expected_output = AIMessage( expected_output = AIMessage(
@ -166,7 +167,11 @@ def test__convert_dict_to_message_tool_call() -> None:
], ],
) )
assert result == expected_output assert result == expected_output
assert _convert_message_to_dict(expected_output) == message reverted_message_dict = _convert_message_to_dict(expected_output)
reverted_message_dict["tool_calls"] = list(
sorted(reverted_message_dict["tool_calls"], key=lambda x: x["id"])
)
assert reverted_message_dict == message
@pytest.fixture @pytest.fixture