@ -1,3 +1,4 @@
import json
from typing import Any , Dict , List , Mapping , Optional
from langchain_core . callbacks import CallbackManagerForLLMRun
@ -7,8 +8,15 @@ from langchain_core.utils import get_from_dict_or_env
from langchain_community . llms . utils import enforce_stop_tokens
DEFAULT_REPO_ID = " gpt2 "
VALID_TASKS = ( " text2text-generation " , " text-generation " , " summarization " )
# key: task
# value: key in the output dictionary
VALID_TASKS_DICT = {
" translation " : " translation_text " ,
" summarization " : " summary_text " ,
" conversational " : " generated_text " ,
" text-generation " : " generated_text " ,
" text2text-generation " : " generated_text " ,
}
class HuggingFaceHub ( LLM ) :
@ -18,7 +26,8 @@ class HuggingFaceHub(LLM):
environment variable ` ` HUGGINGFACEHUB_API_TOKEN ` ` set with your API token , or pass
it as a named parameter to the constructor .
Only supports ` text - generation ` , ` text2text - generation ` and ` summarization ` for now .
Supports ` text - generation ` , ` text2text - generation ` , ` conversational ` , ` translation ` ,
and ` summarization ` .
Example :
. . code - block : : python
@ -28,11 +37,13 @@ class HuggingFaceHub(LLM):
"""
client : Any #: :meta private:
repo_id : str = DEFAULT_REPO_ID
""" Model name to use. """
repo_id : Optional [ str ] = None
""" Model name to use.
If not provided , the default model for the chosen task will be used . """
task : Optional [ str ] = None
""" Task to call the model with.
Should be a task that returns ` generated_text ` or ` summary_text ` . """
Should be a task that returns ` generated_text ` , ` summary_text ` ,
or ` translation_text ` . """
model_kwargs : Optional [ dict ] = None
""" Keyword arguments to pass to the model. """
@ -50,18 +61,27 @@ class HuggingFaceHub(LLM):
values , " huggingfacehub_api_token " , " HUGGINGFACEHUB_API_TOKEN "
)
try :
from huggingface_hub . inference_api import InferenceApi
from huggingface_hub import HfApi , InferenceClient
repo_id = values [ " repo_id " ]
client = Inference Api (
repo_id = repo_id ,
client = Inference Client (
model = repo_id ,
token = huggingfacehub_api_token ,
task = values . get ( " task " ) ,
)
if client . task not in VALID_TASKS :
if not values [ " task " ] :
if not repo_id :
raise ValueError (
" Must specify either `repo_id` or `task`, or both. "
)
# Use the recommended task for the chosen model
model_info = HfApi ( token = huggingfacehub_api_token ) . model_info (
repo_id = repo_id
)
values [ " task " ] = model_info . pipeline_tag
if values [ " task " ] not in VALID_TASKS_DICT :
raise ValueError (
f " Got invalid task { client . task } , "
f " currently only { VALID_TASKS } are supported "
f " Got invalid task { values[ ' task ' ] } , "
f " currently only { VALID_TASKS _DICT. keys ( ) } are supported "
)
values [ " client " ] = client
except ImportError :
@ -108,23 +128,20 @@ class HuggingFaceHub(LLM):
"""
_model_kwargs = self . model_kwargs or { }
params = { * * _model_kwargs , * * kwargs }
response = self . client ( inputs = prompt , params = params )
response = self . client . post (
json = { " inputs " : prompt , " params " : params } , task = self . task
)
response = json . loads ( response . decode ( ) )
if " error " in response :
raise ValueError ( f " Error raised by inference API: { response [ ' error ' ] } " )
if self . client . task == " text-generation " :
# Text generation sometimes return includes the starter text.
text = response [ 0 ] [ " generated_text " ]
if text . startswith ( prompt ) :
text = response [ 0 ] [ " generated_text " ] [ len ( prompt ) : ]
elif self . client . task == " text2text-generation " :
text = response [ 0 ] [ " generated_text " ]
elif self . client . task == " summarization " :
text = response [ 0 ] [ " summary_text " ]
response_key = VALID_TASKS_DICT [ self . task ] # type: ignore
if isinstance ( response , list ) :
text = response [ 0 ] [ response_key ]
else :
raise ValueError (
f " Got invalid task { self . client . task } , "
f " currently only { VALID_TASKS } are supported "
)
text = response [ response_key ]
if stop is not None :
# This is a bit hacky, but I can't figure out a better way to enforce
# stop tokens when making calls to huggingface_hub.