Fix output types for BaseChatModel (#11670)

* Should use non chunked messages for Invoke/Batch
* After this PR, stream output type is not represented, do we want to
use the union?

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
pull/11676/head
Eugene Yurtsev 11 months ago committed by GitHub
parent 7d0dda7e41
commit 539941281d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -38,14 +38,12 @@ from langchain.schema import (
from langchain.schema.language_model import BaseLanguageModel, LanguageModelInput from langchain.schema.language_model import BaseLanguageModel, LanguageModelInput
from langchain.schema.messages import ( from langchain.schema.messages import (
AIMessage, AIMessage,
AIMessageChunk,
BaseMessage, BaseMessage,
BaseMessageChunk, BaseMessageChunk,
ChatMessageChunk, ChatMessage,
FunctionMessageChunk, FunctionMessage,
HumanMessage, HumanMessage,
HumanMessageChunk, SystemMessage,
SystemMessageChunk,
) )
from langchain.schema.output import ChatGenerationChunk from langchain.schema.output import ChatGenerationChunk
from langchain.schema.runnable import RunnableConfig from langchain.schema.runnable import RunnableConfig
@ -115,13 +113,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
@property @property
def OutputType(self) -> Any: def OutputType(self) -> Any:
"""Get the input type for this runnable.""" """Get the output type for this runnable."""
return Union[ return Union[
HumanMessageChunk, HumanMessage, AIMessage, ChatMessage, FunctionMessage, SystemMessage
AIMessageChunk,
ChatMessageChunk,
FunctionMessageChunk,
SystemMessageChunk,
] ]
def _convert_input(self, input: LanguageModelInput) -> PromptValue: def _convert_input(self, input: LanguageModelInput) -> PromptValue:

@ -2163,24 +2163,24 @@
dict({ dict({
'anyOf': list([ 'anyOf': list([
dict({ dict({
'$ref': '#/definitions/HumanMessageChunk', '$ref': '#/definitions/HumanMessage',
}), }),
dict({ dict({
'$ref': '#/definitions/AIMessageChunk', '$ref': '#/definitions/AIMessage',
}), }),
dict({ dict({
'$ref': '#/definitions/ChatMessageChunk', '$ref': '#/definitions/ChatMessage',
}), }),
dict({ dict({
'$ref': '#/definitions/FunctionMessageChunk', '$ref': '#/definitions/FunctionMessage',
}), }),
dict({ dict({
'$ref': '#/definitions/SystemMessageChunk', '$ref': '#/definitions/SystemMessage',
}), }),
]), ]),
'definitions': dict({ 'definitions': dict({
'AIMessageChunk': dict({ 'AIMessage': dict({
'description': 'A Message chunk from an AI.', 'description': 'A Message from an AI.',
'properties': dict({ 'properties': dict({
'additional_kwargs': dict({ 'additional_kwargs': dict({
'title': 'Additional Kwargs', 'title': 'Additional Kwargs',
@ -2196,9 +2196,9 @@
'type': 'boolean', 'type': 'boolean',
}), }),
'type': dict({ 'type': dict({
'default': 'AIMessageChunk', 'default': 'ai',
'enum': list([ 'enum': list([
'AIMessageChunk', 'ai',
]), ]),
'title': 'Type', 'title': 'Type',
'type': 'string', 'type': 'string',
@ -2207,11 +2207,11 @@
'required': list([ 'required': list([
'content', 'content',
]), ]),
'title': 'AIMessageChunk', 'title': 'AIMessage',
'type': 'object', 'type': 'object',
}), }),
'ChatMessageChunk': dict({ 'ChatMessage': dict({
'description': 'A Chat Message chunk.', 'description': 'A Message that can be assigned an arbitrary speaker (i.e. role).',
'properties': dict({ 'properties': dict({
'additional_kwargs': dict({ 'additional_kwargs': dict({
'title': 'Additional Kwargs', 'title': 'Additional Kwargs',
@ -2226,9 +2226,9 @@
'type': 'string', 'type': 'string',
}), }),
'type': dict({ 'type': dict({
'default': 'ChatMessageChunk', 'default': 'chat',
'enum': list([ 'enum': list([
'ChatMessageChunk', 'chat',
]), ]),
'title': 'Type', 'title': 'Type',
'type': 'string', 'type': 'string',
@ -2238,11 +2238,11 @@
'content', 'content',
'role', 'role',
]), ]),
'title': 'ChatMessageChunk', 'title': 'ChatMessage',
'type': 'object', 'type': 'object',
}), }),
'FunctionMessageChunk': dict({ 'FunctionMessage': dict({
'description': 'A Function Message chunk.', 'description': 'A Message for passing the result of executing a function back to a model.',
'properties': dict({ 'properties': dict({
'additional_kwargs': dict({ 'additional_kwargs': dict({
'title': 'Additional Kwargs', 'title': 'Additional Kwargs',
@ -2257,9 +2257,9 @@
'type': 'string', 'type': 'string',
}), }),
'type': dict({ 'type': dict({
'default': 'FunctionMessageChunk', 'default': 'function',
'enum': list([ 'enum': list([
'FunctionMessageChunk', 'function',
]), ]),
'title': 'Type', 'title': 'Type',
'type': 'string', 'type': 'string',
@ -2269,11 +2269,11 @@
'content', 'content',
'name', 'name',
]), ]),
'title': 'FunctionMessageChunk', 'title': 'FunctionMessage',
'type': 'object', 'type': 'object',
}), }),
'HumanMessageChunk': dict({ 'HumanMessage': dict({
'description': 'A Human Message chunk.', 'description': 'A Message from a human.',
'properties': dict({ 'properties': dict({
'additional_kwargs': dict({ 'additional_kwargs': dict({
'title': 'Additional Kwargs', 'title': 'Additional Kwargs',
@ -2289,9 +2289,9 @@
'type': 'boolean', 'type': 'boolean',
}), }),
'type': dict({ 'type': dict({
'default': 'HumanMessageChunk', 'default': 'human',
'enum': list([ 'enum': list([
'HumanMessageChunk', 'human',
]), ]),
'title': 'Type', 'title': 'Type',
'type': 'string', 'type': 'string',
@ -2300,11 +2300,14 @@
'required': list([ 'required': list([
'content', 'content',
]), ]),
'title': 'HumanMessageChunk', 'title': 'HumanMessage',
'type': 'object', 'type': 'object',
}), }),
'SystemMessageChunk': dict({ 'SystemMessage': dict({
'description': 'A System Message chunk.', 'description': '''
A Message for priming AI behavior, usually passed in as the first of a sequence
of input messages.
''',
'properties': dict({ 'properties': dict({
'additional_kwargs': dict({ 'additional_kwargs': dict({
'title': 'Additional Kwargs', 'title': 'Additional Kwargs',
@ -2315,9 +2318,9 @@
'type': 'string', 'type': 'string',
}), }),
'type': dict({ 'type': dict({
'default': 'SystemMessageChunk', 'default': 'system',
'enum': list([ 'enum': list([
'SystemMessageChunk', 'system',
]), ]),
'title': 'Type', 'title': 'Type',
'type': 'string', 'type': 'string',
@ -2326,7 +2329,7 @@
'required': list([ 'required': list([
'content', 'content',
]), ]),
'title': 'SystemMessageChunk', 'title': 'SystemMessage',
'type': 'object', 'type': 'object',
}), }),
}), }),

Loading…
Cancel
Save