diff --git a/langchain/chat_models/base.py b/langchain/chat_models/base.py index f3521df6..440336f2 100644 --- a/langchain/chat_models/base.py +++ b/langchain/chat_models/base.py @@ -190,9 +190,10 @@ class BaseChatModel(BaseLanguageModel, ABC): messages: List[BaseMessage], stop: Optional[List[str]] = None, callbacks: Callbacks = None, + **kwargs: Any, ) -> BaseMessage: generation = self.generate( - [messages], stop=stop, callbacks=callbacks + [messages], stop=stop, callbacks=callbacks, **kwargs ).generations[0][0] if isinstance(generation, ChatGeneration): return generation.message @@ -227,7 +228,7 @@ class BaseChatModel(BaseLanguageModel, ABC): _stop = None else: _stop = list(stop) - result = self([HumanMessage(content=text)], stop=_stop) + result = self([HumanMessage(content=text)], stop=_stop, **kwargs) return result.content def predict_messages(