core[patch]: include tool_calls in ai msg chunk serialization (#20291)

This commit is contained in:
Bagatur 2024-04-10 17:27:40 -05:00 committed by GitHub
parent 0fa551c278
commit 03b247cca1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 89 additions and 6 deletions

View File

@ -1,5 +1,5 @@
import warnings import warnings
from typing import Any, List, Literal from typing import Any, Dict, List, Literal
from langchain_core.messages.base import ( from langchain_core.messages.base import (
BaseMessage, BaseMessage,
@ -40,7 +40,15 @@ class AIMessage(BaseMessage):
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "schema", "messages"] return ["langchain", "schema", "messages"]
@root_validator @property
def lc_attributes(self) -> Dict:
"""Attrs to be serialized even if they are derived from other init args."""
return {
"tool_calls": self.tool_calls,
"invalid_tool_calls": self.invalid_tool_calls,
}
@root_validator()
def _backwards_compat_tool_calls(cls, values: dict) -> dict: def _backwards_compat_tool_calls(cls, values: dict) -> dict:
raw_tool_calls = values.get("additional_kwargs", {}).get("tool_calls") raw_tool_calls = values.get("additional_kwargs", {}).get("tool_calls")
tool_calls = ( tool_calls = (
@ -88,6 +96,14 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "schema", "messages"] return ["langchain", "schema", "messages"]
@property
def lc_attributes(self) -> Dict:
"""Attrs to be serialized even if they are derived from other init args."""
return {
"tool_calls": self.tool_calls,
"invalid_tool_calls": self.invalid_tool_calls,
}
@root_validator() @root_validator()
def init_tool_calls(cls, values: dict) -> dict: def init_tool_calls(cls, values: dict) -> dict:
if not values["tool_call_chunks"]: if not values["tool_call_chunks"]:

View File

@ -0,0 +1,67 @@
from langchain_core.load import dumpd, load
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
InvalidToolCall,
ToolCall,
ToolCallChunk,
)
def test_serdes_message() -> None:
msg = AIMessage(
content=[{"text": "blah", "type": "text"}],
tool_calls=[ToolCall(name="foo", args={"bar": 1}, id="baz")],
invalid_tool_calls=[
InvalidToolCall(name="foobad", args="blah", id="booz", error="bad")
],
)
expected = {
"lc": 1,
"type": "constructor",
"id": ["langchain", "schema", "messages", "AIMessage"],
"kwargs": {
"content": [{"text": "blah", "type": "text"}],
"tool_calls": [{"name": "foo", "args": {"bar": 1}, "id": "baz"}],
"invalid_tool_calls": [
{"name": "foobad", "args": "blah", "id": "booz", "error": "bad"}
],
},
}
actual = dumpd(msg)
assert actual == expected
assert load(actual) == msg
def test_serdes_message_chunk() -> None:
chunk = AIMessageChunk(
content=[{"text": "blah", "type": "text"}],
tool_call_chunks=[
ToolCallChunk(name="foo", args='{"bar": 1}', id="baz", index=0),
ToolCallChunk(name="foobad", args="blah", id="booz", index=1),
],
)
expected = {
"lc": 1,
"type": "constructor",
"id": ["langchain", "schema", "messages", "AIMessageChunk"],
"kwargs": {
"content": [{"text": "blah", "type": "text"}],
"tool_calls": [{"name": "foo", "args": {"bar": 1}, "id": "baz"}],
"invalid_tool_calls": [
{
"name": "foobad",
"args": "blah",
"id": "booz",
"error": "Malformed args.",
}
],
"tool_call_chunks": [
{"name": "foo", "args": '{"bar": 1}', "id": "baz", "index": 0},
{"name": "foobad", "args": "blah", "id": "booz", "index": 1},
],
},
}
actual = dumpd(chunk)
assert actual == expected
assert load(actual) == chunk

File diff suppressed because one or more lines are too long