mirror of
https://github.com/hwchase17/langchain
synced 2024-11-18 09:25:54 +00:00
Added support for streaming output response to HuggingFaceTextgenInference LLM class (#4633)
# Added support for streaming output response to HuggingFaceTextgenInference LLM class Current implementation does not support streaming output. Updated to incorporate this feature. Tagging @agola11 for visibility.
This commit is contained in:
parent
435b70da47
commit
c70ae562b4
@ -1,4 +1,5 @@
|
||||
"""Wrapper around Huggingface text generation inference API."""
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Extra, Field, root_validator
|
||||
@ -36,6 +37,7 @@ class HuggingFaceTextGenInference(LLM):
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
# Basic Example (no streaming)
|
||||
llm = HuggingFaceTextGenInference(
|
||||
inference_server_url = "http://localhost:8010/",
|
||||
max_new_tokens = 512,
|
||||
@ -45,6 +47,25 @@ class HuggingFaceTextGenInference(LLM):
|
||||
temperature = 0.01,
|
||||
repetition_penalty = 1.03,
|
||||
)
|
||||
print(llm("What is Deep Learning?"))
|
||||
|
||||
# Streaming response example
|
||||
from langchain.callbacks import streaming_stdout
|
||||
|
||||
callbacks = [streaming_stdout.StreamingStdOutCallbackHandler()]
|
||||
llm = HuggingFaceTextGenInference(
|
||||
inference_server_url = "http://localhost:8010/",
|
||||
max_new_tokens = 512,
|
||||
top_k = 10,
|
||||
top_p = 0.95,
|
||||
typical_p = 0.95,
|
||||
temperature = 0.01,
|
||||
repetition_penalty = 1.03,
|
||||
callbacks = callbacks,
|
||||
stream = True
|
||||
)
|
||||
print(llm("What is Deep Learning?"))
|
||||
|
||||
"""
|
||||
|
||||
max_new_tokens: int = 512
|
||||
@ -57,6 +78,7 @@ class HuggingFaceTextGenInference(LLM):
|
||||
seed: Optional[int] = None
|
||||
inference_server_url: str = ""
|
||||
timeout: int = 120
|
||||
stream: bool = False
|
||||
client: Any
|
||||
|
||||
class Config:
|
||||
@ -97,22 +119,52 @@ class HuggingFaceTextGenInference(LLM):
|
||||
else:
|
||||
stop += self.stop_sequences
|
||||
|
||||
res = self.client.generate(
|
||||
prompt,
|
||||
stop_sequences=stop,
|
||||
max_new_tokens=self.max_new_tokens,
|
||||
top_k=self.top_k,
|
||||
top_p=self.top_p,
|
||||
typical_p=self.typical_p,
|
||||
temperature=self.temperature,
|
||||
repetition_penalty=self.repetition_penalty,
|
||||
seed=self.seed,
|
||||
)
|
||||
# remove stop sequences from the end of the generated text
|
||||
for stop_seq in stop:
|
||||
if stop_seq in res.generated_text:
|
||||
res.generated_text = res.generated_text[
|
||||
: res.generated_text.index(stop_seq)
|
||||
]
|
||||
|
||||
return res.generated_text
|
||||
if not self.stream:
|
||||
res = self.client.generate(
|
||||
prompt,
|
||||
stop_sequences=stop,
|
||||
max_new_tokens=self.max_new_tokens,
|
||||
top_k=self.top_k,
|
||||
top_p=self.top_p,
|
||||
typical_p=self.typical_p,
|
||||
temperature=self.temperature,
|
||||
repetition_penalty=self.repetition_penalty,
|
||||
seed=self.seed,
|
||||
)
|
||||
# remove stop sequences from the end of the generated text
|
||||
for stop_seq in stop:
|
||||
if stop_seq in res.generated_text:
|
||||
res.generated_text = res.generated_text[
|
||||
: res.generated_text.index(stop_seq)
|
||||
]
|
||||
text = res.generated_text
|
||||
else:
|
||||
text_callback = None
|
||||
if run_manager:
|
||||
text_callback = partial(
|
||||
run_manager.on_llm_new_token, verbose=self.verbose
|
||||
)
|
||||
params = {
|
||||
"stop_sequences": stop,
|
||||
"max_new_tokens": self.max_new_tokens,
|
||||
"top_k": self.top_k,
|
||||
"top_p": self.top_p,
|
||||
"typical_p": self.typical_p,
|
||||
"temperature": self.temperature,
|
||||
"repetition_penalty": self.repetition_penalty,
|
||||
"seed": self.seed,
|
||||
}
|
||||
text = ""
|
||||
for res in self.client.generate_stream(prompt, **params):
|
||||
token = res.token
|
||||
is_stop = False
|
||||
for stop_seq in stop:
|
||||
if stop_seq in token.text:
|
||||
is_stop = True
|
||||
break
|
||||
if is_stop:
|
||||
break
|
||||
if not token.special:
|
||||
if text_callback:
|
||||
text_callback(token.text)
|
||||
return text
|
||||
|
Loading…
Reference in New Issue
Block a user