diff --git a/libs/langchain/langchain/schema/messages.py b/libs/langchain/langchain/schema/messages.py index 0fbef0ba49..e0134da0d8 100644 --- a/libs/langchain/langchain/schema/messages.py +++ b/libs/langchain/langchain/schema/messages.py @@ -117,6 +117,14 @@ class BaseMessageChunk(BaseMessage): # If both are (subclasses of) BaseMessageChunk, # concat into a single BaseMessageChunk + if isinstance(self, ChatMessageChunk): + return self.__class__( + role=self.role, + content=self.content + other.content, + additional_kwargs=self._merge_kwargs_dict( + self.additional_kwargs, other.additional_kwargs + ), + ) return self.__class__( content=self.content + other.content, additional_kwargs=self._merge_kwargs_dict( @@ -168,7 +176,22 @@ class AIMessage(BaseMessage): class AIMessageChunk(AIMessage, BaseMessageChunk): """A Message chunk from an AI.""" - pass + def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore + if isinstance(other, AIMessageChunk): + if self.example != other.example: + raise ValueError( + "Cannot concatenate AIMessageChunks with different example values." + ) + + return self.__class__( + example=self.example, + content=self.content + other.content, + additional_kwargs=self._merge_kwargs_dict( + self.additional_kwargs, other.additional_kwargs + ), + ) + + return super().__add__(other) class SystemMessage(BaseMessage): @@ -203,7 +226,22 @@ class FunctionMessage(BaseMessage): class FunctionMessageChunk(FunctionMessage, BaseMessageChunk): """A Function Message chunk.""" - pass + def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore + if isinstance(other, FunctionMessageChunk): + if self.name != other.name: + raise ValueError( + "Cannot concatenate FunctionMessageChunks with different names." + ) + + return self.__class__( + name=self.name, + content=self.content + other.content, + additional_kwargs=self._merge_kwargs_dict( + self.additional_kwargs, other.additional_kwargs + ), + ) + + return super().__add__(other) class ChatMessage(BaseMessage): @@ -221,7 +259,22 @@ class ChatMessage(BaseMessage): class ChatMessageChunk(ChatMessage, BaseMessageChunk): """A Chat Message chunk.""" - pass + def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore + if isinstance(other, ChatMessageChunk): + if self.role != other.role: + raise ValueError( + "Cannot concatenate ChatMessageChunks with different roles." + ) + + return self.__class__( + role=self.role, + content=self.content + other.content, + additional_kwargs=self._merge_kwargs_dict( + self.additional_kwargs, other.additional_kwargs + ), + ) + + return super().__add__(other) def _message_to_dict(message: BaseMessage) -> dict: diff --git a/libs/langchain/tests/unit_tests/schema/test_messages.py b/libs/langchain/tests/unit_tests/schema/test_messages.py index 25c1a2b072..36c22730e8 100644 --- a/libs/langchain/tests/unit_tests/schema/test_messages.py +++ b/libs/langchain/tests/unit_tests/schema/test_messages.py @@ -1,4 +1,11 @@ -from langchain.schema.messages import AIMessageChunk, HumanMessageChunk +import pytest + +from langchain.schema.messages import ( + AIMessageChunk, + ChatMessageChunk, + FunctionMessageChunk, + HumanMessageChunk, +) def test_message_chunks() -> None: @@ -36,3 +43,54 @@ def test_message_chunks() -> None: } }, ), "MessageChunk + MessageChunk should be a MessageChunk with merged additional_kwargs" # noqa: E501 + + +def test_chat_message_chunks() -> None: + assert ChatMessageChunk(role="User", content="I am") + ChatMessageChunk( + role="User", content=" indeed." + ) == ChatMessageChunk( + role="User", content="I am indeed." + ), "ChatMessageChunk + ChatMessageChunk should be a ChatMessageChunk" + + with pytest.raises(ValueError): + ChatMessageChunk(role="User", content="I am") + ChatMessageChunk( + role="Assistant", content=" indeed." + ) + + assert ChatMessageChunk(role="User", content="I am") + AIMessageChunk( + content=" indeed." + ) == ChatMessageChunk( + role="User", content="I am indeed." + ), "ChatMessageChunk + other MessageChunk should be a ChatMessageChunk with the left side's role" # noqa: E501 + + assert AIMessageChunk(content="I am") + ChatMessageChunk( + role="User", content=" indeed." + ) == AIMessageChunk( + content="I am indeed." + ), "Other MessageChunk + ChatMessageChunk should be a MessageChunk as the left side" # noqa: E501 + + +def test_function_message_chunks() -> None: + assert FunctionMessageChunk(name="hello", content="I am") + FunctionMessageChunk( + name="hello", content=" indeed." + ) == FunctionMessageChunk( + name="hello", content="I am indeed." + ), "FunctionMessageChunk + FunctionMessageChunk should be a FunctionMessageChunk" + + with pytest.raises(ValueError): + FunctionMessageChunk(name="hello", content="I am") + FunctionMessageChunk( + name="bye", content=" indeed." + ) + + +def test_ani_message_chunks() -> None: + assert AIMessageChunk(example=True, content="I am") + AIMessageChunk( + example=True, content=" indeed." + ) == AIMessageChunk( + example=True, content="I am indeed." + ), "AIMessageChunk + AIMessageChunk should be a AIMessageChunk" + + with pytest.raises(ValueError): + AIMessageChunk(example=True, content="I am") + AIMessageChunk( + example=False, content=" indeed." + )