diff --git a/langchain/llms/huggingface_text_gen_inference.py b/langchain/llms/huggingface_text_gen_inference.py index a2489865fa..987db8421b 100644 --- a/langchain/llms/huggingface_text_gen_inference.py +++ b/langchain/llms/huggingface_text_gen_inference.py @@ -1,4 +1,5 @@ """Wrapper around Huggingface text generation inference API.""" +from functools import partial from typing import Any, Dict, List, Optional from pydantic import Extra, Field, root_validator @@ -36,6 +37,7 @@ class HuggingFaceTextGenInference(LLM): Example: .. code-block:: python + # Basic Example (no streaming) llm = HuggingFaceTextGenInference( inference_server_url = "http://localhost:8010/", max_new_tokens = 512, @@ -45,6 +47,25 @@ class HuggingFaceTextGenInference(LLM): temperature = 0.01, repetition_penalty = 1.03, ) + print(llm("What is Deep Learning?")) + + # Streaming response example + from langchain.callbacks import streaming_stdout + + callbacks = [streaming_stdout.StreamingStdOutCallbackHandler()] + llm = HuggingFaceTextGenInference( + inference_server_url = "http://localhost:8010/", + max_new_tokens = 512, + top_k = 10, + top_p = 0.95, + typical_p = 0.95, + temperature = 0.01, + repetition_penalty = 1.03, + callbacks = callbacks, + stream = True + ) + print(llm("What is Deep Learning?")) + """ max_new_tokens: int = 512 @@ -57,6 +78,7 @@ class HuggingFaceTextGenInference(LLM): seed: Optional[int] = None inference_server_url: str = "" timeout: int = 120 + stream: bool = False client: Any class Config: @@ -97,22 +119,52 @@ class HuggingFaceTextGenInference(LLM): else: stop += self.stop_sequences - res = self.client.generate( - prompt, - stop_sequences=stop, - max_new_tokens=self.max_new_tokens, - top_k=self.top_k, - top_p=self.top_p, - typical_p=self.typical_p, - temperature=self.temperature, - repetition_penalty=self.repetition_penalty, - seed=self.seed, - ) - # remove stop sequences from the end of the generated text - for stop_seq in stop: - if stop_seq in res.generated_text: - res.generated_text = res.generated_text[ - : res.generated_text.index(stop_seq) - ] - - return res.generated_text + if not self.stream: + res = self.client.generate( + prompt, + stop_sequences=stop, + max_new_tokens=self.max_new_tokens, + top_k=self.top_k, + top_p=self.top_p, + typical_p=self.typical_p, + temperature=self.temperature, + repetition_penalty=self.repetition_penalty, + seed=self.seed, + ) + # remove stop sequences from the end of the generated text + for stop_seq in stop: + if stop_seq in res.generated_text: + res.generated_text = res.generated_text[ + : res.generated_text.index(stop_seq) + ] + text = res.generated_text + else: + text_callback = None + if run_manager: + text_callback = partial( + run_manager.on_llm_new_token, verbose=self.verbose + ) + params = { + "stop_sequences": stop, + "max_new_tokens": self.max_new_tokens, + "top_k": self.top_k, + "top_p": self.top_p, + "typical_p": self.typical_p, + "temperature": self.temperature, + "repetition_penalty": self.repetition_penalty, + "seed": self.seed, + } + text = "" + for res in self.client.generate_stream(prompt, **params): + token = res.token + is_stop = False + for stop_seq in stop: + if stop_seq in token.text: + is_stop = True + break + if is_stop: + break + if not token.special: + if text_callback: + text_callback(token.text) + return text