You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/core/tests/unit_tests/test_messages.py

431 lines
13 KiB
Python

import unittest
from typing import List
import pytest
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
ChatMessage,
ChatMessageChunk,
FunctionMessage,
FunctionMessageChunk,
HumanMessage,
HumanMessageChunk,
SystemMessage,
ToolMessage,
get_buffer_string,
message_chunk_to_message,
messages_from_dict,
messages_to_dict,
)
def test_message_chunks() -> None:
assert AIMessageChunk(content="I am") + AIMessageChunk(
content=" indeed."
) == AIMessageChunk(
content="I am indeed."
), "MessageChunk + MessageChunk should be a MessageChunk"
assert (
AIMessageChunk(content="I am") + HumanMessageChunk(content=" indeed.")
== AIMessageChunk(content="I am indeed.")
), "MessageChunk + MessageChunk should be a MessageChunk of same class as the left side" # noqa: E501
assert (
AIMessageChunk(content="", additional_kwargs={"foo": "bar"})
+ AIMessageChunk(content="", additional_kwargs={"baz": "foo"})
== AIMessageChunk(content="", additional_kwargs={"foo": "bar", "baz": "foo"})
), "MessageChunk + MessageChunk should be a MessageChunk with merged additional_kwargs" # noqa: E501
assert (
AIMessageChunk(
content="", additional_kwargs={"function_call": {"name": "web_search"}}
)
+ AIMessageChunk(
content="", additional_kwargs={"function_call": {"arguments": None}}
)
+ AIMessageChunk(
content="", additional_kwargs={"function_call": {"arguments": "{\n"}}
)
+ AIMessageChunk(
content="",
additional_kwargs={
"function_call": {"arguments": ' "query": "turtles"\n}'}
},
)
== AIMessageChunk(
content="",
additional_kwargs={
"function_call": {
"name": "web_search",
"arguments": '{\n "query": "turtles"\n}',
}
},
)
), "MessageChunk + MessageChunk should be a MessageChunk with merged additional_kwargs" # noqa: E501
def test_chat_message_chunks() -> None:
assert ChatMessageChunk(role="User", content="I am") + ChatMessageChunk(
role="User", content=" indeed."
) == ChatMessageChunk(
role="User", content="I am indeed."
), "ChatMessageChunk + ChatMessageChunk should be a ChatMessageChunk"
with pytest.raises(ValueError):
ChatMessageChunk(role="User", content="I am") + ChatMessageChunk(
role="Assistant", content=" indeed."
)
assert (
ChatMessageChunk(role="User", content="I am")
+ AIMessageChunk(content=" indeed.")
== ChatMessageChunk(role="User", content="I am indeed.")
), "ChatMessageChunk + other MessageChunk should be a ChatMessageChunk with the left side's role" # noqa: E501
assert AIMessageChunk(content="I am") + ChatMessageChunk(
role="User", content=" indeed."
) == AIMessageChunk(
content="I am indeed."
), "Other MessageChunk + ChatMessageChunk should be a MessageChunk as the left side" # noqa: E501
def test_function_message_chunks() -> None:
assert FunctionMessageChunk(name="hello", content="I am") + FunctionMessageChunk(
name="hello", content=" indeed."
) == FunctionMessageChunk(
name="hello", content="I am indeed."
), "FunctionMessageChunk + FunctionMessageChunk should be a FunctionMessageChunk"
with pytest.raises(ValueError):
FunctionMessageChunk(name="hello", content="I am") + FunctionMessageChunk(
name="bye", content=" indeed."
)
def test_ani_message_chunks() -> None:
assert AIMessageChunk(example=True, content="I am") + AIMessageChunk(
example=True, content=" indeed."
) == AIMessageChunk(
example=True, content="I am indeed."
), "AIMessageChunk + AIMessageChunk should be a AIMessageChunk"
with pytest.raises(ValueError):
AIMessageChunk(example=True, content="I am") + AIMessageChunk(
example=False, content=" indeed."
)
class TestGetBufferString(unittest.TestCase):
def setUp(self) -> None:
self.human_msg = HumanMessage(content="human")
self.ai_msg = AIMessage(content="ai")
self.sys_msg = SystemMessage(content="system")
self.func_msg = FunctionMessage(name="func", content="function")
self.tool_msg = ToolMessage(tool_call_id="tool_id", content="tool")
self.chat_msg = ChatMessage(role="Chat", content="chat")
def test_empty_input(self) -> None:
self.assertEqual(get_buffer_string([]), "")
def test_valid_single_message(self) -> None:
expected_output = f"Human: {self.human_msg.content}"
self.assertEqual(
get_buffer_string([self.human_msg]),
expected_output,
)
def test_custom_human_prefix(self) -> None:
prefix = "H"
expected_output = f"{prefix}: {self.human_msg.content}"
self.assertEqual(
get_buffer_string([self.human_msg], human_prefix="H"),
expected_output,
)
def test_custom_ai_prefix(self) -> None:
prefix = "A"
expected_output = f"{prefix}: {self.ai_msg.content}"
self.assertEqual(
get_buffer_string([self.ai_msg], ai_prefix="A"),
expected_output,
)
def test_multiple_msg(self) -> None:
msgs = [
self.human_msg,
self.ai_msg,
self.sys_msg,
self.func_msg,
self.tool_msg,
self.chat_msg,
]
expected_output = "\n".join(
[
"Human: human",
"AI: ai",
"System: system",
"Function: function",
"Tool: tool",
"Chat: chat",
]
)
self.assertEqual(
get_buffer_string(msgs),
expected_output,
)
def test_multiple_msg() -> None:
human_msg = HumanMessage(content="human", additional_kwargs={"key": "value"})
ai_msg = AIMessage(content="ai")
sys_msg = SystemMessage(content="sys")
msgs = [
human_msg,
ai_msg,
sys_msg,
]
assert messages_from_dict(messages_to_dict(msgs)) == msgs
def test_message_chunk_to_message() -> None:
assert message_chunk_to_message(
AIMessageChunk(content="I am", additional_kwargs={"foo": "bar"})
) == AIMessage(content="I am", additional_kwargs={"foo": "bar"})
assert message_chunk_to_message(HumanMessageChunk(content="I am")) == HumanMessage(
content="I am"
)
assert message_chunk_to_message(
ChatMessageChunk(role="User", content="I am")
) == ChatMessage(role="User", content="I am")
assert message_chunk_to_message(
FunctionMessageChunk(name="hello", content="I am")
) == FunctionMessage(name="hello", content="I am")
def test_tool_calls_merge() -> None:
chunks: List[dict] = [
dict(content=""),
dict(
content="",
additional_kwargs={
"tool_calls": [
{
"index": 0,
"id": "call_CwGAsESnXehQEjiAIWzinlva",
"function": {"arguments": "", "name": "person"},
"type": "function",
}
]
},
),
dict(
content="",
additional_kwargs={
"tool_calls": [
{
"index": 0,
"id": None,
"function": {"arguments": '{"na', "name": None},
"type": None,
}
]
},
),
dict(
content="",
additional_kwargs={
"tool_calls": [
{
"index": 0,
"id": None,
"function": {"arguments": 'me": ', "name": None},
"type": None,
}
]
},
),
dict(
content="",
additional_kwargs={
"tool_calls": [
{
"index": 0,
"id": None,
"function": {"arguments": '"jane"', "name": None},
"type": None,
}
]
},
),
dict(
content="",
additional_kwargs={
"tool_calls": [
{
"index": 0,
"id": None,
"function": {"arguments": ', "a', "name": None},
"type": None,
}
]
},
),
dict(
content="",
additional_kwargs={
"tool_calls": [
{
"index": 0,
"id": None,
"function": {"arguments": 'ge": ', "name": None},
"type": None,
}
]
},
),
dict(
content="",
additional_kwargs={
"tool_calls": [
{
"index": 0,
"id": None,
"function": {"arguments": "2}", "name": None},
"type": None,
}
]
},
),
dict(
content="",
additional_kwargs={
"tool_calls": [
{
"index": 1,
"id": "call_zXSIylHvc5x3JUAPcHZR5GZI",
"function": {"arguments": "", "name": "person"},
"type": "function",
}
]
},
),
dict(
content="",
additional_kwargs={
"tool_calls": [
{
"index": 1,
"id": None,
"function": {"arguments": '{"na', "name": None},
"type": None,
}
]
},
),
dict(
content="",
additional_kwargs={
"tool_calls": [
{
"index": 1,
"id": None,
"function": {"arguments": 'me": ', "name": None},
"type": None,
}
]
},
),
dict(
content="",
additional_kwargs={
"tool_calls": [
{
"index": 1,
"id": None,
"function": {"arguments": '"bob",', "name": None},
"type": None,
}
]
},
),
dict(
content="",
additional_kwargs={
"tool_calls": [
{
"index": 1,
"id": None,
"function": {"arguments": ' "ag', "name": None},
"type": None,
}
]
},
),
dict(
content="",
additional_kwargs={
"tool_calls": [
{
"index": 1,
"id": None,
"function": {"arguments": 'e": 3', "name": None},
"type": None,
}
]
},
),
dict(
content="",
additional_kwargs={
"tool_calls": [
{
"index": 1,
"id": None,
"function": {"arguments": "}", "name": None},
"type": None,
}
]
},
),
dict(content=""),
]
final = None
for chunk in chunks:
msg = AIMessageChunk(**chunk)
if final is None:
final = msg
else:
final = final + msg
assert final == AIMessageChunk(
content="",
additional_kwargs={
"tool_calls": [
{
"index": 0,
"id": "call_CwGAsESnXehQEjiAIWzinlva",
"function": {
"arguments": '{"name": "jane", "age": 2}',
"name": "person",
},
"type": "function",
},
{
"index": 1,
"id": "call_zXSIylHvc5x3JUAPcHZR5GZI",
"function": {
"arguments": '{"name": "bob", "age": 3}',
"name": "person",
},
"type": "function",
},
]
},
)