Added support for streaming output response to HuggingFaceTextgenInference LLM class (#4633)

# Added support for streaming output response to
HuggingFaceTextgenInference LLM class

Current implementation does not support streaming output. Updated to
incorporate this feature. Tagging @agola11 for visibility.
dynamic_agent_tools
Daniel Barker 1 year ago committed by GitHub
parent 435b70da47
commit c70ae562b4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,4 +1,5 @@
"""Wrapper around Huggingface text generation inference API.""" """Wrapper around Huggingface text generation inference API."""
from functools import partial
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from pydantic import Extra, Field, root_validator from pydantic import Extra, Field, root_validator
@ -36,6 +37,7 @@ class HuggingFaceTextGenInference(LLM):
Example: Example:
.. code-block:: python .. code-block:: python
# Basic Example (no streaming)
llm = HuggingFaceTextGenInference( llm = HuggingFaceTextGenInference(
inference_server_url = "http://localhost:8010/", inference_server_url = "http://localhost:8010/",
max_new_tokens = 512, max_new_tokens = 512,
@ -45,6 +47,25 @@ class HuggingFaceTextGenInference(LLM):
temperature = 0.01, temperature = 0.01,
repetition_penalty = 1.03, repetition_penalty = 1.03,
) )
print(llm("What is Deep Learning?"))
# Streaming response example
from langchain.callbacks import streaming_stdout
callbacks = [streaming_stdout.StreamingStdOutCallbackHandler()]
llm = HuggingFaceTextGenInference(
inference_server_url = "http://localhost:8010/",
max_new_tokens = 512,
top_k = 10,
top_p = 0.95,
typical_p = 0.95,
temperature = 0.01,
repetition_penalty = 1.03,
callbacks = callbacks,
stream = True
)
print(llm("What is Deep Learning?"))
""" """
max_new_tokens: int = 512 max_new_tokens: int = 512
@ -57,6 +78,7 @@ class HuggingFaceTextGenInference(LLM):
seed: Optional[int] = None seed: Optional[int] = None
inference_server_url: str = "" inference_server_url: str = ""
timeout: int = 120 timeout: int = 120
stream: bool = False
client: Any client: Any
class Config: class Config:
@ -97,22 +119,52 @@ class HuggingFaceTextGenInference(LLM):
else: else:
stop += self.stop_sequences stop += self.stop_sequences
res = self.client.generate( if not self.stream:
prompt, res = self.client.generate(
stop_sequences=stop, prompt,
max_new_tokens=self.max_new_tokens, stop_sequences=stop,
top_k=self.top_k, max_new_tokens=self.max_new_tokens,
top_p=self.top_p, top_k=self.top_k,
typical_p=self.typical_p, top_p=self.top_p,
temperature=self.temperature, typical_p=self.typical_p,
repetition_penalty=self.repetition_penalty, temperature=self.temperature,
seed=self.seed, repetition_penalty=self.repetition_penalty,
) seed=self.seed,
# remove stop sequences from the end of the generated text )
for stop_seq in stop: # remove stop sequences from the end of the generated text
if stop_seq in res.generated_text: for stop_seq in stop:
res.generated_text = res.generated_text[ if stop_seq in res.generated_text:
: res.generated_text.index(stop_seq) res.generated_text = res.generated_text[
] : res.generated_text.index(stop_seq)
]
return res.generated_text text = res.generated_text
else:
text_callback = None
if run_manager:
text_callback = partial(
run_manager.on_llm_new_token, verbose=self.verbose
)
params = {
"stop_sequences": stop,
"max_new_tokens": self.max_new_tokens,
"top_k": self.top_k,
"top_p": self.top_p,
"typical_p": self.typical_p,
"temperature": self.temperature,
"repetition_penalty": self.repetition_penalty,
"seed": self.seed,
}
text = ""
for res in self.client.generate_stream(prompt, **params):
token = res.token
is_stop = False
for stop_seq in stop:
if stop_seq in token.text:
is_stop = True
break
if is_stop:
break
if not token.special:
if text_callback:
text_callback(token.text)
return text

Loading…
Cancel
Save