From 73c76b962894d6ba65e541339c593c45adc4737b Mon Sep 17 00:00:00 2001 From: ccurme Date: Thu, 13 Jun 2024 23:14:23 -0400 Subject: [PATCH] anthropic[patch]: always add tool_result type to ToolMessage content (#22721) Anthropic tool results can contain image data, which are typically represented with content blocks having `"type": "image"`. Currently, these content blocks are passed as-is as human/user messages to Anthropic, which raises BadRequestError as it expects a tool_result block to follow a tool_use. Here we update ChatAnthropic to nest the content blocks inside a tool_result content block. Example: ```python import base64 import httpx from langchain_anthropic import ChatAnthropic from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from langchain_core.pydantic_v1 import BaseModel, Field # Fetch image image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" image_data = base64.b64encode(httpx.get(image_url).content).decode("utf-8") class FetchImage(BaseModel): should_fetch: bool = Field(..., description="Whether an image is requested.") llm = ChatAnthropic(model="claude-3-sonnet-20240229").bind_tools([FetchImage]) messages = [ HumanMessage(content="Could you summon a beautiful image please?"), AIMessage( content=[ { "type": "tool_use", "id": "toolu_01Rn6Qvj5m7955x9m9Pfxbcx", "name": "FetchImage", "input": {"should_fetch": True}, }, ], tool_calls=[ { "name": "FetchImage", "args": {"should_fetch": True}, "id": "toolu_01Rn6Qvj5m7955x9m9Pfxbcx", }, ], ), ToolMessage( name="FetchImage", content=[ { "type": "image", "source": { "type": "base64", "media_type": "image/jpeg", "data": image_data, }, }, ], tool_call_id="toolu_01Rn6Qvj5m7955x9m9Pfxbcx", ), ] llm.invoke(messages) ``` Trace: https://smith.langchain.com/public/d27e4fc1-a96d-41e1-9f52-54f5004122db/r --- .../langchain_anthropic/chat_models.py | 9 ++-- .../tests/unit_tests/test_chat_models.py | 50 ++++++++++++++++++- 2 files changed, 54 insertions(+), 5 deletions(-) diff --git a/libs/partners/anthropic/langchain_anthropic/chat_models.py b/libs/partners/anthropic/langchain_anthropic/chat_models.py index d7e63c8c04..ce0f3bb98a 100644 --- a/libs/partners/anthropic/langchain_anthropic/chat_models.py +++ b/libs/partners/anthropic/langchain_anthropic/chat_models.py @@ -104,7 +104,12 @@ def _merge_messages( for curr in messages: curr = curr.copy(deep=True) if isinstance(curr, ToolMessage): - if isinstance(curr.content, str): + if isinstance(curr.content, list) and all( + isinstance(block, dict) and block.get("type") == "tool_result" + for block in curr.content + ): + curr = HumanMessage(curr.content) # type: ignore[misc] + else: curr = HumanMessage( # type: ignore[misc] [ { @@ -114,8 +119,6 @@ def _merge_messages( } ] ) - else: - curr = HumanMessage(curr.content) # type: ignore[misc] last = merged[-1] if merged else None if isinstance(last, HumanMessage) and isinstance(curr, HumanMessage): if isinstance(last.content, str): diff --git a/libs/partners/anthropic/tests/unit_tests/test_chat_models.py b/libs/partners/anthropic/tests/unit_tests/test_chat_models.py index 31d12db98c..d7e21edb9a 100644 --- a/libs/partners/anthropic/tests/unit_tests/test_chat_models.py +++ b/libs/partners/anthropic/tests/unit_tests/test_chat_models.py @@ -140,7 +140,19 @@ def test__merge_messages() -> None: ] ), ToolMessage("buz output", tool_call_id="1"), # type: ignore[misc] - ToolMessage("blah output", tool_call_id="2"), # type: ignore[misc] + ToolMessage( + content=[ + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/jpeg", + "data": "fake_image_data", + }, + }, + ], + tool_call_id="2", + ), # type: ignore[misc] HumanMessage("next thing"), # type: ignore[misc] ] expected = [ @@ -169,7 +181,20 @@ def test__merge_messages() -> None: HumanMessage( # type: ignore[misc] [ {"type": "tool_result", "content": "buz output", "tool_use_id": "1"}, - {"type": "tool_result", "content": "blah output", "tool_use_id": "2"}, + { + "type": "tool_result", + "content": [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/jpeg", + "data": "fake_image_data", + }, + }, + ], + "tool_use_id": "2", + }, {"type": "text", "text": "next thing"}, ] ), @@ -177,6 +202,27 @@ def test__merge_messages() -> None: actual = _merge_messages(messages) assert expected == actual + # Test tool message case + messages = [ + ToolMessage("buz output", tool_call_id="1"), # type: ignore[misc] + ToolMessage( # type: ignore[misc] + content=[ + {"type": "tool_result", "content": "blah output", "tool_use_id": "2"} + ], + tool_call_id="2", + ), + ] + expected = [ + HumanMessage( # type: ignore[misc] + [ + {"type": "tool_result", "content": "buz output", "tool_use_id": "1"}, + {"type": "tool_result", "content": "blah output", "tool_use_id": "2"}, + ] + ) + ] + actual = _merge_messages(messages) + assert expected == actual + def test__merge_messages_mutation() -> None: original_messages = [