@ -11,7 +11,7 @@ from langchain.llms.utils import enforce_stop_tokens
DEFAULT_MODEL_ID = " gpt2 "
DEFAULT_MODEL_ID = " gpt2 "
DEFAULT_TASK = " text-generation "
DEFAULT_TASK = " text-generation "
VALID_TASKS = ( " text2text-generation " , " text-generation " )
VALID_TASKS = ( " text2text-generation " , " text-generation " , " summarization " )
logger = logging . getLogger ( __name__ )
logger = logging . getLogger ( __name__ )
@ -35,6 +35,8 @@ def _generate_text(
text = response [ 0 ] [ " generated_text " ] [ len ( prompt ) : ]
text = response [ 0 ] [ " generated_text " ] [ len ( prompt ) : ]
elif pipeline . task == " text2text-generation " :
elif pipeline . task == " text2text-generation " :
text = response [ 0 ] [ " generated_text " ]
text = response [ 0 ] [ " generated_text " ]
elif pipeline . task == " summarization " :
text = response [ 0 ] [ " summary_text " ]
else :
else :
raise ValueError (
raise ValueError (
f " Got invalid task { pipeline . task } , "
f " Got invalid task { pipeline . task } , "
@ -64,7 +66,7 @@ def _load_transformer(
try :
try :
if task == " text-generation " :
if task == " text-generation " :
model = AutoModelForCausalLM . from_pretrained ( model_id , * * _model_kwargs )
model = AutoModelForCausalLM . from_pretrained ( model_id , * * _model_kwargs )
elif task == " text2text-generation " :
elif task in ( " text2text-generation " , " summarization " ) :
model = AutoModelForSeq2SeqLM . from_pretrained ( model_id , * * _model_kwargs )
model = AutoModelForSeq2SeqLM . from_pretrained ( model_id , * * _model_kwargs )
else :
else :
raise ValueError (
raise ValueError (
@ -119,7 +121,7 @@ class SelfHostedHuggingFaceLLM(SelfHostedPipeline):
To use , you should have the ` ` runhouse ` ` python package installed .
To use , you should have the ` ` runhouse ` ` python package installed .
Only supports ` text - generation ` and ` text2text - gener ation` for now .
Only supports ` text - generation ` , ` text2text - generation ` and ` summariz ation` for now .
Example using from_model_id :
Example using from_model_id :
. . code - block : : python
. . code - block : : python
@ -153,7 +155,8 @@ class SelfHostedHuggingFaceLLM(SelfHostedPipeline):
model_id : str = DEFAULT_MODEL_ID
model_id : str = DEFAULT_MODEL_ID
""" Hugging Face model_id to load the model. """
""" Hugging Face model_id to load the model. """
task : str = DEFAULT_TASK
task : str = DEFAULT_TASK
""" Hugging Face task (either " text-generation " or " text2text-generation " ). """
""" Hugging Face task ( " text-generation " , " text2text-generation " or
" summarization " ) . """
device : int = 0
device : int = 0
""" Device to use for inference. -1 for CPU, 0 for GPU, 1 for second GPU, etc. """
""" Device to use for inference. -1 for CPU, 0 for GPU, 1 for second GPU, etc. """
model_kwargs : Optional [ dict ] = None
model_kwargs : Optional [ dict ] = None