fix ChatMessageChunk concat error (#10174)

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - Description: a description of the change, 
  - Issue: the issue # it fixes (if applicable),
  - Dependencies: any dependencies required for this change,
- Tag maintainer: for a quicker response, tag the relevant maintainer
(see below),
- Twitter handle: we announce bigger features on Twitter. If your PR
gets announced and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. These live is docs/extras
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17, @rlancemartin.
 -->

- Description: fix `ChatMessageChunk` concat error 
- Issue: #10173 
- Dependencies: None
- Tag maintainer: @baskaryan, @eyurtsev, @rlancemartin
- Twitter handle: None

---------

Co-authored-by: wangshuai.scotty <wangshuai.scotty@bytedance.com>
Co-authored-by: Nuno Campos <nuno@boringbits.io>
pull/10956/head^2
Scotty 11 months ago committed by GitHub
parent 4322b246aa
commit 88a02076af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -117,6 +117,14 @@ class BaseMessageChunk(BaseMessage):
# If both are (subclasses of) BaseMessageChunk,
# concat into a single BaseMessageChunk
if isinstance(self, ChatMessageChunk):
return self.__class__(
role=self.role,
content=self.content + other.content,
additional_kwargs=self._merge_kwargs_dict(
self.additional_kwargs, other.additional_kwargs
),
)
return self.__class__(
content=self.content + other.content,
additional_kwargs=self._merge_kwargs_dict(
@ -168,7 +176,22 @@ class AIMessage(BaseMessage):
class AIMessageChunk(AIMessage, BaseMessageChunk):
"""A Message chunk from an AI."""
pass
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, AIMessageChunk):
if self.example != other.example:
raise ValueError(
"Cannot concatenate AIMessageChunks with different example values."
)
return self.__class__(
example=self.example,
content=self.content + other.content,
additional_kwargs=self._merge_kwargs_dict(
self.additional_kwargs, other.additional_kwargs
),
)
return super().__add__(other)
class SystemMessage(BaseMessage):
@ -203,7 +226,22 @@ class FunctionMessage(BaseMessage):
class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
"""A Function Message chunk."""
pass
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, FunctionMessageChunk):
if self.name != other.name:
raise ValueError(
"Cannot concatenate FunctionMessageChunks with different names."
)
return self.__class__(
name=self.name,
content=self.content + other.content,
additional_kwargs=self._merge_kwargs_dict(
self.additional_kwargs, other.additional_kwargs
),
)
return super().__add__(other)
class ChatMessage(BaseMessage):
@ -221,7 +259,22 @@ class ChatMessage(BaseMessage):
class ChatMessageChunk(ChatMessage, BaseMessageChunk):
"""A Chat Message chunk."""
pass
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, ChatMessageChunk):
if self.role != other.role:
raise ValueError(
"Cannot concatenate ChatMessageChunks with different roles."
)
return self.__class__(
role=self.role,
content=self.content + other.content,
additional_kwargs=self._merge_kwargs_dict(
self.additional_kwargs, other.additional_kwargs
),
)
return super().__add__(other)
def _message_to_dict(message: BaseMessage) -> dict:

@ -1,4 +1,11 @@
from langchain.schema.messages import AIMessageChunk, HumanMessageChunk
import pytest
from langchain.schema.messages import (
AIMessageChunk,
ChatMessageChunk,
FunctionMessageChunk,
HumanMessageChunk,
)
def test_message_chunks() -> None:
@ -36,3 +43,54 @@ def test_message_chunks() -> None:
}
},
), "MessageChunk + MessageChunk should be a MessageChunk with merged additional_kwargs" # noqa: E501
def test_chat_message_chunks() -> None:
assert ChatMessageChunk(role="User", content="I am") + ChatMessageChunk(
role="User", content=" indeed."
) == ChatMessageChunk(
role="User", content="I am indeed."
), "ChatMessageChunk + ChatMessageChunk should be a ChatMessageChunk"
with pytest.raises(ValueError):
ChatMessageChunk(role="User", content="I am") + ChatMessageChunk(
role="Assistant", content=" indeed."
)
assert ChatMessageChunk(role="User", content="I am") + AIMessageChunk(
content=" indeed."
) == ChatMessageChunk(
role="User", content="I am indeed."
), "ChatMessageChunk + other MessageChunk should be a ChatMessageChunk with the left side's role" # noqa: E501
assert AIMessageChunk(content="I am") + ChatMessageChunk(
role="User", content=" indeed."
) == AIMessageChunk(
content="I am indeed."
), "Other MessageChunk + ChatMessageChunk should be a MessageChunk as the left side" # noqa: E501
def test_function_message_chunks() -> None:
assert FunctionMessageChunk(name="hello", content="I am") + FunctionMessageChunk(
name="hello", content=" indeed."
) == FunctionMessageChunk(
name="hello", content="I am indeed."
), "FunctionMessageChunk + FunctionMessageChunk should be a FunctionMessageChunk"
with pytest.raises(ValueError):
FunctionMessageChunk(name="hello", content="I am") + FunctionMessageChunk(
name="bye", content=" indeed."
)
def test_ani_message_chunks() -> None:
assert AIMessageChunk(example=True, content="I am") + AIMessageChunk(
example=True, content=" indeed."
) == AIMessageChunk(
example=True, content="I am indeed."
), "AIMessageChunk + AIMessageChunk should be a AIMessageChunk"
with pytest.raises(ValueError):
AIMessageChunk(example=True, content="I am") + AIMessageChunk(
example=False, content=" indeed."
)

Loading…
Cancel
Save