From 36b6b3cdf6abb7b956d1bd0e919564c95938f300 Mon Sep 17 00:00:00 2001 From: xloem <0xloem@gmail.com> Date: Mon, 23 Jan 2023 02:38:47 -0500 Subject: [PATCH] HuggingFacePipeline: Forward model_kwargs. (#696) Since the tokenizer and model are constructed manually, model_kwargs needs to be passed to their constructors. Additionally, the pipeline has a specific named parameter to pass these with, which can provide forward compatibility if they are used for something other than tokenizer or model construction. --- langchain/llms/huggingface_pipeline.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/langchain/llms/huggingface_pipeline.py b/langchain/llms/huggingface_pipeline.py index 40198a9b8d..22bd3661a5 100644 --- a/langchain/llms/huggingface_pipeline.py +++ b/langchain/llms/huggingface_pipeline.py @@ -68,19 +68,19 @@ class HuggingFacePipeline(LLM, BaseModel): ) from transformers import pipeline as hf_pipeline - tokenizer = AutoTokenizer.from_pretrained(model_id) + _model_kwargs = model_kwargs or {} + tokenizer = AutoTokenizer.from_pretrained(model_id, **_model_kwargs) if task == "text-generation": - model = AutoModelForCausalLM.from_pretrained(model_id) + model = AutoModelForCausalLM.from_pretrained(model_id, **_model_kwargs) elif task == "text2text-generation": - model = AutoModelForSeq2SeqLM.from_pretrained(model_id) + model = AutoModelForSeq2SeqLM.from_pretrained(model_id, **_model_kwargs) else: raise ValueError( f"Got invalid task {task}, " f"currently only {VALID_TASKS} are supported" ) - _model_kwargs = model_kwargs or {} pipeline = hf_pipeline( - task=task, model=model, tokenizer=tokenizer, **_model_kwargs + task=task, model=model, tokenizer=tokenizer, model_kwargs=_model_kwargs ) if pipeline.task not in VALID_TASKS: raise ValueError(