diff --git a/langchain/base_language.py b/langchain/base_language.py index cebf2ffa..1b5bd084 100644 --- a/langchain/base_language.py +++ b/langchain/base_language.py @@ -58,6 +58,16 @@ class BaseLanguageModel(BaseModel, ABC): ) -> BaseMessage: """Predict message from messages.""" + @abstractmethod + async def apredict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str: + """Predict text from text.""" + + @abstractmethod + async def apredict_messages( + self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None + ) -> BaseMessage: + """Predict message from messages.""" + def get_token_ids(self, text: str) -> List[int]: """Get the token present in the text.""" return _get_token_ids_default_method(text) diff --git a/langchain/chat_models/base.py b/langchain/chat_models/base.py index 806c1574..de2cdd06 100644 --- a/langchain/chat_models/base.py +++ b/langchain/chat_models/base.py @@ -183,6 +183,19 @@ class BaseChatModel(BaseLanguageModel, ABC): else: raise ValueError("Unexpected generation type") + async def _call_async( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, + ) -> BaseMessage: + result = await self.agenerate([messages], stop=stop, callbacks=callbacks) + generation = result.generations[0][0] + if isinstance(generation, ChatGeneration): + return generation.message + else: + raise ValueError("Unexpected generation type") + def call_as_llm(self, message: str, stop: Optional[List[str]] = None) -> str: return self.predict(message, stop=stop) @@ -203,6 +216,23 @@ class BaseChatModel(BaseLanguageModel, ABC): _stop = list(stop) return self(messages, stop=_stop) + async def apredict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str: + if stop is None: + _stop = None + else: + _stop = list(stop) + result = await self._call_async([HumanMessage(content=text)], stop=_stop) + return result.content + + async def apredict_messages( + self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None + ) -> BaseMessage: + if stop is None: + _stop = None + else: + _stop = list(stop) + return await self._call_async(messages, stop=_stop) + @property def _identifying_params(self) -> Mapping[str, Any]: """Get the identifying parameters.""" diff --git a/langchain/llms/base.py b/langchain/llms/base.py index 2e92c15d..21267a96 100644 --- a/langchain/llms/base.py +++ b/langchain/llms/base.py @@ -299,6 +299,13 @@ class BaseLLM(BaseLanguageModel, ABC): .text ) + async def _call_async( + self, prompt: str, stop: Optional[List[str]] = None, callbacks: Callbacks = None + ) -> str: + """Check Cache and run the LLM on the given prompt and input.""" + result = await self.agenerate([prompt], stop=stop, callbacks=callbacks) + return result.generations[0][0].text + def predict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str: if stop is None: _stop = None @@ -317,6 +324,24 @@ class BaseLLM(BaseLanguageModel, ABC): content = self(text, stop=_stop) return AIMessage(content=content) + async def apredict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str: + if stop is None: + _stop = None + else: + _stop = list(stop) + return await self._call_async(text, stop=_stop) + + async def apredict_messages( + self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None + ) -> BaseMessage: + text = get_buffer_string(messages) + if stop is None: + _stop = None + else: + _stop = list(stop) + content = await self._call_async(text, stop=_stop) + return AIMessage(content=content) + @property def _identifying_params(self) -> Mapping[str, Any]: """Get the identifying parameters."""