From 2e024823d233c1ea1cfc8cd0be187c4d99cde292 Mon Sep 17 00:00:00 2001 From: Mircea Pasoi Date: Tue, 20 Jun 2023 23:12:24 -0700 Subject: [PATCH] Add async support for HuggingFaceTextGenInference (#6507) Adding support for async calls in `HuggingFaceTextGenInference` Co-authored-by: Dev 2049 --- .../llms/huggingface_text_gen_inference.py | 85 ++++++++++++++++++- 1 file changed, 83 insertions(+), 2 deletions(-) diff --git a/langchain/llms/huggingface_text_gen_inference.py b/langchain/llms/huggingface_text_gen_inference.py index 56b95636..10bc7361 100644 --- a/langchain/llms/huggingface_text_gen_inference.py +++ b/langchain/llms/huggingface_text_gen_inference.py @@ -4,7 +4,10 @@ from typing import Any, Dict, List, Optional from pydantic import Extra, Field, root_validator -from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) from langchain.llms.base import LLM @@ -26,10 +29,13 @@ class HuggingFaceTextGenInference(LLM): - seed: The seed to use when generating text. - inference_server_url: The URL of the inference server to use. - timeout: The timeout value in seconds to use while connecting to inference server. + - server_kwargs: The keyword arguments to pass to the inference server. - client: The client object used to communicate with the inference server. + - async_client: The async client object used to communicate with the server. Methods: - _call: Generates text based on a given prompt and stop sequences. + - _acall: Async generates text based on a given prompt and stop sequences. - _llm_type: Returns the type of LLM. """ @@ -78,8 +84,10 @@ class HuggingFaceTextGenInference(LLM): seed: Optional[int] = None inference_server_url: str = "" timeout: int = 120 + server_kwargs: Dict[str, Any] = Field(default_factory=dict) stream: bool = False client: Any + async_client: Any class Config: """Configuration for this pydantic object.""" @@ -94,7 +102,14 @@ class HuggingFaceTextGenInference(LLM): import text_generation values["client"] = text_generation.Client( - values["inference_server_url"], timeout=values["timeout"] + values["inference_server_url"], + timeout=values["timeout"], + **values["server_kwargs"], + ) + values["async_client"] = text_generation.AsyncClient( + values["inference_server_url"], + timeout=values["timeout"], + **values["server_kwargs"], ) except ImportError: raise ImportError( @@ -171,3 +186,69 @@ class HuggingFaceTextGenInference(LLM): text_callback(token.text) text += token.text return text + + async def _acall( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + if stop is None: + stop = self.stop_sequences + else: + stop += self.stop_sequences + + if not self.stream: + res = await self.async_client.generate( + prompt, + stop_sequences=stop, + max_new_tokens=self.max_new_tokens, + top_k=self.top_k, + top_p=self.top_p, + typical_p=self.typical_p, + temperature=self.temperature, + repetition_penalty=self.repetition_penalty, + seed=self.seed, + **kwargs, + ) + # remove stop sequences from the end of the generated text + for stop_seq in stop: + if stop_seq in res.generated_text: + res.generated_text = res.generated_text[ + : res.generated_text.index(stop_seq) + ] + text: str = res.generated_text + else: + text_callback = None + if run_manager: + text_callback = partial( + run_manager.on_llm_new_token, verbose=self.verbose + ) + params = { + **{ + "stop_sequences": stop, + "max_new_tokens": self.max_new_tokens, + "top_k": self.top_k, + "top_p": self.top_p, + "typical_p": self.typical_p, + "temperature": self.temperature, + "repetition_penalty": self.repetition_penalty, + "seed": self.seed, + }, + **kwargs, + } + text = "" + async for res in self.async_client.generate_stream(prompt, **params): + token = res.token + is_stop = False + for stop_seq in stop: + if stop_seq in token.text: + is_stop = True + break + if is_stop: + break + if not token.special: + if text_callback: + await text_callback(token.text) + return text