mirror of
https://github.com/hwchase17/langchain
synced 2024-11-02 09:40:22 +00:00
core[patch]: include tool_calls in ai msg chunk serialization (#20291)
This commit is contained in:
parent
0fa551c278
commit
03b247cca1
@ -1,5 +1,5 @@
|
|||||||
import warnings
|
import warnings
|
||||||
from typing import Any, List, Literal
|
from typing import Any, Dict, List, Literal
|
||||||
|
|
||||||
from langchain_core.messages.base import (
|
from langchain_core.messages.base import (
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
@ -40,7 +40,15 @@ class AIMessage(BaseMessage):
|
|||||||
"""Get the namespace of the langchain object."""
|
"""Get the namespace of the langchain object."""
|
||||||
return ["langchain", "schema", "messages"]
|
return ["langchain", "schema", "messages"]
|
||||||
|
|
||||||
@root_validator
|
@property
|
||||||
|
def lc_attributes(self) -> Dict:
|
||||||
|
"""Attrs to be serialized even if they are derived from other init args."""
|
||||||
|
return {
|
||||||
|
"tool_calls": self.tool_calls,
|
||||||
|
"invalid_tool_calls": self.invalid_tool_calls,
|
||||||
|
}
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
def _backwards_compat_tool_calls(cls, values: dict) -> dict:
|
def _backwards_compat_tool_calls(cls, values: dict) -> dict:
|
||||||
raw_tool_calls = values.get("additional_kwargs", {}).get("tool_calls")
|
raw_tool_calls = values.get("additional_kwargs", {}).get("tool_calls")
|
||||||
tool_calls = (
|
tool_calls = (
|
||||||
@ -88,6 +96,14 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
|
|||||||
"""Get the namespace of the langchain object."""
|
"""Get the namespace of the langchain object."""
|
||||||
return ["langchain", "schema", "messages"]
|
return ["langchain", "schema", "messages"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_attributes(self) -> Dict:
|
||||||
|
"""Attrs to be serialized even if they are derived from other init args."""
|
||||||
|
return {
|
||||||
|
"tool_calls": self.tool_calls,
|
||||||
|
"invalid_tool_calls": self.invalid_tool_calls,
|
||||||
|
}
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def init_tool_calls(cls, values: dict) -> dict:
|
def init_tool_calls(cls, values: dict) -> dict:
|
||||||
if not values["tool_call_chunks"]:
|
if not values["tool_call_chunks"]:
|
||||||
|
67
libs/core/tests/unit_tests/messages/test_ai.py
Normal file
67
libs/core/tests/unit_tests/messages/test_ai.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
from langchain_core.load import dumpd, load
|
||||||
|
from langchain_core.messages import (
|
||||||
|
AIMessage,
|
||||||
|
AIMessageChunk,
|
||||||
|
InvalidToolCall,
|
||||||
|
ToolCall,
|
||||||
|
ToolCallChunk,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_serdes_message() -> None:
|
||||||
|
msg = AIMessage(
|
||||||
|
content=[{"text": "blah", "type": "text"}],
|
||||||
|
tool_calls=[ToolCall(name="foo", args={"bar": 1}, id="baz")],
|
||||||
|
invalid_tool_calls=[
|
||||||
|
InvalidToolCall(name="foobad", args="blah", id="booz", error="bad")
|
||||||
|
],
|
||||||
|
)
|
||||||
|
expected = {
|
||||||
|
"lc": 1,
|
||||||
|
"type": "constructor",
|
||||||
|
"id": ["langchain", "schema", "messages", "AIMessage"],
|
||||||
|
"kwargs": {
|
||||||
|
"content": [{"text": "blah", "type": "text"}],
|
||||||
|
"tool_calls": [{"name": "foo", "args": {"bar": 1}, "id": "baz"}],
|
||||||
|
"invalid_tool_calls": [
|
||||||
|
{"name": "foobad", "args": "blah", "id": "booz", "error": "bad"}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
actual = dumpd(msg)
|
||||||
|
assert actual == expected
|
||||||
|
assert load(actual) == msg
|
||||||
|
|
||||||
|
|
||||||
|
def test_serdes_message_chunk() -> None:
|
||||||
|
chunk = AIMessageChunk(
|
||||||
|
content=[{"text": "blah", "type": "text"}],
|
||||||
|
tool_call_chunks=[
|
||||||
|
ToolCallChunk(name="foo", args='{"bar": 1}', id="baz", index=0),
|
||||||
|
ToolCallChunk(name="foobad", args="blah", id="booz", index=1),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
expected = {
|
||||||
|
"lc": 1,
|
||||||
|
"type": "constructor",
|
||||||
|
"id": ["langchain", "schema", "messages", "AIMessageChunk"],
|
||||||
|
"kwargs": {
|
||||||
|
"content": [{"text": "blah", "type": "text"}],
|
||||||
|
"tool_calls": [{"name": "foo", "args": {"bar": 1}, "id": "baz"}],
|
||||||
|
"invalid_tool_calls": [
|
||||||
|
{
|
||||||
|
"name": "foobad",
|
||||||
|
"args": "blah",
|
||||||
|
"id": "booz",
|
||||||
|
"error": "Malformed args.",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"tool_call_chunks": [
|
||||||
|
{"name": "foo", "args": '{"bar": 1}', "id": "baz", "index": 0},
|
||||||
|
{"name": "foobad", "args": "blah", "id": "booz", "index": 1},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
actual = dumpd(chunk)
|
||||||
|
assert actual == expected
|
||||||
|
assert load(actual) == chunk
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user