diff --git a/langchain/llms/huggingface_pipeline.py b/langchain/llms/huggingface_pipeline.py index 9b8d604292..615dcd8e3f 100644 --- a/langchain/llms/huggingface_pipeline.py +++ b/langchain/llms/huggingface_pipeline.py @@ -28,7 +28,9 @@ class HuggingFacePipeline(LLM): from langchain.llms import HuggingFacePipeline hf = HuggingFacePipeline.from_model_id( - model_id="gpt2", task="text-generation" + model_id="gpt2", + task="text-generation", + pipeline_kwargs={"max_new_tokens": 10}, ) Example passing pipeline in directly: .. code-block:: python @@ -49,7 +51,9 @@ class HuggingFacePipeline(LLM): model_id: str = DEFAULT_MODEL_ID """Model name to use.""" model_kwargs: Optional[dict] = None - """Key word arguments to pass to the model.""" + """Key word arguments passed to the model.""" + pipeline_kwargs: Optional[dict] = None + """Key word arguments passed to the pipeline.""" class Config: """Configuration for this pydantic object.""" @@ -63,6 +67,7 @@ class HuggingFacePipeline(LLM): task: str, device: int = -1, model_kwargs: Optional[dict] = None, + pipeline_kwargs: Optional[dict] = None, **kwargs: Any, ) -> LLM: """Construct the pipeline object from model_id and task.""" @@ -119,12 +124,14 @@ class HuggingFacePipeline(LLM): _model_kwargs = { k: v for k, v in _model_kwargs.items() if k != "trust_remote_code" } + _pipeline_kwargs = pipeline_kwargs or {} pipeline = hf_pipeline( task=task, model=model, tokenizer=tokenizer, device=device, model_kwargs=_model_kwargs, + **_pipeline_kwargs, ) if pipeline.task not in VALID_TASKS: raise ValueError( @@ -135,6 +142,7 @@ class HuggingFacePipeline(LLM): pipeline=pipeline, model_id=model_id, model_kwargs=_model_kwargs, + pipeline_kwargs=_pipeline_kwargs, **kwargs, ) @@ -142,8 +150,9 @@ class HuggingFacePipeline(LLM): def _identifying_params(self) -> Mapping[str, Any]: """Get the identifying parameters.""" return { - **{"model_id": self.model_id}, - **{"model_kwargs": self.model_kwargs}, + "model_id": self.model_id, + "model_kwargs": self.model_kwargs, + "pipeline_kwargs": self.pipeline_kwargs, } @property