HuggingFacePipeline: Forward model_kwargs. (#696)

Since the tokenizer and model are constructed manually, model_kwargs
needs to
be passed to their constructors. Additionally, the pipeline has a
specific
named parameter to pass these with, which can provide forward
compatibility if
they are used for something other than tokenizer or model construction.
pull/701/head
xloem 2 years ago committed by GitHub
parent 3a30e6daa8
commit 36b6b3cdf6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -68,19 +68,19 @@ class HuggingFacePipeline(LLM, BaseModel):
) )
from transformers import pipeline as hf_pipeline from transformers import pipeline as hf_pipeline
tokenizer = AutoTokenizer.from_pretrained(model_id) _model_kwargs = model_kwargs or {}
tokenizer = AutoTokenizer.from_pretrained(model_id, **_model_kwargs)
if task == "text-generation": if task == "text-generation":
model = AutoModelForCausalLM.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained(model_id, **_model_kwargs)
elif task == "text2text-generation": elif task == "text2text-generation":
model = AutoModelForSeq2SeqLM.from_pretrained(model_id) model = AutoModelForSeq2SeqLM.from_pretrained(model_id, **_model_kwargs)
else: else:
raise ValueError( raise ValueError(
f"Got invalid task {task}, " f"Got invalid task {task}, "
f"currently only {VALID_TASKS} are supported" f"currently only {VALID_TASKS} are supported"
) )
_model_kwargs = model_kwargs or {}
pipeline = hf_pipeline( pipeline = hf_pipeline(
task=task, model=model, tokenizer=tokenizer, **_model_kwargs task=task, model=model, tokenizer=tokenizer, model_kwargs=_model_kwargs
) )
if pipeline.task not in VALID_TASKS: if pipeline.task not in VALID_TASKS:
raise ValueError( raise ValueError(

Loading…
Cancel
Save