mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
anthropic[patch]: always add tool_result type to ToolMessage content (#22721)
Anthropic tool results can contain image data, which are typically represented with content blocks having `"type": "image"`. Currently, these content blocks are passed as-is as human/user messages to Anthropic, which raises BadRequestError as it expects a tool_result block to follow a tool_use. Here we update ChatAnthropic to nest the content blocks inside a tool_result content block. Example: ```python import base64 import httpx from langchain_anthropic import ChatAnthropic from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from langchain_core.pydantic_v1 import BaseModel, Field # Fetch image image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" image_data = base64.b64encode(httpx.get(image_url).content).decode("utf-8") class FetchImage(BaseModel): should_fetch: bool = Field(..., description="Whether an image is requested.") llm = ChatAnthropic(model="claude-3-sonnet-20240229").bind_tools([FetchImage]) messages = [ HumanMessage(content="Could you summon a beautiful image please?"), AIMessage( content=[ { "type": "tool_use", "id": "toolu_01Rn6Qvj5m7955x9m9Pfxbcx", "name": "FetchImage", "input": {"should_fetch": True}, }, ], tool_calls=[ { "name": "FetchImage", "args": {"should_fetch": True}, "id": "toolu_01Rn6Qvj5m7955x9m9Pfxbcx", }, ], ), ToolMessage( name="FetchImage", content=[ { "type": "image", "source": { "type": "base64", "media_type": "image/jpeg", "data": image_data, }, }, ], tool_call_id="toolu_01Rn6Qvj5m7955x9m9Pfxbcx", ), ] llm.invoke(messages) ``` Trace: https://smith.langchain.com/public/d27e4fc1-a96d-41e1-9f52-54f5004122db/r
This commit is contained in:
parent
7114aed78f
commit
73c76b9628
@ -104,7 +104,12 @@ def _merge_messages(
|
||||
for curr in messages:
|
||||
curr = curr.copy(deep=True)
|
||||
if isinstance(curr, ToolMessage):
|
||||
if isinstance(curr.content, str):
|
||||
if isinstance(curr.content, list) and all(
|
||||
isinstance(block, dict) and block.get("type") == "tool_result"
|
||||
for block in curr.content
|
||||
):
|
||||
curr = HumanMessage(curr.content) # type: ignore[misc]
|
||||
else:
|
||||
curr = HumanMessage( # type: ignore[misc]
|
||||
[
|
||||
{
|
||||
@ -114,8 +119,6 @@ def _merge_messages(
|
||||
}
|
||||
]
|
||||
)
|
||||
else:
|
||||
curr = HumanMessage(curr.content) # type: ignore[misc]
|
||||
last = merged[-1] if merged else None
|
||||
if isinstance(last, HumanMessage) and isinstance(curr, HumanMessage):
|
||||
if isinstance(last.content, str):
|
||||
|
@ -140,7 +140,19 @@ def test__merge_messages() -> None:
|
||||
]
|
||||
),
|
||||
ToolMessage("buz output", tool_call_id="1"), # type: ignore[misc]
|
||||
ToolMessage("blah output", tool_call_id="2"), # type: ignore[misc]
|
||||
ToolMessage(
|
||||
content=[
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/jpeg",
|
||||
"data": "fake_image_data",
|
||||
},
|
||||
},
|
||||
],
|
||||
tool_call_id="2",
|
||||
), # type: ignore[misc]
|
||||
HumanMessage("next thing"), # type: ignore[misc]
|
||||
]
|
||||
expected = [
|
||||
@ -169,7 +181,20 @@ def test__merge_messages() -> None:
|
||||
HumanMessage( # type: ignore[misc]
|
||||
[
|
||||
{"type": "tool_result", "content": "buz output", "tool_use_id": "1"},
|
||||
{"type": "tool_result", "content": "blah output", "tool_use_id": "2"},
|
||||
{
|
||||
"type": "tool_result",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/jpeg",
|
||||
"data": "fake_image_data",
|
||||
},
|
||||
},
|
||||
],
|
||||
"tool_use_id": "2",
|
||||
},
|
||||
{"type": "text", "text": "next thing"},
|
||||
]
|
||||
),
|
||||
@ -177,6 +202,27 @@ def test__merge_messages() -> None:
|
||||
actual = _merge_messages(messages)
|
||||
assert expected == actual
|
||||
|
||||
# Test tool message case
|
||||
messages = [
|
||||
ToolMessage("buz output", tool_call_id="1"), # type: ignore[misc]
|
||||
ToolMessage( # type: ignore[misc]
|
||||
content=[
|
||||
{"type": "tool_result", "content": "blah output", "tool_use_id": "2"}
|
||||
],
|
||||
tool_call_id="2",
|
||||
),
|
||||
]
|
||||
expected = [
|
||||
HumanMessage( # type: ignore[misc]
|
||||
[
|
||||
{"type": "tool_result", "content": "buz output", "tool_use_id": "1"},
|
||||
{"type": "tool_result", "content": "blah output", "tool_use_id": "2"},
|
||||
]
|
||||
)
|
||||
]
|
||||
actual = _merge_messages(messages)
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test__merge_messages_mutation() -> None:
|
||||
original_messages = [
|
||||
|
Loading…
Reference in New Issue
Block a user