Add type to message chunks (#11232)

pull/11233/head
Eugene Yurtsev 10 months ago committed by GitHub
parent fb66b392c6
commit 8b4cb4eb60
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -149,6 +149,7 @@ class HumanMessage(BaseMessage):
"""
type: Literal["human"] = "human"
is_chunk: Literal[False] = False
HumanMessage.update_forward_refs()
@ -157,7 +158,10 @@ HumanMessage.update_forward_refs()
class HumanMessageChunk(HumanMessage, BaseMessageChunk):
"""A Human Message chunk."""
pass
# Ignoring mypy re-assignment here since we're overriding the value
# to make sure that the chunk variant can be discriminated from the
# non-chunk variant.
is_chunk: Literal[True] = True # type: ignore[assignment]
class AIMessage(BaseMessage):
@ -169,6 +173,7 @@ class AIMessage(BaseMessage):
"""
type: Literal["ai"] = "ai"
is_chunk: Literal[False] = False
AIMessage.update_forward_refs()
@ -177,6 +182,11 @@ AIMessage.update_forward_refs()
class AIMessageChunk(AIMessage, BaseMessageChunk):
"""A Message chunk from an AI."""
# Ignoring mypy re-assignment here since we're overriding the value
# to make sure that the chunk variant can be discriminated from the
# non-chunk variant.
is_chunk: Literal[True] = True # type: ignore[assignment]
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, AIMessageChunk):
if self.example != other.example:
@ -201,6 +211,7 @@ class SystemMessage(BaseMessage):
"""
type: Literal["system"] = "system"
is_chunk: Literal[False] = False
SystemMessage.update_forward_refs()
@ -209,7 +220,10 @@ SystemMessage.update_forward_refs()
class SystemMessageChunk(SystemMessage, BaseMessageChunk):
"""A System Message chunk."""
pass
# Ignoring mypy re-assignment here since we're overriding the value
# to make sure that the chunk variant can be discriminated from the
# non-chunk variant.
is_chunk: Literal[True] = True # type: ignore[assignment]
class FunctionMessage(BaseMessage):
@ -219,6 +233,7 @@ class FunctionMessage(BaseMessage):
"""The name of the function that was executed."""
type: Literal["function"] = "function"
is_chunk: Literal[False] = False
FunctionMessage.update_forward_refs()
@ -227,6 +242,11 @@ FunctionMessage.update_forward_refs()
class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
"""A Function Message chunk."""
# Ignoring mypy re-assignment here since we're overriding the value
# to make sure that the chunk variant can be discriminated from the
# non-chunk variant.
is_chunk: Literal[True] = True # type: ignore[assignment]
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, FunctionMessageChunk):
if self.name != other.name:
@ -252,6 +272,7 @@ class ChatMessage(BaseMessage):
"""The speaker / role of the Message."""
type: Literal["chat"] = "chat"
is_chunk: Literal[False] = False
ChatMessage.update_forward_refs()
@ -260,6 +281,11 @@ ChatMessage.update_forward_refs()
class ChatMessageChunk(ChatMessage, BaseMessageChunk):
"""A Chat Message chunk."""
# Ignoring mypy re-assignment here since we're overriding the value
# to make sure that the chunk variant can be discriminated from the
# non-chunk variant.
is_chunk: Literal[True] = True # type: ignore[assignment]
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, ChatMessageChunk):
if self.role != other.role:

@ -1693,6 +1693,14 @@
'title': 'Example',
'type': 'boolean',
}),
'is_chunk': dict({
'default': False,
'enum': list([
False,
]),
'title': 'Is Chunk',
'type': 'boolean',
}),
'type': dict({
'default': 'ai',
'enum': list([
@ -1719,6 +1727,14 @@
'title': 'Content',
'type': 'string',
}),
'is_chunk': dict({
'default': False,
'enum': list([
False,
]),
'title': 'Is Chunk',
'type': 'boolean',
}),
'role': dict({
'title': 'Role',
'type': 'string',
@ -1786,6 +1802,14 @@
'title': 'Content',
'type': 'string',
}),
'is_chunk': dict({
'default': False,
'enum': list([
False,
]),
'title': 'Is Chunk',
'type': 'boolean',
}),
'name': dict({
'title': 'Name',
'type': 'string',
@ -1822,6 +1846,14 @@
'title': 'Example',
'type': 'boolean',
}),
'is_chunk': dict({
'default': False,
'enum': list([
False,
]),
'title': 'Is Chunk',
'type': 'boolean',
}),
'type': dict({
'default': 'human',
'enum': list([
@ -1865,6 +1897,14 @@
'title': 'Content',
'type': 'string',
}),
'is_chunk': dict({
'default': False,
'enum': list([
False,
]),
'title': 'Is Chunk',
'type': 'boolean',
}),
'type': dict({
'default': 'system',
'enum': list([
@ -1936,6 +1976,14 @@
'title': 'Example',
'type': 'boolean',
}),
'is_chunk': dict({
'default': False,
'enum': list([
False,
]),
'title': 'Is Chunk',
'type': 'boolean',
}),
'type': dict({
'default': 'ai',
'enum': list([
@ -1962,6 +2010,14 @@
'title': 'Content',
'type': 'string',
}),
'is_chunk': dict({
'default': False,
'enum': list([
False,
]),
'title': 'Is Chunk',
'type': 'boolean',
}),
'role': dict({
'title': 'Role',
'type': 'string',
@ -2029,6 +2085,14 @@
'title': 'Content',
'type': 'string',
}),
'is_chunk': dict({
'default': False,
'enum': list([
False,
]),
'title': 'Is Chunk',
'type': 'boolean',
}),
'name': dict({
'title': 'Name',
'type': 'string',
@ -2065,6 +2129,14 @@
'title': 'Example',
'type': 'boolean',
}),
'is_chunk': dict({
'default': False,
'enum': list([
False,
]),
'title': 'Is Chunk',
'type': 'boolean',
}),
'type': dict({
'default': 'human',
'enum': list([
@ -2108,6 +2180,14 @@
'title': 'Content',
'type': 'string',
}),
'is_chunk': dict({
'default': False,
'enum': list([
False,
]),
'title': 'Is Chunk',
'type': 'boolean',
}),
'type': dict({
'default': 'system',
'enum': list([
@ -2163,6 +2243,14 @@
'title': 'Example',
'type': 'boolean',
}),
'is_chunk': dict({
'default': True,
'enum': list([
True,
]),
'title': 'Is Chunk',
'type': 'boolean',
}),
'type': dict({
'default': 'ai',
'enum': list([
@ -2189,6 +2277,14 @@
'title': 'Content',
'type': 'string',
}),
'is_chunk': dict({
'default': True,
'enum': list([
True,
]),
'title': 'Is Chunk',
'type': 'boolean',
}),
'role': dict({
'title': 'Role',
'type': 'string',
@ -2220,6 +2316,14 @@
'title': 'Content',
'type': 'string',
}),
'is_chunk': dict({
'default': True,
'enum': list([
True,
]),
'title': 'Is Chunk',
'type': 'boolean',
}),
'name': dict({
'title': 'Name',
'type': 'string',
@ -2256,6 +2360,14 @@
'title': 'Example',
'type': 'boolean',
}),
'is_chunk': dict({
'default': True,
'enum': list([
True,
]),
'title': 'Is Chunk',
'type': 'boolean',
}),
'type': dict({
'default': 'human',
'enum': list([
@ -2282,6 +2394,14 @@
'title': 'Content',
'type': 'string',
}),
'is_chunk': dict({
'default': True,
'enum': list([
True,
]),
'title': 'Is Chunk',
'type': 'boolean',
}),
'type': dict({
'default': 'system',
'enum': list([
@ -2328,6 +2448,14 @@
'title': 'Example',
'type': 'boolean',
}),
'is_chunk': dict({
'default': False,
'enum': list([
False,
]),
'title': 'Is Chunk',
'type': 'boolean',
}),
'type': dict({
'default': 'ai',
'enum': list([
@ -2354,6 +2482,14 @@
'title': 'Content',
'type': 'string',
}),
'is_chunk': dict({
'default': False,
'enum': list([
False,
]),
'title': 'Is Chunk',
'type': 'boolean',
}),
'role': dict({
'title': 'Role',
'type': 'string',
@ -2421,6 +2557,14 @@
'title': 'Content',
'type': 'string',
}),
'is_chunk': dict({
'default': False,
'enum': list([
False,
]),
'title': 'Is Chunk',
'type': 'boolean',
}),
'name': dict({
'title': 'Name',
'type': 'string',
@ -2457,6 +2601,14 @@
'title': 'Example',
'type': 'boolean',
}),
'is_chunk': dict({
'default': False,
'enum': list([
False,
]),
'title': 'Is Chunk',
'type': 'boolean',
}),
'type': dict({
'default': 'human',
'enum': list([
@ -2500,6 +2652,14 @@
'title': 'Content',
'type': 'string',
}),
'is_chunk': dict({
'default': False,
'enum': list([
False,
]),
'title': 'Is Chunk',
'type': 'boolean',
}),
'type': dict({
'default': 'system',
'enum': list([
@ -2538,6 +2698,14 @@
'title': 'Example',
'type': 'boolean',
}),
'is_chunk': dict({
'default': False,
'enum': list([
False,
]),
'title': 'Is Chunk',
'type': 'boolean',
}),
'type': dict({
'default': 'ai',
'enum': list([
@ -2564,6 +2732,14 @@
'title': 'Content',
'type': 'string',
}),
'is_chunk': dict({
'default': False,
'enum': list([
False,
]),
'title': 'Is Chunk',
'type': 'boolean',
}),
'role': dict({
'title': 'Role',
'type': 'string',
@ -2631,6 +2807,14 @@
'title': 'Content',
'type': 'string',
}),
'is_chunk': dict({
'default': False,
'enum': list([
False,
]),
'title': 'Is Chunk',
'type': 'boolean',
}),
'name': dict({
'title': 'Name',
'type': 'string',
@ -2667,6 +2851,14 @@
'title': 'Example',
'type': 'boolean',
}),
'is_chunk': dict({
'default': False,
'enum': list([
False,
]),
'title': 'Is Chunk',
'type': 'boolean',
}),
'type': dict({
'default': 'human',
'enum': list([
@ -2721,6 +2913,14 @@
'title': 'Content',
'type': 'string',
}),
'is_chunk': dict({
'default': False,
'enum': list([
False,
]),
'title': 'Is Chunk',
'type': 'boolean',
}),
'type': dict({
'default': 'system',
'enum': list([
@ -2783,6 +2983,14 @@
'title': 'Example',
'type': 'boolean',
}),
'is_chunk': dict({
'default': False,
'enum': list([
False,
]),
'title': 'Is Chunk',
'type': 'boolean',
}),
'type': dict({
'default': 'ai',
'enum': list([
@ -2809,6 +3017,14 @@
'title': 'Content',
'type': 'string',
}),
'is_chunk': dict({
'default': False,
'enum': list([
False,
]),
'title': 'Is Chunk',
'type': 'boolean',
}),
'role': dict({
'title': 'Role',
'type': 'string',
@ -2840,6 +3056,14 @@
'title': 'Content',
'type': 'string',
}),
'is_chunk': dict({
'default': False,
'enum': list([
False,
]),
'title': 'Is Chunk',
'type': 'boolean',
}),
'name': dict({
'title': 'Name',
'type': 'string',
@ -2876,6 +3100,14 @@
'title': 'Example',
'type': 'boolean',
}),
'is_chunk': dict({
'default': False,
'enum': list([
False,
]),
'title': 'Is Chunk',
'type': 'boolean',
}),
'type': dict({
'default': 'human',
'enum': list([
@ -2905,6 +3137,14 @@
'title': 'Content',
'type': 'string',
}),
'is_chunk': dict({
'default': False,
'enum': list([
False,
]),
'title': 'Is Chunk',
'type': 'boolean',
}),
'type': dict({
'default': 'system',
'enum': list([

@ -960,7 +960,11 @@ async def test_prompt_with_chat_model(
tracer = FakeTracer()
assert [
*chain.stream({"question": "What is your name?"}, dict(callbacks=[tracer]))
] == [AIMessage(content="f"), AIMessage(content="o"), AIMessage(content="o")]
] == [
AIMessageChunk(content="f"),
AIMessageChunk(content="o"),
AIMessageChunk(content="o"),
]
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert chat_spy.call_args.args[1] == ChatPromptValue(
messages=[

@ -1,11 +1,20 @@
"""Test formatting functionality."""
import unittest
from typing import Union
from langchain.pydantic_v1 import BaseModel
from langchain.schema.messages import (
AIMessage,
AIMessageChunk,
ChatMessage,
ChatMessageChunk,
FunctionMessage,
FunctionMessageChunk,
HumanMessage,
HumanMessageChunk,
SystemMessage,
SystemMessageChunk,
get_buffer_string,
messages_from_dict,
messages_to_dict,
@ -70,3 +79,50 @@ def test_multiple_msg() -> None:
sys_msg,
]
assert messages_from_dict(messages_to_dict(msgs)) == msgs
def test_distinguish_messages() -> None:
"""Test that pydantic is able to discriminate between similar looking messages."""
class WellKnownTypes(BaseModel):
__root__: Union[
HumanMessage,
AIMessage,
SystemMessage,
FunctionMessage,
HumanMessageChunk,
AIMessageChunk,
SystemMessageChunk,
FunctionMessageChunk,
ChatMessageChunk,
ChatMessage,
]
messages = [
HumanMessage(content="human"),
HumanMessageChunk(content="human"),
AIMessage(content="ai"),
AIMessageChunk(content="ai"),
SystemMessage(content="sys"),
SystemMessageChunk(content="sys"),
FunctionMessage(
name="func",
content="func",
),
FunctionMessageChunk(
name="func",
content="func",
),
ChatMessage(
role="human",
content="human",
),
ChatMessageChunk(
role="human",
content="human",
),
]
for msg in messages:
obj1 = WellKnownTypes.parse_obj(msg.dict())
assert type(obj1.__root__) == type(msg), f"failed for {type(msg)}"

Loading…
Cancel
Save