Update chat model output type (#11833)

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/12016/head^2
Nuno Campos 12 months ago committed by GitHub
parent ed62984cb2
commit 7db6aabf65
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -11,7 +11,6 @@ from typing import (
List,
Optional,
Sequence,
Union,
cast,
)
@ -38,12 +37,10 @@ from langchain.schema import (
from langchain.schema.language_model import BaseLanguageModel, LanguageModelInput
from langchain.schema.messages import (
AIMessage,
AnyMessage,
BaseMessage,
BaseMessageChunk,
ChatMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
)
from langchain.schema.output import ChatGenerationChunk
from langchain.schema.runnable import RunnableConfig
@ -79,7 +76,7 @@ async def _agenerate_from_stream(
return ChatResult(generations=[generation])
class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
"""Base class for Chat models."""
cache: Optional[bool] = None
@ -116,9 +113,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
@property
def OutputType(self) -> Any:
"""Get the output type for this runnable."""
return Union[
HumanMessage, AIMessage, ChatMessage, FunctionMessage, SystemMessage
]
return AnyMessage
def _convert_input(self, input: LanguageModelInput) -> PromptValue:
if isinstance(input, PromptValue):
@ -140,23 +135,20 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> BaseMessageChunk:
) -> BaseMessage:
config = config or {}
return cast(
BaseMessageChunk,
cast(
ChatGeneration,
self.generate_prompt(
[self._convert_input(input)],
stop=stop,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
**kwargs,
).generations[0][0],
).message,
)
ChatGeneration,
self.generate_prompt(
[self._convert_input(input)],
stop=stop,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
**kwargs,
).generations[0][0],
).message
async def ainvoke(
self,
@ -165,7 +157,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> BaseMessageChunk:
) -> BaseMessage:
config = config or {}
llm_result = await self.agenerate_prompt(
[self._convert_input(input)],
@ -176,9 +168,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
run_name=config.get("run_name"),
**kwargs,
)
return cast(
BaseMessageChunk, cast(ChatGeneration, llm_result.generations[0][0]).message
)
return cast(ChatGeneration, llm_result.generations[0][0]).message
def stream(
self,
@ -190,7 +180,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
) -> Iterator[BaseMessageChunk]:
if type(self)._stream == BaseChatModel._stream:
# model doesn't implement streaming, so use default implementation
yield self.invoke(input, config=config, stop=stop, **kwargs)
yield cast(
BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs)
)
else:
config = config or {}
messages = self._convert_input(input).to_messages()
@ -241,7 +233,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
) -> AsyncIterator[BaseMessageChunk]:
if type(self)._astream == BaseChatModel._astream:
# model doesn't implement streaming, so use default implementation
yield self.invoke(input, config=config, stop=stop, **kwargs)
yield cast(
BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs)
)
else:
config = config or {}
messages = self._convert_input(input).to_messages()

@ -2163,19 +2163,19 @@
dict({
'anyOf': list([
dict({
'$ref': '#/definitions/HumanMessage',
'$ref': '#/definitions/AIMessage',
}),
dict({
'$ref': '#/definitions/AIMessage',
'$ref': '#/definitions/HumanMessage',
}),
dict({
'$ref': '#/definitions/ChatMessage',
}),
dict({
'$ref': '#/definitions/FunctionMessage',
'$ref': '#/definitions/SystemMessage',
}),
dict({
'$ref': '#/definitions/SystemMessage',
'$ref': '#/definitions/FunctionMessage',
}),
]),
'definitions': dict({

Loading…
Cancel
Save