anthropic[patch]: always add tool_result type to ToolMessage content (#22721)

Anthropic tool results can contain image data, which are typically
represented with content blocks having `"type": "image"`. Currently,
these content blocks are passed as-is as human/user messages to
Anthropic, which raises BadRequestError as it expects a tool_result
block to follow a tool_use.

Here we update ChatAnthropic to nest the content blocks inside a
tool_result content block.

Example:
```python
import base64

import httpx
from langchain_anthropic import ChatAnthropic
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langchain_core.pydantic_v1 import BaseModel, Field


# Fetch image
image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
image_data = base64.b64encode(httpx.get(image_url).content).decode("utf-8")


class FetchImage(BaseModel):
    should_fetch: bool = Field(..., description="Whether an image is requested.")


llm = ChatAnthropic(model="claude-3-sonnet-20240229").bind_tools([FetchImage])

messages = [
    HumanMessage(content="Could you summon a beautiful image please?"),
    AIMessage(
        content=[
            {
                "type": "tool_use",
                "id": "toolu_01Rn6Qvj5m7955x9m9Pfxbcx",
                "name": "FetchImage",
                "input": {"should_fetch": True},
            },
        ],
        tool_calls=[
            {
                "name": "FetchImage",
                "args": {"should_fetch": True},
                "id": "toolu_01Rn6Qvj5m7955x9m9Pfxbcx",
            },
        ],
    ),
    ToolMessage(
        name="FetchImage",
        content=[
            {
                "type": "image",
                "source": {
                    "type": "base64",
                    "media_type": "image/jpeg",
                    "data": image_data,
                },
            },
        ],
        tool_call_id="toolu_01Rn6Qvj5m7955x9m9Pfxbcx",
    ),
]

llm.invoke(messages)
```

Trace:
https://smith.langchain.com/public/d27e4fc1-a96d-41e1-9f52-54f5004122db/r
This commit is contained in:
ccurme 2024-06-13 23:14:23 -04:00 committed by GitHub
parent 7114aed78f
commit 73c76b9628
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 54 additions and 5 deletions

View File

@ -104,7 +104,12 @@ def _merge_messages(
for curr in messages: for curr in messages:
curr = curr.copy(deep=True) curr = curr.copy(deep=True)
if isinstance(curr, ToolMessage): if isinstance(curr, ToolMessage):
if isinstance(curr.content, str): if isinstance(curr.content, list) and all(
isinstance(block, dict) and block.get("type") == "tool_result"
for block in curr.content
):
curr = HumanMessage(curr.content) # type: ignore[misc]
else:
curr = HumanMessage( # type: ignore[misc] curr = HumanMessage( # type: ignore[misc]
[ [
{ {
@ -114,8 +119,6 @@ def _merge_messages(
} }
] ]
) )
else:
curr = HumanMessage(curr.content) # type: ignore[misc]
last = merged[-1] if merged else None last = merged[-1] if merged else None
if isinstance(last, HumanMessage) and isinstance(curr, HumanMessage): if isinstance(last, HumanMessage) and isinstance(curr, HumanMessage):
if isinstance(last.content, str): if isinstance(last.content, str):

View File

@ -140,7 +140,19 @@ def test__merge_messages() -> None:
] ]
), ),
ToolMessage("buz output", tool_call_id="1"), # type: ignore[misc] ToolMessage("buz output", tool_call_id="1"), # type: ignore[misc]
ToolMessage("blah output", tool_call_id="2"), # type: ignore[misc] ToolMessage(
content=[
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/jpeg",
"data": "fake_image_data",
},
},
],
tool_call_id="2",
), # type: ignore[misc]
HumanMessage("next thing"), # type: ignore[misc] HumanMessage("next thing"), # type: ignore[misc]
] ]
expected = [ expected = [
@ -169,7 +181,20 @@ def test__merge_messages() -> None:
HumanMessage( # type: ignore[misc] HumanMessage( # type: ignore[misc]
[ [
{"type": "tool_result", "content": "buz output", "tool_use_id": "1"}, {"type": "tool_result", "content": "buz output", "tool_use_id": "1"},
{"type": "tool_result", "content": "blah output", "tool_use_id": "2"}, {
"type": "tool_result",
"content": [
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/jpeg",
"data": "fake_image_data",
},
},
],
"tool_use_id": "2",
},
{"type": "text", "text": "next thing"}, {"type": "text", "text": "next thing"},
] ]
), ),
@ -177,6 +202,27 @@ def test__merge_messages() -> None:
actual = _merge_messages(messages) actual = _merge_messages(messages)
assert expected == actual assert expected == actual
# Test tool message case
messages = [
ToolMessage("buz output", tool_call_id="1"), # type: ignore[misc]
ToolMessage( # type: ignore[misc]
content=[
{"type": "tool_result", "content": "blah output", "tool_use_id": "2"}
],
tool_call_id="2",
),
]
expected = [
HumanMessage( # type: ignore[misc]
[
{"type": "tool_result", "content": "buz output", "tool_use_id": "1"},
{"type": "tool_result", "content": "blah output", "tool_use_id": "2"},
]
)
]
actual = _merge_messages(messages)
assert expected == actual
def test__merge_messages_mutation() -> None: def test__merge_messages_mutation() -> None:
original_messages = [ original_messages = [