diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index ea4df08969..50308b6327 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -2,6 +2,7 @@ from __future__ import annotations +import json import logging import os import sys @@ -50,8 +51,10 @@ from langchain_core.messages import ( FunctionMessageChunk, HumanMessage, HumanMessageChunk, + InvalidToolCall, SystemMessage, SystemMessageChunk, + ToolCall, ToolMessage, ToolMessageChunk, ) @@ -169,20 +172,25 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: message_dict["role"] = "assistant" if "function_call" in message.additional_kwargs: message_dict["function_call"] = message.additional_kwargs["function_call"] - # If function call only, content is None not empty string - if message_dict["content"] == "": - message_dict["content"] = None - if "tool_calls" in message.additional_kwargs: + if message.tool_calls or message.invalid_tool_calls: + message_dict["tool_calls"] = [ + _lc_tool_call_to_openai_tool_call(tc) for tc in message.tool_calls + ] + [ + _lc_invalid_tool_call_to_openai_tool_call(tc) + for tc in message.invalid_tool_calls + ] + elif "tool_calls" in message.additional_kwargs: message_dict["tool_calls"] = message.additional_kwargs["tool_calls"] - # If tool calls only, content is None not empty string - if message_dict["content"] == "": - message_dict["content"] = None - tool_call_supported_props = {"id", "type", "function"} message_dict["tool_calls"] = [ {k: v for k, v in tool_call.items() if k in tool_call_supported_props} for tool_call in message_dict["tool_calls"] ] + else: + pass + # If tool calls present, content null value should be None not empty string. + if "function_call" in message_dict or "tool_calls" in message_dict: + message_dict["content"] = message_dict["content"] or None elif isinstance(message, SystemMessage): message_dict["role"] = "system" elif isinstance(message, FunctionMessage): @@ -1067,3 +1075,27 @@ class ChatOpenAI(BaseChatModel): def _is_pydantic_class(obj: Any) -> bool: return isinstance(obj, type) and issubclass(obj, BaseModel) + + +def _lc_tool_call_to_openai_tool_call(tool_call: ToolCall) -> dict: + return { + "type": "function", + "id": tool_call["id"], + "function": { + "name": tool_call["name"], + "arguments": json.dumps(tool_call["args"]), + }, + } + + +def _lc_invalid_tool_call_to_openai_tool_call( + invalid_tool_call: InvalidToolCall, +) -> dict: + return { + "type": "function", + "id": invalid_tool_call["id"], + "function": { + "name": invalid_tool_call["name"], + "arguments": invalid_tool_call["args"], + }, + } diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py index 03cdcbee6c..e1f15ec9c5 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py @@ -10,6 +10,7 @@ from langchain_core.messages import ( BaseMessageChunk, HumanMessage, SystemMessage, + ToolCall, ToolMessage, ) from langchain_core.outputs import ( @@ -519,6 +520,49 @@ def test_tool_use() -> None: llm_with_tool.invoke(msgs) +def test_manual_tool_call_msg() -> None: + """Test passing in manually construct tool call message.""" + llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0) + llm_with_tool = llm.bind_tools(tools=[GenerateUsername]) + msgs: List = [ + HumanMessage("Sally has green hair, what would her username be?"), + AIMessage( + content="", + tool_calls=[ + ToolCall( + name="GenerateUsername", + args={"name": "Sally", "hair_color": "green"}, + id="foo", + ) + ], + ), + ToolMessage("sally_green_hair", tool_call_id="foo"), + ] + output: AIMessage = cast(AIMessage, llm_with_tool.invoke(msgs)) + assert output.content + # Should not have called the tool again. + assert not output.tool_calls and not output.invalid_tool_calls + + # OpenAI should error when tool call id doesn't match across AIMessage and + # ToolMessage + msgs = [ + HumanMessage("Sally has green hair, what would her username be?"), + AIMessage( + content="", + tool_calls=[ + ToolCall( + name="GenerateUsername", + args={"name": "Sally", "hair_color": "green"}, + id="bar", + ) + ], + ), + ToolMessage("sally_green_hair", tool_call_id="foo"), + ] + with pytest.raises(Exception): + llm_with_tool.invoke(msgs) + + def test_openai_structured_output() -> None: class MyModel(BaseModel): """A Person""" diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py index 1b8668c955..9665af8f64 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py @@ -104,7 +104,7 @@ def test__convert_dict_to_message_tool_call() -> None: raw_tool_call = { "id": "call_wm0JY6CdwOMZ4eTxHWUThDNz", "function": { - "arguments": '{"name":"Sally","hair_color":"green"}', + "arguments": '{"name": "Sally", "hair_color": "green"}', "name": "GenerateUsername", }, "type": "function", @@ -126,7 +126,7 @@ def test__convert_dict_to_message_tool_call() -> None: assert _convert_message_to_dict(expected_output) == message # Test malformed tool call - raw_tool_calls = [ + raw_tool_calls: list = [ { "id": "call_wm0JY6CdwOMZ4eTxHWUThDNz", "function": { @@ -138,12 +138,13 @@ def test__convert_dict_to_message_tool_call() -> None: { "id": "call_abc123", "function": { - "arguments": '{"name":"Sally","hair_color":"green"}', + "arguments": '{"name": "Sally", "hair_color": "green"}', "name": "GenerateUsername", }, "type": "function", }, ] + raw_tool_calls = list(sorted(raw_tool_calls, key=lambda x: x["id"])) message = {"role": "assistant", "content": None, "tool_calls": raw_tool_calls} result = _convert_dict_to_message(message) expected_output = AIMessage( @@ -166,7 +167,11 @@ def test__convert_dict_to_message_tool_call() -> None: ], ) assert result == expected_output - assert _convert_message_to_dict(expected_output) == message + reverted_message_dict = _convert_message_to_dict(expected_output) + reverted_message_dict["tool_calls"] = list( + sorted(reverted_message_dict["tool_calls"], key=lambda x: x["id"]) + ) + assert reverted_message_dict == message @pytest.fixture