diff --git a/libs/langchain/langchain/llms/huggingface_pipeline.py b/libs/langchain/langchain/llms/huggingface_pipeline.py index 2b0d792bec..291fb5a69d 100644 --- a/libs/langchain/langchain/llms/huggingface_pipeline.py +++ b/libs/langchain/langchain/llms/huggingface_pipeline.py @@ -202,8 +202,23 @@ class HuggingFacePipeline(BaseLLM): response = response[0] if self.pipeline.task == "text-generation": - # Text generation return includes the starter text - text = response["generated_text"][len(batch_prompts[j]) :] + try: + from transformers.pipelines.text_generation import ReturnType + + remove_prompt = ( + self.pipeline._postprocess_params.get("return_type") + != ReturnType.NEW_TEXT + ) + except Exception as e: + logger.warning( + f"Unable to extract pipeline return_type. " + f"Received error:\n\n{e}" + ) + remove_prompt = True + if remove_prompt: + text = response["generated_text"][len(batch_prompts[j]) :] + else: + text = response["generated_text"] elif self.pipeline.task == "text2text-generation": text = response["generated_text"] elif self.pipeline.task == "summarization":