From 0b743f005b118397d67a64113c08fd6ca4660d7e Mon Sep 17 00:00:00 2001 From: hsuyuming Date: Wed, 11 Oct 2023 20:09:03 -0600 Subject: [PATCH] Feature/enhance huggingfacepipeline to handle different return type (#11394) **Description:** Avoid huggingfacepipeline to truncate the response if user setup return_full_text as False within huggingface pipeline. **Dependencies:** : None **Tag maintainer:** Maybe @sam-h-bean ? --------- Co-authored-by: Bagatur --- .../langchain/llms/huggingface_pipeline.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/libs/langchain/langchain/llms/huggingface_pipeline.py b/libs/langchain/langchain/llms/huggingface_pipeline.py index 2b0d792bec..291fb5a69d 100644 --- a/libs/langchain/langchain/llms/huggingface_pipeline.py +++ b/libs/langchain/langchain/llms/huggingface_pipeline.py @@ -202,8 +202,23 @@ class HuggingFacePipeline(BaseLLM): response = response[0] if self.pipeline.task == "text-generation": - # Text generation return includes the starter text - text = response["generated_text"][len(batch_prompts[j]) :] + try: + from transformers.pipelines.text_generation import ReturnType + + remove_prompt = ( + self.pipeline._postprocess_params.get("return_type") + != ReturnType.NEW_TEXT + ) + except Exception as e: + logger.warning( + f"Unable to extract pipeline return_type. " + f"Received error:\n\n{e}" + ) + remove_prompt = True + if remove_prompt: + text = response["generated_text"][len(batch_prompts[j]) :] + else: + text = response["generated_text"] elif self.pipeline.task == "text2text-generation": text = response["generated_text"] elif self.pipeline.task == "summarization":