|
|
|
@ -11,7 +11,6 @@ from typing import (
|
|
|
|
|
List,
|
|
|
|
|
Optional,
|
|
|
|
|
Sequence,
|
|
|
|
|
Union,
|
|
|
|
|
cast,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
@ -38,12 +37,10 @@ from langchain.schema import (
|
|
|
|
|
from langchain.schema.language_model import BaseLanguageModel, LanguageModelInput
|
|
|
|
|
from langchain.schema.messages import (
|
|
|
|
|
AIMessage,
|
|
|
|
|
AnyMessage,
|
|
|
|
|
BaseMessage,
|
|
|
|
|
BaseMessageChunk,
|
|
|
|
|
ChatMessage,
|
|
|
|
|
FunctionMessage,
|
|
|
|
|
HumanMessage,
|
|
|
|
|
SystemMessage,
|
|
|
|
|
)
|
|
|
|
|
from langchain.schema.output import ChatGenerationChunk
|
|
|
|
|
from langchain.schema.runnable import RunnableConfig
|
|
|
|
@ -79,7 +76,7 @@ async def _agenerate_from_stream(
|
|
|
|
|
return ChatResult(generations=[generation])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
|
|
|
|
class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|
|
|
|
"""Base class for Chat models."""
|
|
|
|
|
|
|
|
|
|
cache: Optional[bool] = None
|
|
|
|
@ -116,9 +113,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
|
|
|
|
@property
|
|
|
|
|
def OutputType(self) -> Any:
|
|
|
|
|
"""Get the output type for this runnable."""
|
|
|
|
|
return Union[
|
|
|
|
|
HumanMessage, AIMessage, ChatMessage, FunctionMessage, SystemMessage
|
|
|
|
|
]
|
|
|
|
|
return AnyMessage
|
|
|
|
|
|
|
|
|
|
def _convert_input(self, input: LanguageModelInput) -> PromptValue:
|
|
|
|
|
if isinstance(input, PromptValue):
|
|
|
|
@ -140,23 +135,20 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
|
|
|
|
*,
|
|
|
|
|
stop: Optional[List[str]] = None,
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
) -> BaseMessageChunk:
|
|
|
|
|
) -> BaseMessage:
|
|
|
|
|
config = config or {}
|
|
|
|
|
return cast(
|
|
|
|
|
BaseMessageChunk,
|
|
|
|
|
cast(
|
|
|
|
|
ChatGeneration,
|
|
|
|
|
self.generate_prompt(
|
|
|
|
|
[self._convert_input(input)],
|
|
|
|
|
stop=stop,
|
|
|
|
|
callbacks=config.get("callbacks"),
|
|
|
|
|
tags=config.get("tags"),
|
|
|
|
|
metadata=config.get("metadata"),
|
|
|
|
|
run_name=config.get("run_name"),
|
|
|
|
|
**kwargs,
|
|
|
|
|
).generations[0][0],
|
|
|
|
|
).message,
|
|
|
|
|
)
|
|
|
|
|
ChatGeneration,
|
|
|
|
|
self.generate_prompt(
|
|
|
|
|
[self._convert_input(input)],
|
|
|
|
|
stop=stop,
|
|
|
|
|
callbacks=config.get("callbacks"),
|
|
|
|
|
tags=config.get("tags"),
|
|
|
|
|
metadata=config.get("metadata"),
|
|
|
|
|
run_name=config.get("run_name"),
|
|
|
|
|
**kwargs,
|
|
|
|
|
).generations[0][0],
|
|
|
|
|
).message
|
|
|
|
|
|
|
|
|
|
async def ainvoke(
|
|
|
|
|
self,
|
|
|
|
@ -165,7 +157,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
|
|
|
|
*,
|
|
|
|
|
stop: Optional[List[str]] = None,
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
) -> BaseMessageChunk:
|
|
|
|
|
) -> BaseMessage:
|
|
|
|
|
config = config or {}
|
|
|
|
|
llm_result = await self.agenerate_prompt(
|
|
|
|
|
[self._convert_input(input)],
|
|
|
|
@ -176,9 +168,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
|
|
|
|
run_name=config.get("run_name"),
|
|
|
|
|
**kwargs,
|
|
|
|
|
)
|
|
|
|
|
return cast(
|
|
|
|
|
BaseMessageChunk, cast(ChatGeneration, llm_result.generations[0][0]).message
|
|
|
|
|
)
|
|
|
|
|
return cast(ChatGeneration, llm_result.generations[0][0]).message
|
|
|
|
|
|
|
|
|
|
def stream(
|
|
|
|
|
self,
|
|
|
|
@ -190,7 +180,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
|
|
|
|
) -> Iterator[BaseMessageChunk]:
|
|
|
|
|
if type(self)._stream == BaseChatModel._stream:
|
|
|
|
|
# model doesn't implement streaming, so use default implementation
|
|
|
|
|
yield self.invoke(input, config=config, stop=stop, **kwargs)
|
|
|
|
|
yield cast(
|
|
|
|
|
BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs)
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
config = config or {}
|
|
|
|
|
messages = self._convert_input(input).to_messages()
|
|
|
|
@ -241,7 +233,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
|
|
|
|
) -> AsyncIterator[BaseMessageChunk]:
|
|
|
|
|
if type(self)._astream == BaseChatModel._astream:
|
|
|
|
|
# model doesn't implement streaming, so use default implementation
|
|
|
|
|
yield self.invoke(input, config=config, stop=stop, **kwargs)
|
|
|
|
|
yield cast(
|
|
|
|
|
BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs)
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
config = config or {}
|
|
|
|
|
messages = self._convert_input(input).to_messages()
|
|
|
|
|