From 6e848b879ad551e930462c71e174be0cc730b4c9 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Tue, 3 Oct 2023 17:28:14 -0700 Subject: [PATCH] add default for async (#11367) --- libs/langchain/langchain/chat_models/base.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/libs/langchain/langchain/chat_models/base.py b/libs/langchain/langchain/chat_models/base.py index 7b8b474e2f..4a16393525 100644 --- a/libs/langchain/langchain/chat_models/base.py +++ b/libs/langchain/langchain/chat_models/base.py @@ -170,12 +170,6 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC): stop: Optional[List[str]] = None, **kwargs: Any, ) -> BaseMessageChunk: - if type(self)._agenerate == BaseChatModel._agenerate: - # model doesn't implement async generation, so use default implementation - return await asyncio.get_running_loop().run_in_executor( - None, partial(self.invoke, input, config, stop=stop, **kwargs) - ) - config = config or {} llm_result = await self.agenerate_prompt( [self._convert_input(input)], @@ -582,7 +576,12 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC): **kwargs: Any, ) -> ChatResult: """Top Level call""" - raise NotImplementedError() + return await asyncio.get_running_loop().run_in_executor( + None, + partial( + self._generate, messages, stop=stop, run_manager=run_manager, **kwargs + ), + ) def _stream( self,