diff --git a/libs/community/langchain_community/adapters/openai.py b/libs/community/langchain_community/adapters/openai.py index 0af759ebf5..fbdf84be65 100644 --- a/libs/community/langchain_community/adapters/openai.py +++ b/libs/community/langchain_community/adapters/openai.py @@ -89,7 +89,14 @@ def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: elif role == "function": return FunctionMessage(content=_dict["content"], name=_dict["name"]) elif role == "tool": - return ToolMessage(content=_dict["content"], tool_call_id=_dict["tool_call_id"]) + additional_kwargs = {} + if "name" in _dict: + additional_kwargs["name"] = _dict["name"] + return ToolMessage( + content=_dict["content"], + tool_call_id=_dict["tool_call_id"], + additional_kwargs=additional_kwargs, + ) else: return ChatMessage(content=_dict["content"], role=role) diff --git a/libs/langchain/langchain/agents/format_scratchpad/openai_tools.py b/libs/langchain/langchain/agents/format_scratchpad/openai_tools.py index d9717dc803..123a92e5b0 100644 --- a/libs/langchain/langchain/agents/format_scratchpad/openai_tools.py +++ b/libs/langchain/langchain/agents/format_scratchpad/openai_tools.py @@ -31,6 +31,7 @@ def _create_tool_message( return ToolMessage( tool_call_id=agent_action.tool_call_id, content=content, + additional_kwargs={"name": agent_action.tool}, ) diff --git a/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_openai_tools.py b/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_openai_tools.py new file mode 100644 index 0000000000..96b4c7b88d --- /dev/null +++ b/libs/langchain/tests/unit_tests/agents/format_scratchpad/test_openai_tools.py @@ -0,0 +1,94 @@ +from langchain_core.messages import AIMessage, ToolMessage + +from langchain.agents.format_scratchpad.openai_tools import ( + format_to_openai_tool_messages, +) +from langchain.agents.output_parsers.openai_tools import ( + parse_ai_message_to_openai_tool_action, +) + + +def test_calls_convert_agent_action_to_messages() -> None: + additional_kwargs1 = { + "tool_calls": [ + { + "id": "call_abcd12345", + "function": {"arguments": '{"a": 3, "b": 5}', "name": "add"}, + "type": "function", + } + ], + } + message1 = AIMessage(content="", additional_kwargs=additional_kwargs1) + + actions1 = parse_ai_message_to_openai_tool_action(message1) + additional_kwargs2 = { + "tool_calls": [ + { + "id": "call_abcd54321", + "function": {"arguments": '{"a": 3, "b": 5}', "name": "subtract"}, + "type": "function", + } + ], + } + message2 = AIMessage(content="", additional_kwargs=additional_kwargs2) + actions2 = parse_ai_message_to_openai_tool_action(message2) + + additional_kwargs3 = { + "tool_calls": [ + { + "id": "call_abcd67890", + "function": {"arguments": '{"a": 3, "b": 5}', "name": "multiply"}, + "type": "function", + }, + { + "id": "call_abcd09876", + "function": {"arguments": '{"a": 3, "b": 5}', "name": "divide"}, + "type": "function", + }, + ], + } + message3 = AIMessage(content="", additional_kwargs=additional_kwargs3) + actions3 = parse_ai_message_to_openai_tool_action(message3) + # for mypy + assert isinstance(actions1, list) + assert isinstance(actions2, list) + assert isinstance(actions3, list) + + intermediate_steps = [ + (actions1[0], "observation1"), + (actions2[0], "observation2"), + (actions3[0], "observation3"), + (actions3[1], "observation4"), + ] + expected_messages = [ + message1, + ToolMessage( + tool_call_id="call_abcd12345", + content="observation1", + additional_kwargs={"name": "add"}, + ), + message2, + ToolMessage( + tool_call_id="call_abcd54321", + content="observation2", + additional_kwargs={"name": "subtract"}, + ), + message3, + ToolMessage( + tool_call_id="call_abcd67890", + content="observation3", + additional_kwargs={"name": "multiply"}, + ), + ToolMessage( + tool_call_id="call_abcd09876", + content="observation4", + additional_kwargs={"name": "divide"}, + ), + ] + output = format_to_openai_tool_messages(intermediate_steps) + assert output == expected_messages + + +def test_handles_empty_input_list() -> None: + output = format_to_openai_tool_messages([]) + assert output == []