|
|
|
@ -190,9 +190,10 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
|
|
|
|
messages: List[BaseMessage],
|
|
|
|
|
stop: Optional[List[str]] = None,
|
|
|
|
|
callbacks: Callbacks = None,
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
) -> BaseMessage:
|
|
|
|
|
generation = self.generate(
|
|
|
|
|
[messages], stop=stop, callbacks=callbacks
|
|
|
|
|
[messages], stop=stop, callbacks=callbacks, **kwargs
|
|
|
|
|
).generations[0][0]
|
|
|
|
|
if isinstance(generation, ChatGeneration):
|
|
|
|
|
return generation.message
|
|
|
|
@ -227,7 +228,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
|
|
|
|
_stop = None
|
|
|
|
|
else:
|
|
|
|
|
_stop = list(stop)
|
|
|
|
|
result = self([HumanMessage(content=text)], stop=_stop)
|
|
|
|
|
result = self([HumanMessage(content=text)], stop=_stop, **kwargs)
|
|
|
|
|
return result.content
|
|
|
|
|
|
|
|
|
|
def predict_messages(
|
|
|
|
|