HuggingFacePipeline: Forward model_kwargs. (#696)

Since the tokenizer and model are constructed manually, model_kwargs
needs to
be passed to their constructors. Additionally, the pipeline has a
specific
named parameter to pass these with, which can provide forward
compatibility if
they are used for something other than tokenizer or model construction.
pull/701/head
xloem 1 year ago committed by GitHub
parent 3a30e6daa8
commit 36b6b3cdf6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -68,19 +68,19 @@ class HuggingFacePipeline(LLM, BaseModel):
)
from transformers import pipeline as hf_pipeline
tokenizer = AutoTokenizer.from_pretrained(model_id)
_model_kwargs = model_kwargs or {}
tokenizer = AutoTokenizer.from_pretrained(model_id, **_model_kwargs)
if task == "text-generation":
model = AutoModelForCausalLM.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, **_model_kwargs)
elif task == "text2text-generation":
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(model_id, **_model_kwargs)
else:
raise ValueError(
f"Got invalid task {task}, "
f"currently only {VALID_TASKS} are supported"
)
_model_kwargs = model_kwargs or {}
pipeline = hf_pipeline(
task=task, model=model, tokenizer=tokenizer, **_model_kwargs
task=task, model=model, tokenizer=tokenizer, model_kwargs=_model_kwargs
)
if pipeline.task not in VALID_TASKS:
raise ValueError(

Loading…
Cancel
Save