@ -1,4 +1,5 @@
""" Wrapper around Huggingface text generation inference API. """
from functools import partial
from typing import Any , Dict , List , Optional
from pydantic import Extra , Field , root_validator
@ -36,6 +37,7 @@ class HuggingFaceTextGenInference(LLM):
Example :
. . code - block : : python
# Basic Example (no streaming)
llm = HuggingFaceTextGenInference (
inference_server_url = " http://localhost:8010/ " ,
max_new_tokens = 512 ,
@ -45,6 +47,25 @@ class HuggingFaceTextGenInference(LLM):
temperature = 0.01 ,
repetition_penalty = 1.03 ,
)
print ( llm ( " What is Deep Learning? " ) )
# Streaming response example
from langchain . callbacks import streaming_stdout
callbacks = [ streaming_stdout . StreamingStdOutCallbackHandler ( ) ]
llm = HuggingFaceTextGenInference (
inference_server_url = " http://localhost:8010/ " ,
max_new_tokens = 512 ,
top_k = 10 ,
top_p = 0.95 ,
typical_p = 0.95 ,
temperature = 0.01 ,
repetition_penalty = 1.03 ,
callbacks = callbacks ,
stream = True
)
print ( llm ( " What is Deep Learning? " ) )
"""
max_new_tokens : int = 512
@ -57,6 +78,7 @@ class HuggingFaceTextGenInference(LLM):
seed : Optional [ int ] = None
inference_server_url : str = " "
timeout : int = 120
stream : bool = False
client : Any
class Config :
@ -97,22 +119,52 @@ class HuggingFaceTextGenInference(LLM):
else :
stop + = self . stop_sequences
res = self . client . generate (
prompt ,
stop_sequences = stop ,
max_new_tokens = self . max_new_tokens ,
top_k = self . top_k ,
top_p = self . top_p ,
typical_p = self . typical_p ,
temperature = self . temperature ,
repetition_penalty = self . repetition_penalty ,
seed = self . seed ,
)
# remove stop sequences from the end of the generated text
for stop_seq in stop :
if stop_seq in res . generated_text :
res . generated_text = res . generated_text [
: res . generated_text . index ( stop_seq )
]
return res . generated_text
if not self . stream :
res = self . client . generate (
prompt ,
stop_sequences = stop ,
max_new_tokens = self . max_new_tokens ,
top_k = self . top_k ,
top_p = self . top_p ,
typical_p = self . typical_p ,
temperature = self . temperature ,
repetition_penalty = self . repetition_penalty ,
seed = self . seed ,
)
# remove stop sequences from the end of the generated text
for stop_seq in stop :
if stop_seq in res . generated_text :
res . generated_text = res . generated_text [
: res . generated_text . index ( stop_seq )
]
text = res . generated_text
else :
text_callback = None
if run_manager :
text_callback = partial (
run_manager . on_llm_new_token , verbose = self . verbose
)
params = {
" stop_sequences " : stop ,
" max_new_tokens " : self . max_new_tokens ,
" top_k " : self . top_k ,
" top_p " : self . top_p ,
" typical_p " : self . typical_p ,
" temperature " : self . temperature ,
" repetition_penalty " : self . repetition_penalty ,
" seed " : self . seed ,
}
text = " "
for res in self . client . generate_stream ( prompt , * * params ) :
token = res . token
is_stop = False
for stop_seq in stop :
if stop_seq in token . text :
is_stop = True
break
if is_stop :
break
if not token . special :
if text_callback :
text_callback ( token . text )
return text