diff --git a/langchain/llms/huggingface_pipeline.py b/langchain/llms/huggingface_pipeline.py index 40198a9b8d..22bd3661a5 100644 --- a/langchain/llms/huggingface_pipeline.py +++ b/langchain/llms/huggingface_pipeline.py @@ -68,19 +68,19 @@ class HuggingFacePipeline(LLM, BaseModel): ) from transformers import pipeline as hf_pipeline - tokenizer = AutoTokenizer.from_pretrained(model_id) + _model_kwargs = model_kwargs or {} + tokenizer = AutoTokenizer.from_pretrained(model_id, **_model_kwargs) if task == "text-generation": - model = AutoModelForCausalLM.from_pretrained(model_id) + model = AutoModelForCausalLM.from_pretrained(model_id, **_model_kwargs) elif task == "text2text-generation": - model = AutoModelForSeq2SeqLM.from_pretrained(model_id) + model = AutoModelForSeq2SeqLM.from_pretrained(model_id, **_model_kwargs) else: raise ValueError( f"Got invalid task {task}, " f"currently only {VALID_TASKS} are supported" ) - _model_kwargs = model_kwargs or {} pipeline = hf_pipeline( - task=task, model=model, tokenizer=tokenizer, **_model_kwargs + task=task, model=model, tokenizer=tokenizer, model_kwargs=_model_kwargs ) if pipeline.task not in VALID_TASKS: raise ValueError(