From 539941281ddb6cca5b9949eab193035711c4cee1 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Wed, 11 Oct 2023 16:02:03 -0400 Subject: [PATCH] Fix output types for BaseChatModel (#11670) * Should use non chunked messages for Invoke/Batch * After this PR, stream output type is not represented, do we want to use the union? --------- Co-authored-by: Erick Friis --- libs/langchain/langchain/chat_models/base.py | 16 ++--- .../runnable/__snapshots__/test_runnable.ambr | 63 ++++++++++--------- 2 files changed, 38 insertions(+), 41 deletions(-) diff --git a/libs/langchain/langchain/chat_models/base.py b/libs/langchain/langchain/chat_models/base.py index afb6926742..08d96e6bbf 100644 --- a/libs/langchain/langchain/chat_models/base.py +++ b/libs/langchain/langchain/chat_models/base.py @@ -38,14 +38,12 @@ from langchain.schema import ( from langchain.schema.language_model import BaseLanguageModel, LanguageModelInput from langchain.schema.messages import ( AIMessage, - AIMessageChunk, BaseMessage, BaseMessageChunk, - ChatMessageChunk, - FunctionMessageChunk, + ChatMessage, + FunctionMessage, HumanMessage, - HumanMessageChunk, - SystemMessageChunk, + SystemMessage, ) from langchain.schema.output import ChatGenerationChunk from langchain.schema.runnable import RunnableConfig @@ -115,13 +113,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC): @property def OutputType(self) -> Any: - """Get the input type for this runnable.""" + """Get the output type for this runnable.""" return Union[ - HumanMessageChunk, - AIMessageChunk, - ChatMessageChunk, - FunctionMessageChunk, - SystemMessageChunk, + HumanMessage, AIMessage, ChatMessage, FunctionMessage, SystemMessage ] def _convert_input(self, input: LanguageModelInput) -> PromptValue: diff --git a/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr b/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr index fb9d63ee91..4a7c638359 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr +++ b/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr @@ -2163,24 +2163,24 @@ dict({ 'anyOf': list([ dict({ - '$ref': '#/definitions/HumanMessageChunk', + '$ref': '#/definitions/HumanMessage', }), dict({ - '$ref': '#/definitions/AIMessageChunk', + '$ref': '#/definitions/AIMessage', }), dict({ - '$ref': '#/definitions/ChatMessageChunk', + '$ref': '#/definitions/ChatMessage', }), dict({ - '$ref': '#/definitions/FunctionMessageChunk', + '$ref': '#/definitions/FunctionMessage', }), dict({ - '$ref': '#/definitions/SystemMessageChunk', + '$ref': '#/definitions/SystemMessage', }), ]), 'definitions': dict({ - 'AIMessageChunk': dict({ - 'description': 'A Message chunk from an AI.', + 'AIMessage': dict({ + 'description': 'A Message from an AI.', 'properties': dict({ 'additional_kwargs': dict({ 'title': 'Additional Kwargs', @@ -2196,9 +2196,9 @@ 'type': 'boolean', }), 'type': dict({ - 'default': 'AIMessageChunk', + 'default': 'ai', 'enum': list([ - 'AIMessageChunk', + 'ai', ]), 'title': 'Type', 'type': 'string', @@ -2207,11 +2207,11 @@ 'required': list([ 'content', ]), - 'title': 'AIMessageChunk', + 'title': 'AIMessage', 'type': 'object', }), - 'ChatMessageChunk': dict({ - 'description': 'A Chat Message chunk.', + 'ChatMessage': dict({ + 'description': 'A Message that can be assigned an arbitrary speaker (i.e. role).', 'properties': dict({ 'additional_kwargs': dict({ 'title': 'Additional Kwargs', @@ -2226,9 +2226,9 @@ 'type': 'string', }), 'type': dict({ - 'default': 'ChatMessageChunk', + 'default': 'chat', 'enum': list([ - 'ChatMessageChunk', + 'chat', ]), 'title': 'Type', 'type': 'string', @@ -2238,11 +2238,11 @@ 'content', 'role', ]), - 'title': 'ChatMessageChunk', + 'title': 'ChatMessage', 'type': 'object', }), - 'FunctionMessageChunk': dict({ - 'description': 'A Function Message chunk.', + 'FunctionMessage': dict({ + 'description': 'A Message for passing the result of executing a function back to a model.', 'properties': dict({ 'additional_kwargs': dict({ 'title': 'Additional Kwargs', @@ -2257,9 +2257,9 @@ 'type': 'string', }), 'type': dict({ - 'default': 'FunctionMessageChunk', + 'default': 'function', 'enum': list([ - 'FunctionMessageChunk', + 'function', ]), 'title': 'Type', 'type': 'string', @@ -2269,11 +2269,11 @@ 'content', 'name', ]), - 'title': 'FunctionMessageChunk', + 'title': 'FunctionMessage', 'type': 'object', }), - 'HumanMessageChunk': dict({ - 'description': 'A Human Message chunk.', + 'HumanMessage': dict({ + 'description': 'A Message from a human.', 'properties': dict({ 'additional_kwargs': dict({ 'title': 'Additional Kwargs', @@ -2289,9 +2289,9 @@ 'type': 'boolean', }), 'type': dict({ - 'default': 'HumanMessageChunk', + 'default': 'human', 'enum': list([ - 'HumanMessageChunk', + 'human', ]), 'title': 'Type', 'type': 'string', @@ -2300,11 +2300,14 @@ 'required': list([ 'content', ]), - 'title': 'HumanMessageChunk', + 'title': 'HumanMessage', 'type': 'object', }), - 'SystemMessageChunk': dict({ - 'description': 'A System Message chunk.', + 'SystemMessage': dict({ + 'description': ''' + A Message for priming AI behavior, usually passed in as the first of a sequence + of input messages. + ''', 'properties': dict({ 'additional_kwargs': dict({ 'title': 'Additional Kwargs', @@ -2315,9 +2318,9 @@ 'type': 'string', }), 'type': dict({ - 'default': 'SystemMessageChunk', + 'default': 'system', 'enum': list([ - 'SystemMessageChunk', + 'system', ]), 'title': 'Type', 'type': 'string', @@ -2326,7 +2329,7 @@ 'required': list([ 'content', ]), - 'title': 'SystemMessageChunk', + 'title': 'SystemMessage', 'type': 'object', }), }),