From 8b4cb4eb60e3935eea895aa955d68ca0afce788c Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Fri, 29 Sep 2023 15:14:52 -0400 Subject: [PATCH] Add type to message chunks (#11232) --- libs/langchain/langchain/schema/messages.py | 30 ++- .../runnable/__snapshots__/test_runnable.ambr | 240 ++++++++++++++++++ .../schema/runnable/test_runnable.py | 6 +- .../langchain/tests/unit_tests/test_schema.py | 56 ++++ 4 files changed, 329 insertions(+), 3 deletions(-) diff --git a/libs/langchain/langchain/schema/messages.py b/libs/langchain/langchain/schema/messages.py index 003d133b9d..f2e33b0248 100644 --- a/libs/langchain/langchain/schema/messages.py +++ b/libs/langchain/langchain/schema/messages.py @@ -149,6 +149,7 @@ class HumanMessage(BaseMessage): """ type: Literal["human"] = "human" + is_chunk: Literal[False] = False HumanMessage.update_forward_refs() @@ -157,7 +158,10 @@ HumanMessage.update_forward_refs() class HumanMessageChunk(HumanMessage, BaseMessageChunk): """A Human Message chunk.""" - pass + # Ignoring mypy re-assignment here since we're overriding the value + # to make sure that the chunk variant can be discriminated from the + # non-chunk variant. + is_chunk: Literal[True] = True # type: ignore[assignment] class AIMessage(BaseMessage): @@ -169,6 +173,7 @@ class AIMessage(BaseMessage): """ type: Literal["ai"] = "ai" + is_chunk: Literal[False] = False AIMessage.update_forward_refs() @@ -177,6 +182,11 @@ AIMessage.update_forward_refs() class AIMessageChunk(AIMessage, BaseMessageChunk): """A Message chunk from an AI.""" + # Ignoring mypy re-assignment here since we're overriding the value + # to make sure that the chunk variant can be discriminated from the + # non-chunk variant. + is_chunk: Literal[True] = True # type: ignore[assignment] + def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore if isinstance(other, AIMessageChunk): if self.example != other.example: @@ -201,6 +211,7 @@ class SystemMessage(BaseMessage): """ type: Literal["system"] = "system" + is_chunk: Literal[False] = False SystemMessage.update_forward_refs() @@ -209,7 +220,10 @@ SystemMessage.update_forward_refs() class SystemMessageChunk(SystemMessage, BaseMessageChunk): """A System Message chunk.""" - pass + # Ignoring mypy re-assignment here since we're overriding the value + # to make sure that the chunk variant can be discriminated from the + # non-chunk variant. + is_chunk: Literal[True] = True # type: ignore[assignment] class FunctionMessage(BaseMessage): @@ -219,6 +233,7 @@ class FunctionMessage(BaseMessage): """The name of the function that was executed.""" type: Literal["function"] = "function" + is_chunk: Literal[False] = False FunctionMessage.update_forward_refs() @@ -227,6 +242,11 @@ FunctionMessage.update_forward_refs() class FunctionMessageChunk(FunctionMessage, BaseMessageChunk): """A Function Message chunk.""" + # Ignoring mypy re-assignment here since we're overriding the value + # to make sure that the chunk variant can be discriminated from the + # non-chunk variant. + is_chunk: Literal[True] = True # type: ignore[assignment] + def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore if isinstance(other, FunctionMessageChunk): if self.name != other.name: @@ -252,6 +272,7 @@ class ChatMessage(BaseMessage): """The speaker / role of the Message.""" type: Literal["chat"] = "chat" + is_chunk: Literal[False] = False ChatMessage.update_forward_refs() @@ -260,6 +281,11 @@ ChatMessage.update_forward_refs() class ChatMessageChunk(ChatMessage, BaseMessageChunk): """A Chat Message chunk.""" + # Ignoring mypy re-assignment here since we're overriding the value + # to make sure that the chunk variant can be discriminated from the + # non-chunk variant. + is_chunk: Literal[True] = True # type: ignore[assignment] + def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore if isinstance(other, ChatMessageChunk): if self.role != other.role: diff --git a/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr b/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr index 4637c17e65..4e7b4c53de 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr +++ b/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr @@ -1693,6 +1693,14 @@ 'title': 'Example', 'type': 'boolean', }), + 'is_chunk': dict({ + 'default': False, + 'enum': list([ + False, + ]), + 'title': 'Is Chunk', + 'type': 'boolean', + }), 'type': dict({ 'default': 'ai', 'enum': list([ @@ -1719,6 +1727,14 @@ 'title': 'Content', 'type': 'string', }), + 'is_chunk': dict({ + 'default': False, + 'enum': list([ + False, + ]), + 'title': 'Is Chunk', + 'type': 'boolean', + }), 'role': dict({ 'title': 'Role', 'type': 'string', @@ -1786,6 +1802,14 @@ 'title': 'Content', 'type': 'string', }), + 'is_chunk': dict({ + 'default': False, + 'enum': list([ + False, + ]), + 'title': 'Is Chunk', + 'type': 'boolean', + }), 'name': dict({ 'title': 'Name', 'type': 'string', @@ -1822,6 +1846,14 @@ 'title': 'Example', 'type': 'boolean', }), + 'is_chunk': dict({ + 'default': False, + 'enum': list([ + False, + ]), + 'title': 'Is Chunk', + 'type': 'boolean', + }), 'type': dict({ 'default': 'human', 'enum': list([ @@ -1865,6 +1897,14 @@ 'title': 'Content', 'type': 'string', }), + 'is_chunk': dict({ + 'default': False, + 'enum': list([ + False, + ]), + 'title': 'Is Chunk', + 'type': 'boolean', + }), 'type': dict({ 'default': 'system', 'enum': list([ @@ -1936,6 +1976,14 @@ 'title': 'Example', 'type': 'boolean', }), + 'is_chunk': dict({ + 'default': False, + 'enum': list([ + False, + ]), + 'title': 'Is Chunk', + 'type': 'boolean', + }), 'type': dict({ 'default': 'ai', 'enum': list([ @@ -1962,6 +2010,14 @@ 'title': 'Content', 'type': 'string', }), + 'is_chunk': dict({ + 'default': False, + 'enum': list([ + False, + ]), + 'title': 'Is Chunk', + 'type': 'boolean', + }), 'role': dict({ 'title': 'Role', 'type': 'string', @@ -2029,6 +2085,14 @@ 'title': 'Content', 'type': 'string', }), + 'is_chunk': dict({ + 'default': False, + 'enum': list([ + False, + ]), + 'title': 'Is Chunk', + 'type': 'boolean', + }), 'name': dict({ 'title': 'Name', 'type': 'string', @@ -2065,6 +2129,14 @@ 'title': 'Example', 'type': 'boolean', }), + 'is_chunk': dict({ + 'default': False, + 'enum': list([ + False, + ]), + 'title': 'Is Chunk', + 'type': 'boolean', + }), 'type': dict({ 'default': 'human', 'enum': list([ @@ -2108,6 +2180,14 @@ 'title': 'Content', 'type': 'string', }), + 'is_chunk': dict({ + 'default': False, + 'enum': list([ + False, + ]), + 'title': 'Is Chunk', + 'type': 'boolean', + }), 'type': dict({ 'default': 'system', 'enum': list([ @@ -2163,6 +2243,14 @@ 'title': 'Example', 'type': 'boolean', }), + 'is_chunk': dict({ + 'default': True, + 'enum': list([ + True, + ]), + 'title': 'Is Chunk', + 'type': 'boolean', + }), 'type': dict({ 'default': 'ai', 'enum': list([ @@ -2189,6 +2277,14 @@ 'title': 'Content', 'type': 'string', }), + 'is_chunk': dict({ + 'default': True, + 'enum': list([ + True, + ]), + 'title': 'Is Chunk', + 'type': 'boolean', + }), 'role': dict({ 'title': 'Role', 'type': 'string', @@ -2220,6 +2316,14 @@ 'title': 'Content', 'type': 'string', }), + 'is_chunk': dict({ + 'default': True, + 'enum': list([ + True, + ]), + 'title': 'Is Chunk', + 'type': 'boolean', + }), 'name': dict({ 'title': 'Name', 'type': 'string', @@ -2256,6 +2360,14 @@ 'title': 'Example', 'type': 'boolean', }), + 'is_chunk': dict({ + 'default': True, + 'enum': list([ + True, + ]), + 'title': 'Is Chunk', + 'type': 'boolean', + }), 'type': dict({ 'default': 'human', 'enum': list([ @@ -2282,6 +2394,14 @@ 'title': 'Content', 'type': 'string', }), + 'is_chunk': dict({ + 'default': True, + 'enum': list([ + True, + ]), + 'title': 'Is Chunk', + 'type': 'boolean', + }), 'type': dict({ 'default': 'system', 'enum': list([ @@ -2328,6 +2448,14 @@ 'title': 'Example', 'type': 'boolean', }), + 'is_chunk': dict({ + 'default': False, + 'enum': list([ + False, + ]), + 'title': 'Is Chunk', + 'type': 'boolean', + }), 'type': dict({ 'default': 'ai', 'enum': list([ @@ -2354,6 +2482,14 @@ 'title': 'Content', 'type': 'string', }), + 'is_chunk': dict({ + 'default': False, + 'enum': list([ + False, + ]), + 'title': 'Is Chunk', + 'type': 'boolean', + }), 'role': dict({ 'title': 'Role', 'type': 'string', @@ -2421,6 +2557,14 @@ 'title': 'Content', 'type': 'string', }), + 'is_chunk': dict({ + 'default': False, + 'enum': list([ + False, + ]), + 'title': 'Is Chunk', + 'type': 'boolean', + }), 'name': dict({ 'title': 'Name', 'type': 'string', @@ -2457,6 +2601,14 @@ 'title': 'Example', 'type': 'boolean', }), + 'is_chunk': dict({ + 'default': False, + 'enum': list([ + False, + ]), + 'title': 'Is Chunk', + 'type': 'boolean', + }), 'type': dict({ 'default': 'human', 'enum': list([ @@ -2500,6 +2652,14 @@ 'title': 'Content', 'type': 'string', }), + 'is_chunk': dict({ + 'default': False, + 'enum': list([ + False, + ]), + 'title': 'Is Chunk', + 'type': 'boolean', + }), 'type': dict({ 'default': 'system', 'enum': list([ @@ -2538,6 +2698,14 @@ 'title': 'Example', 'type': 'boolean', }), + 'is_chunk': dict({ + 'default': False, + 'enum': list([ + False, + ]), + 'title': 'Is Chunk', + 'type': 'boolean', + }), 'type': dict({ 'default': 'ai', 'enum': list([ @@ -2564,6 +2732,14 @@ 'title': 'Content', 'type': 'string', }), + 'is_chunk': dict({ + 'default': False, + 'enum': list([ + False, + ]), + 'title': 'Is Chunk', + 'type': 'boolean', + }), 'role': dict({ 'title': 'Role', 'type': 'string', @@ -2631,6 +2807,14 @@ 'title': 'Content', 'type': 'string', }), + 'is_chunk': dict({ + 'default': False, + 'enum': list([ + False, + ]), + 'title': 'Is Chunk', + 'type': 'boolean', + }), 'name': dict({ 'title': 'Name', 'type': 'string', @@ -2667,6 +2851,14 @@ 'title': 'Example', 'type': 'boolean', }), + 'is_chunk': dict({ + 'default': False, + 'enum': list([ + False, + ]), + 'title': 'Is Chunk', + 'type': 'boolean', + }), 'type': dict({ 'default': 'human', 'enum': list([ @@ -2721,6 +2913,14 @@ 'title': 'Content', 'type': 'string', }), + 'is_chunk': dict({ + 'default': False, + 'enum': list([ + False, + ]), + 'title': 'Is Chunk', + 'type': 'boolean', + }), 'type': dict({ 'default': 'system', 'enum': list([ @@ -2783,6 +2983,14 @@ 'title': 'Example', 'type': 'boolean', }), + 'is_chunk': dict({ + 'default': False, + 'enum': list([ + False, + ]), + 'title': 'Is Chunk', + 'type': 'boolean', + }), 'type': dict({ 'default': 'ai', 'enum': list([ @@ -2809,6 +3017,14 @@ 'title': 'Content', 'type': 'string', }), + 'is_chunk': dict({ + 'default': False, + 'enum': list([ + False, + ]), + 'title': 'Is Chunk', + 'type': 'boolean', + }), 'role': dict({ 'title': 'Role', 'type': 'string', @@ -2840,6 +3056,14 @@ 'title': 'Content', 'type': 'string', }), + 'is_chunk': dict({ + 'default': False, + 'enum': list([ + False, + ]), + 'title': 'Is Chunk', + 'type': 'boolean', + }), 'name': dict({ 'title': 'Name', 'type': 'string', @@ -2876,6 +3100,14 @@ 'title': 'Example', 'type': 'boolean', }), + 'is_chunk': dict({ + 'default': False, + 'enum': list([ + False, + ]), + 'title': 'Is Chunk', + 'type': 'boolean', + }), 'type': dict({ 'default': 'human', 'enum': list([ @@ -2905,6 +3137,14 @@ 'title': 'Content', 'type': 'string', }), + 'is_chunk': dict({ + 'default': False, + 'enum': list([ + False, + ]), + 'title': 'Is Chunk', + 'type': 'boolean', + }), 'type': dict({ 'default': 'system', 'enum': list([ diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index b3484a8d8a..2031ec21cb 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -960,7 +960,11 @@ async def test_prompt_with_chat_model( tracer = FakeTracer() assert [ *chain.stream({"question": "What is your name?"}, dict(callbacks=[tracer])) - ] == [AIMessage(content="f"), AIMessage(content="o"), AIMessage(content="o")] + ] == [ + AIMessageChunk(content="f"), + AIMessageChunk(content="o"), + AIMessageChunk(content="o"), + ] assert prompt_spy.call_args.args[1] == {"question": "What is your name?"} assert chat_spy.call_args.args[1] == ChatPromptValue( messages=[ diff --git a/libs/langchain/tests/unit_tests/test_schema.py b/libs/langchain/tests/unit_tests/test_schema.py index 3fa2be84e1..4b72ddde1c 100644 --- a/libs/langchain/tests/unit_tests/test_schema.py +++ b/libs/langchain/tests/unit_tests/test_schema.py @@ -1,11 +1,20 @@ """Test formatting functionality.""" import unittest +from typing import Union +from langchain.pydantic_v1 import BaseModel from langchain.schema.messages import ( AIMessage, + AIMessageChunk, + ChatMessage, + ChatMessageChunk, + FunctionMessage, + FunctionMessageChunk, HumanMessage, + HumanMessageChunk, SystemMessage, + SystemMessageChunk, get_buffer_string, messages_from_dict, messages_to_dict, @@ -70,3 +79,50 @@ def test_multiple_msg() -> None: sys_msg, ] assert messages_from_dict(messages_to_dict(msgs)) == msgs + + +def test_distinguish_messages() -> None: + """Test that pydantic is able to discriminate between similar looking messages.""" + + class WellKnownTypes(BaseModel): + __root__: Union[ + HumanMessage, + AIMessage, + SystemMessage, + FunctionMessage, + HumanMessageChunk, + AIMessageChunk, + SystemMessageChunk, + FunctionMessageChunk, + ChatMessageChunk, + ChatMessage, + ] + + messages = [ + HumanMessage(content="human"), + HumanMessageChunk(content="human"), + AIMessage(content="ai"), + AIMessageChunk(content="ai"), + SystemMessage(content="sys"), + SystemMessageChunk(content="sys"), + FunctionMessage( + name="func", + content="func", + ), + FunctionMessageChunk( + name="func", + content="func", + ), + ChatMessage( + role="human", + content="human", + ), + ChatMessageChunk( + role="human", + content="human", + ), + ] + + for msg in messages: + obj1 = WellKnownTypes.parse_obj(msg.dict()) + assert type(obj1.__root__) == type(msg), f"failed for {type(msg)}"