From 2ef5579eae206801c4b23dcbfe7171f4a9feabca Mon Sep 17 00:00:00 2001 From: Abdelsalam ElTamawy <40343437+solomspd@users.noreply.github.com> Date: Fri, 26 May 2023 03:54:52 +0300 Subject: [PATCH] Added pipline args to `HuggingFacePipeline.from_model_id` (#5268) The current `HuggingFacePipeline.from_model_id` does not allow passing of pipeline arguments to the transformer pipeline. This PR enables adding important pipeline parameters like setting `max_new_tokens` for example. Previous to this PR it would be necessary to manually create the pipeline through huggingface transformers then handing it to langchain. For example instead of this ```py model_id = "gpt2" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained(model_id) pipe = pipeline( "text-generation", model=model, tokenizer=tokenizer, max_new_tokens=10 ) hf = HuggingFacePipeline(pipeline=pipe) ``` You can write this ```py hf = HuggingFacePipeline.from_model_id( model_id="gpt2", task="text-generation", pipeline_kwargs={"max_new_tokens": 10} ) ``` Co-authored-by: Dev 2049 --- langchain/llms/huggingface_pipeline.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/langchain/llms/huggingface_pipeline.py b/langchain/llms/huggingface_pipeline.py index 9b8d604292..615dcd8e3f 100644 --- a/langchain/llms/huggingface_pipeline.py +++ b/langchain/llms/huggingface_pipeline.py @@ -28,7 +28,9 @@ class HuggingFacePipeline(LLM): from langchain.llms import HuggingFacePipeline hf = HuggingFacePipeline.from_model_id( - model_id="gpt2", task="text-generation" + model_id="gpt2", + task="text-generation", + pipeline_kwargs={"max_new_tokens": 10}, ) Example passing pipeline in directly: .. code-block:: python @@ -49,7 +51,9 @@ class HuggingFacePipeline(LLM): model_id: str = DEFAULT_MODEL_ID """Model name to use.""" model_kwargs: Optional[dict] = None - """Key word arguments to pass to the model.""" + """Key word arguments passed to the model.""" + pipeline_kwargs: Optional[dict] = None + """Key word arguments passed to the pipeline.""" class Config: """Configuration for this pydantic object.""" @@ -63,6 +67,7 @@ class HuggingFacePipeline(LLM): task: str, device: int = -1, model_kwargs: Optional[dict] = None, + pipeline_kwargs: Optional[dict] = None, **kwargs: Any, ) -> LLM: """Construct the pipeline object from model_id and task.""" @@ -119,12 +124,14 @@ class HuggingFacePipeline(LLM): _model_kwargs = { k: v for k, v in _model_kwargs.items() if k != "trust_remote_code" } + _pipeline_kwargs = pipeline_kwargs or {} pipeline = hf_pipeline( task=task, model=model, tokenizer=tokenizer, device=device, model_kwargs=_model_kwargs, + **_pipeline_kwargs, ) if pipeline.task not in VALID_TASKS: raise ValueError( @@ -135,6 +142,7 @@ class HuggingFacePipeline(LLM): pipeline=pipeline, model_id=model_id, model_kwargs=_model_kwargs, + pipeline_kwargs=_pipeline_kwargs, **kwargs, ) @@ -142,8 +150,9 @@ class HuggingFacePipeline(LLM): def _identifying_params(self) -> Mapping[str, Any]: """Get the identifying parameters.""" return { - **{"model_id": self.model_id}, - **{"model_kwargs": self.model_kwargs}, + "model_id": self.model_id, + "model_kwargs": self.model_kwargs, + "pipeline_kwargs": self.pipeline_kwargs, } @property