Add async versions of predict() and predict_messages() (#4867)

# Add async versions of predict() and predict_messages()

#4615 introduced a unifying interface for "base" and "chat" LLM models
via the new `predict()` and `predict_messages()` methods that allow both
types of models to operate on string and message-based inputs,
respectively.

This PR adds async versions of the same (`apredict()` and
`apredict_messages()`) that are identical except for their use of
`agenerate()` in place of `generate()`, which means they repurpose all
existing work on the async backend.


## Who can review?

Community members can review the PR once tests pass. Tag
maintainers/contributors who might be interested:
        @hwchase17 (follows his work on #4615)
        @agola11 (async)

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
Jeremiah Lowin 2023-05-23 20:22:49 -04:00 committed by GitHub
parent 9242998db1
commit 925dd3e59e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 65 additions and 0 deletions

View File

@ -58,6 +58,16 @@ class BaseLanguageModel(BaseModel, ABC):
) -> BaseMessage:
"""Predict message from messages."""
@abstractmethod
async def apredict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str:
"""Predict text from text."""
@abstractmethod
async def apredict_messages(
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None
) -> BaseMessage:
"""Predict message from messages."""
def get_token_ids(self, text: str) -> List[int]:
"""Get the token present in the text."""
return _get_token_ids_default_method(text)

View File

@ -183,6 +183,19 @@ class BaseChatModel(BaseLanguageModel, ABC):
else:
raise ValueError("Unexpected generation type")
async def _call_async(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
) -> BaseMessage:
result = await self.agenerate([messages], stop=stop, callbacks=callbacks)
generation = result.generations[0][0]
if isinstance(generation, ChatGeneration):
return generation.message
else:
raise ValueError("Unexpected generation type")
def call_as_llm(self, message: str, stop: Optional[List[str]] = None) -> str:
return self.predict(message, stop=stop)
@ -203,6 +216,23 @@ class BaseChatModel(BaseLanguageModel, ABC):
_stop = list(stop)
return self(messages, stop=_stop)
async def apredict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str:
if stop is None:
_stop = None
else:
_stop = list(stop)
result = await self._call_async([HumanMessage(content=text)], stop=_stop)
return result.content
async def apredict_messages(
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None
) -> BaseMessage:
if stop is None:
_stop = None
else:
_stop = list(stop)
return await self._call_async(messages, stop=_stop)
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""

View File

@ -299,6 +299,13 @@ class BaseLLM(BaseLanguageModel, ABC):
.text
)
async def _call_async(
self, prompt: str, stop: Optional[List[str]] = None, callbacks: Callbacks = None
) -> str:
"""Check Cache and run the LLM on the given prompt and input."""
result = await self.agenerate([prompt], stop=stop, callbacks=callbacks)
return result.generations[0][0].text
def predict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str:
if stop is None:
_stop = None
@ -317,6 +324,24 @@ class BaseLLM(BaseLanguageModel, ABC):
content = self(text, stop=_stop)
return AIMessage(content=content)
async def apredict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str:
if stop is None:
_stop = None
else:
_stop = list(stop)
return await self._call_async(text, stop=_stop)
async def apredict_messages(
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None
) -> BaseMessage:
text = get_buffer_string(messages)
if stop is None:
_stop = None
else:
_stop = list(stop)
content = await self._call_async(text, stop=_stop)
return AIMessage(content=content)
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""