forked from Archives/langchain
Add async versions of predict() and predict_messages() (#4867)
# Add async versions of predict() and predict_messages() #4615 introduced a unifying interface for "base" and "chat" LLM models via the new `predict()` and `predict_messages()` methods that allow both types of models to operate on string and message-based inputs, respectively. This PR adds async versions of the same (`apredict()` and `apredict_messages()`) that are identical except for their use of `agenerate()` in place of `generate()`, which means they repurpose all existing work on the async backend. ## Who can review? Community members can review the PR once tests pass. Tag maintainers/contributors who might be interested: @hwchase17 (follows his work on #4615) @agola11 (async) --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
9242998db1
commit
925dd3e59e
@ -58,6 +58,16 @@ class BaseLanguageModel(BaseModel, ABC):
|
||||
) -> BaseMessage:
|
||||
"""Predict message from messages."""
|
||||
|
||||
@abstractmethod
|
||||
async def apredict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str:
|
||||
"""Predict text from text."""
|
||||
|
||||
@abstractmethod
|
||||
async def apredict_messages(
|
||||
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None
|
||||
) -> BaseMessage:
|
||||
"""Predict message from messages."""
|
||||
|
||||
def get_token_ids(self, text: str) -> List[int]:
|
||||
"""Get the token present in the text."""
|
||||
return _get_token_ids_default_method(text)
|
||||
|
@ -183,6 +183,19 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
else:
|
||||
raise ValueError("Unexpected generation type")
|
||||
|
||||
async def _call_async(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
) -> BaseMessage:
|
||||
result = await self.agenerate([messages], stop=stop, callbacks=callbacks)
|
||||
generation = result.generations[0][0]
|
||||
if isinstance(generation, ChatGeneration):
|
||||
return generation.message
|
||||
else:
|
||||
raise ValueError("Unexpected generation type")
|
||||
|
||||
def call_as_llm(self, message: str, stop: Optional[List[str]] = None) -> str:
|
||||
return self.predict(message, stop=stop)
|
||||
|
||||
@ -203,6 +216,23 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
_stop = list(stop)
|
||||
return self(messages, stop=_stop)
|
||||
|
||||
async def apredict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str:
|
||||
if stop is None:
|
||||
_stop = None
|
||||
else:
|
||||
_stop = list(stop)
|
||||
result = await self._call_async([HumanMessage(content=text)], stop=_stop)
|
||||
return result.content
|
||||
|
||||
async def apredict_messages(
|
||||
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None
|
||||
) -> BaseMessage:
|
||||
if stop is None:
|
||||
_stop = None
|
||||
else:
|
||||
_stop = list(stop)
|
||||
return await self._call_async(messages, stop=_stop)
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
|
@ -299,6 +299,13 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
.text
|
||||
)
|
||||
|
||||
async def _call_async(
|
||||
self, prompt: str, stop: Optional[List[str]] = None, callbacks: Callbacks = None
|
||||
) -> str:
|
||||
"""Check Cache and run the LLM on the given prompt and input."""
|
||||
result = await self.agenerate([prompt], stop=stop, callbacks=callbacks)
|
||||
return result.generations[0][0].text
|
||||
|
||||
def predict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str:
|
||||
if stop is None:
|
||||
_stop = None
|
||||
@ -317,6 +324,24 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
content = self(text, stop=_stop)
|
||||
return AIMessage(content=content)
|
||||
|
||||
async def apredict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str:
|
||||
if stop is None:
|
||||
_stop = None
|
||||
else:
|
||||
_stop = list(stop)
|
||||
return await self._call_async(text, stop=_stop)
|
||||
|
||||
async def apredict_messages(
|
||||
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None
|
||||
) -> BaseMessage:
|
||||
text = get_buffer_string(messages)
|
||||
if stop is None:
|
||||
_stop = None
|
||||
else:
|
||||
_stop = list(stop)
|
||||
content = await self._call_async(text, stop=_stop)
|
||||
return AIMessage(content=content)
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
|
Loading…
Reference in New Issue
Block a user