mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
anthropic[patch]: always add tool_result type to ToolMessage content (#22721)
Anthropic tool results can contain image data, which are typically represented with content blocks having `"type": "image"`. Currently, these content blocks are passed as-is as human/user messages to Anthropic, which raises BadRequestError as it expects a tool_result block to follow a tool_use. Here we update ChatAnthropic to nest the content blocks inside a tool_result content block. Example: ```python import base64 import httpx from langchain_anthropic import ChatAnthropic from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from langchain_core.pydantic_v1 import BaseModel, Field # Fetch image image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" image_data = base64.b64encode(httpx.get(image_url).content).decode("utf-8") class FetchImage(BaseModel): should_fetch: bool = Field(..., description="Whether an image is requested.") llm = ChatAnthropic(model="claude-3-sonnet-20240229").bind_tools([FetchImage]) messages = [ HumanMessage(content="Could you summon a beautiful image please?"), AIMessage( content=[ { "type": "tool_use", "id": "toolu_01Rn6Qvj5m7955x9m9Pfxbcx", "name": "FetchImage", "input": {"should_fetch": True}, }, ], tool_calls=[ { "name": "FetchImage", "args": {"should_fetch": True}, "id": "toolu_01Rn6Qvj5m7955x9m9Pfxbcx", }, ], ), ToolMessage( name="FetchImage", content=[ { "type": "image", "source": { "type": "base64", "media_type": "image/jpeg", "data": image_data, }, }, ], tool_call_id="toolu_01Rn6Qvj5m7955x9m9Pfxbcx", ), ] llm.invoke(messages) ``` Trace: https://smith.langchain.com/public/d27e4fc1-a96d-41e1-9f52-54f5004122db/r
This commit is contained in:
parent
7114aed78f
commit
73c76b9628
@ -104,7 +104,12 @@ def _merge_messages(
|
|||||||
for curr in messages:
|
for curr in messages:
|
||||||
curr = curr.copy(deep=True)
|
curr = curr.copy(deep=True)
|
||||||
if isinstance(curr, ToolMessage):
|
if isinstance(curr, ToolMessage):
|
||||||
if isinstance(curr.content, str):
|
if isinstance(curr.content, list) and all(
|
||||||
|
isinstance(block, dict) and block.get("type") == "tool_result"
|
||||||
|
for block in curr.content
|
||||||
|
):
|
||||||
|
curr = HumanMessage(curr.content) # type: ignore[misc]
|
||||||
|
else:
|
||||||
curr = HumanMessage( # type: ignore[misc]
|
curr = HumanMessage( # type: ignore[misc]
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
@ -114,8 +119,6 @@ def _merge_messages(
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
curr = HumanMessage(curr.content) # type: ignore[misc]
|
|
||||||
last = merged[-1] if merged else None
|
last = merged[-1] if merged else None
|
||||||
if isinstance(last, HumanMessage) and isinstance(curr, HumanMessage):
|
if isinstance(last, HumanMessage) and isinstance(curr, HumanMessage):
|
||||||
if isinstance(last.content, str):
|
if isinstance(last.content, str):
|
||||||
|
@ -140,7 +140,19 @@ def test__merge_messages() -> None:
|
|||||||
]
|
]
|
||||||
),
|
),
|
||||||
ToolMessage("buz output", tool_call_id="1"), # type: ignore[misc]
|
ToolMessage("buz output", tool_call_id="1"), # type: ignore[misc]
|
||||||
ToolMessage("blah output", tool_call_id="2"), # type: ignore[misc]
|
ToolMessage(
|
||||||
|
content=[
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"source": {
|
||||||
|
"type": "base64",
|
||||||
|
"media_type": "image/jpeg",
|
||||||
|
"data": "fake_image_data",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
tool_call_id="2",
|
||||||
|
), # type: ignore[misc]
|
||||||
HumanMessage("next thing"), # type: ignore[misc]
|
HumanMessage("next thing"), # type: ignore[misc]
|
||||||
]
|
]
|
||||||
expected = [
|
expected = [
|
||||||
@ -169,7 +181,20 @@ def test__merge_messages() -> None:
|
|||||||
HumanMessage( # type: ignore[misc]
|
HumanMessage( # type: ignore[misc]
|
||||||
[
|
[
|
||||||
{"type": "tool_result", "content": "buz output", "tool_use_id": "1"},
|
{"type": "tool_result", "content": "buz output", "tool_use_id": "1"},
|
||||||
{"type": "tool_result", "content": "blah output", "tool_use_id": "2"},
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"source": {
|
||||||
|
"type": "base64",
|
||||||
|
"media_type": "image/jpeg",
|
||||||
|
"data": "fake_image_data",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"tool_use_id": "2",
|
||||||
|
},
|
||||||
{"type": "text", "text": "next thing"},
|
{"type": "text", "text": "next thing"},
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
@ -177,6 +202,27 @@ def test__merge_messages() -> None:
|
|||||||
actual = _merge_messages(messages)
|
actual = _merge_messages(messages)
|
||||||
assert expected == actual
|
assert expected == actual
|
||||||
|
|
||||||
|
# Test tool message case
|
||||||
|
messages = [
|
||||||
|
ToolMessage("buz output", tool_call_id="1"), # type: ignore[misc]
|
||||||
|
ToolMessage( # type: ignore[misc]
|
||||||
|
content=[
|
||||||
|
{"type": "tool_result", "content": "blah output", "tool_use_id": "2"}
|
||||||
|
],
|
||||||
|
tool_call_id="2",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
expected = [
|
||||||
|
HumanMessage( # type: ignore[misc]
|
||||||
|
[
|
||||||
|
{"type": "tool_result", "content": "buz output", "tool_use_id": "1"},
|
||||||
|
{"type": "tool_result", "content": "blah output", "tool_use_id": "2"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
actual = _merge_messages(messages)
|
||||||
|
assert expected == actual
|
||||||
|
|
||||||
|
|
||||||
def test__merge_messages_mutation() -> None:
|
def test__merge_messages_mutation() -> None:
|
||||||
original_messages = [
|
original_messages = [
|
||||||
|
Loading…
Reference in New Issue
Block a user