From 925dd3e59e112dec4e3dad459b05c69aa054e1bc Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Tue, 23 May 2023 20:22:49 -0400 Subject: [PATCH] Add async versions of predict() and predict_messages() (#4867) # Add async versions of predict() and predict_messages() #4615 introduced a unifying interface for "base" and "chat" LLM models via the new `predict()` and `predict_messages()` methods that allow both types of models to operate on string and message-based inputs, respectively. This PR adds async versions of the same (`apredict()` and `apredict_messages()`) that are identical except for their use of `agenerate()` in place of `generate()`, which means they repurpose all existing work on the async backend. ## Who can review? Community members can review the PR once tests pass. Tag maintainers/contributors who might be interested: @hwchase17 (follows his work on #4615) @agola11 (async) --------- Co-authored-by: Harrison Chase --- langchain/base_language.py | 10 ++++++++++ langchain/chat_models/base.py | 30 ++++++++++++++++++++++++++++++ langchain/llms/base.py | 25 +++++++++++++++++++++++++ 3 files changed, 65 insertions(+) diff --git a/langchain/base_language.py b/langchain/base_language.py index cebf2ffa..1b5bd084 100644 --- a/langchain/base_language.py +++ b/langchain/base_language.py @@ -58,6 +58,16 @@ class BaseLanguageModel(BaseModel, ABC): ) -> BaseMessage: """Predict message from messages.""" + @abstractmethod + async def apredict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str: + """Predict text from text.""" + + @abstractmethod + async def apredict_messages( + self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None + ) -> BaseMessage: + """Predict message from messages.""" + def get_token_ids(self, text: str) -> List[int]: """Get the token present in the text.""" return _get_token_ids_default_method(text) diff --git a/langchain/chat_models/base.py b/langchain/chat_models/base.py index 806c1574..de2cdd06 100644 --- a/langchain/chat_models/base.py +++ b/langchain/chat_models/base.py @@ -183,6 +183,19 @@ class BaseChatModel(BaseLanguageModel, ABC): else: raise ValueError("Unexpected generation type") + async def _call_async( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, + ) -> BaseMessage: + result = await self.agenerate([messages], stop=stop, callbacks=callbacks) + generation = result.generations[0][0] + if isinstance(generation, ChatGeneration): + return generation.message + else: + raise ValueError("Unexpected generation type") + def call_as_llm(self, message: str, stop: Optional[List[str]] = None) -> str: return self.predict(message, stop=stop) @@ -203,6 +216,23 @@ class BaseChatModel(BaseLanguageModel, ABC): _stop = list(stop) return self(messages, stop=_stop) + async def apredict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str: + if stop is None: + _stop = None + else: + _stop = list(stop) + result = await self._call_async([HumanMessage(content=text)], stop=_stop) + return result.content + + async def apredict_messages( + self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None + ) -> BaseMessage: + if stop is None: + _stop = None + else: + _stop = list(stop) + return await self._call_async(messages, stop=_stop) + @property def _identifying_params(self) -> Mapping[str, Any]: """Get the identifying parameters.""" diff --git a/langchain/llms/base.py b/langchain/llms/base.py index 2e92c15d..21267a96 100644 --- a/langchain/llms/base.py +++ b/langchain/llms/base.py @@ -299,6 +299,13 @@ class BaseLLM(BaseLanguageModel, ABC): .text ) + async def _call_async( + self, prompt: str, stop: Optional[List[str]] = None, callbacks: Callbacks = None + ) -> str: + """Check Cache and run the LLM on the given prompt and input.""" + result = await self.agenerate([prompt], stop=stop, callbacks=callbacks) + return result.generations[0][0].text + def predict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str: if stop is None: _stop = None @@ -317,6 +324,24 @@ class BaseLLM(BaseLanguageModel, ABC): content = self(text, stop=_stop) return AIMessage(content=content) + async def apredict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str: + if stop is None: + _stop = None + else: + _stop = list(stop) + return await self._call_async(text, stop=_stop) + + async def apredict_messages( + self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None + ) -> BaseMessage: + text = get_buffer_string(messages) + if stop is None: + _stop = None + else: + _stop = list(stop) + content = await self._call_async(text, stop=_stop) + return AIMessage(content=content) + @property def _identifying_params(self) -> Mapping[str, Any]: """Get the identifying parameters."""