diff --git a/langchain/chat_models/base.py b/langchain/chat_models/base.py index bc62535a..806c1574 100644 --- a/langchain/chat_models/base.py +++ b/langchain/chat_models/base.py @@ -2,6 +2,7 @@ import asyncio import inspect import warnings from abc import ABC, abstractmethod +from functools import partial from typing import Any, Dict, List, Mapping, Optional, Sequence from pydantic import Extra, Field, root_validator @@ -239,3 +240,12 @@ class SimpleChatModel(BaseChatModel): run_manager: Optional[CallbackManagerForLLMRun] = None, ) -> str: """Simpler interface.""" + + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + ) -> ChatResult: + func = partial(self._generate, messages, stop=stop, run_manager=run_manager) + return await asyncio.get_event_loop().run_in_executor(None, func)