diff --git a/libs/partners/anthropic/langchain_anthropic/chat_models.py b/libs/partners/anthropic/langchain_anthropic/chat_models.py index d7e63c8c04..ce0f3bb98a 100644 --- a/libs/partners/anthropic/langchain_anthropic/chat_models.py +++ b/libs/partners/anthropic/langchain_anthropic/chat_models.py @@ -104,7 +104,12 @@ def _merge_messages( for curr in messages: curr = curr.copy(deep=True) if isinstance(curr, ToolMessage): - if isinstance(curr.content, str): + if isinstance(curr.content, list) and all( + isinstance(block, dict) and block.get("type") == "tool_result" + for block in curr.content + ): + curr = HumanMessage(curr.content) # type: ignore[misc] + else: curr = HumanMessage( # type: ignore[misc] [ { @@ -114,8 +119,6 @@ def _merge_messages( } ] ) - else: - curr = HumanMessage(curr.content) # type: ignore[misc] last = merged[-1] if merged else None if isinstance(last, HumanMessage) and isinstance(curr, HumanMessage): if isinstance(last.content, str): diff --git a/libs/partners/anthropic/tests/unit_tests/test_chat_models.py b/libs/partners/anthropic/tests/unit_tests/test_chat_models.py index 31d12db98c..d7e21edb9a 100644 --- a/libs/partners/anthropic/tests/unit_tests/test_chat_models.py +++ b/libs/partners/anthropic/tests/unit_tests/test_chat_models.py @@ -140,7 +140,19 @@ def test__merge_messages() -> None: ] ), ToolMessage("buz output", tool_call_id="1"), # type: ignore[misc] - ToolMessage("blah output", tool_call_id="2"), # type: ignore[misc] + ToolMessage( + content=[ + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/jpeg", + "data": "fake_image_data", + }, + }, + ], + tool_call_id="2", + ), # type: ignore[misc] HumanMessage("next thing"), # type: ignore[misc] ] expected = [ @@ -169,7 +181,20 @@ def test__merge_messages() -> None: HumanMessage( # type: ignore[misc] [ {"type": "tool_result", "content": "buz output", "tool_use_id": "1"}, - {"type": "tool_result", "content": "blah output", "tool_use_id": "2"}, + { + "type": "tool_result", + "content": [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/jpeg", + "data": "fake_image_data", + }, + }, + ], + "tool_use_id": "2", + }, {"type": "text", "text": "next thing"}, ] ), @@ -177,6 +202,27 @@ def test__merge_messages() -> None: actual = _merge_messages(messages) assert expected == actual + # Test tool message case + messages = [ + ToolMessage("buz output", tool_call_id="1"), # type: ignore[misc] + ToolMessage( # type: ignore[misc] + content=[ + {"type": "tool_result", "content": "blah output", "tool_use_id": "2"} + ], + tool_call_id="2", + ), + ] + expected = [ + HumanMessage( # type: ignore[misc] + [ + {"type": "tool_result", "content": "buz output", "tool_use_id": "1"}, + {"type": "tool_result", "content": "blah output", "tool_use_id": "2"}, + ] + ) + ] + actual = _merge_messages(messages) + assert expected == actual + def test__merge_messages_mutation() -> None: original_messages = [