diff --git a/langchain/llms/huggingface_endpoint.py b/langchain/llms/huggingface_endpoint.py index cc549da2..03b0467b 100644 --- a/langchain/llms/huggingface_endpoint.py +++ b/langchain/llms/huggingface_endpoint.py @@ -9,7 +9,7 @@ from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens from langchain.utils import get_from_dict_or_env -VALID_TASKS = ("text2text-generation", "text-generation") +VALID_TASKS = ("text2text-generation", "text-generation", "summarization") class HuggingFaceEndpoint(LLM): @@ -37,7 +37,8 @@ class HuggingFaceEndpoint(LLM): endpoint_url: str = "" """Endpoint URL to use.""" task: Optional[str] = None - """Task to call the model with. Should be a task that returns `generated_text`.""" + """Task to call the model with. + Should be a task that returns `generated_text` or `summary_text`.""" model_kwargs: Optional[dict] = None """Key word arguments to pass to the model.""" @@ -138,6 +139,8 @@ class HuggingFaceEndpoint(LLM): text = generated_text[0]["generated_text"][len(prompt) :] elif self.task == "text2text-generation": text = generated_text[0]["generated_text"] + elif self.task == "summarization": + text = generated_text[0]["summary_text"] else: raise ValueError( f"Got invalid task {self.task}, " diff --git a/langchain/llms/huggingface_hub.py b/langchain/llms/huggingface_hub.py index 2838b858..5cd7e242 100644 --- a/langchain/llms/huggingface_hub.py +++ b/langchain/llms/huggingface_hub.py @@ -9,7 +9,7 @@ from langchain.llms.utils import enforce_stop_tokens from langchain.utils import get_from_dict_or_env DEFAULT_REPO_ID = "gpt2" -VALID_TASKS = ("text2text-generation", "text-generation") +VALID_TASKS = ("text2text-generation", "text-generation", "summarization") class HuggingFaceHub(LLM): @@ -19,7 +19,7 @@ class HuggingFaceHub(LLM): environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass it as a named parameter to the constructor. - Only supports `text-generation` and `text2text-generation` for now. + Only supports `text-generation`, `text2text-generation` and `summarization` for now. Example: .. code-block:: python @@ -32,7 +32,8 @@ class HuggingFaceHub(LLM): repo_id: str = DEFAULT_REPO_ID """Model name to use.""" task: Optional[str] = None - """Task to call the model with. Should be a task that returns `generated_text`.""" + """Task to call the model with. + Should be a task that returns `generated_text` or `summary_text`.""" model_kwargs: Optional[dict] = None """Key word arguments to pass to the model.""" @@ -114,6 +115,8 @@ class HuggingFaceHub(LLM): text = response[0]["generated_text"][len(prompt) :] elif self.client.task == "text2text-generation": text = response[0]["generated_text"] + elif self.client.task == "summarization": + text = response[0]["summary_text"] else: raise ValueError( f"Got invalid task {self.client.task}, " diff --git a/langchain/llms/huggingface_pipeline.py b/langchain/llms/huggingface_pipeline.py index 9d63aa04..9b8d6042 100644 --- a/langchain/llms/huggingface_pipeline.py +++ b/langchain/llms/huggingface_pipeline.py @@ -11,7 +11,7 @@ from langchain.llms.utils import enforce_stop_tokens DEFAULT_MODEL_ID = "gpt2" DEFAULT_TASK = "text-generation" -VALID_TASKS = ("text2text-generation", "text-generation") +VALID_TASKS = ("text2text-generation", "text-generation", "summarization") logger = logging.getLogger(__name__) @@ -21,7 +21,7 @@ class HuggingFacePipeline(LLM): To use, you should have the ``transformers`` python package installed. - Only supports `text-generation` and `text2text-generation` for now. + Only supports `text-generation`, `text2text-generation` and `summarization` for now. Example using from_model_id: .. code-block:: python @@ -86,7 +86,7 @@ class HuggingFacePipeline(LLM): try: if task == "text-generation": model = AutoModelForCausalLM.from_pretrained(model_id, **_model_kwargs) - elif task == "text2text-generation": + elif task in ("text2text-generation", "summarization"): model = AutoModelForSeq2SeqLM.from_pretrained(model_id, **_model_kwargs) else: raise ValueError( @@ -162,6 +162,8 @@ class HuggingFacePipeline(LLM): text = response[0]["generated_text"][len(prompt) :] elif self.pipeline.task == "text2text-generation": text = response[0]["generated_text"] + elif self.pipeline.task == "summarization": + text = response[0]["summary_text"] else: raise ValueError( f"Got invalid task {self.pipeline.task}, " diff --git a/langchain/llms/self_hosted_hugging_face.py b/langchain/llms/self_hosted_hugging_face.py index 49bd8536..1ef685a5 100644 --- a/langchain/llms/self_hosted_hugging_face.py +++ b/langchain/llms/self_hosted_hugging_face.py @@ -11,7 +11,7 @@ from langchain.llms.utils import enforce_stop_tokens DEFAULT_MODEL_ID = "gpt2" DEFAULT_TASK = "text-generation" -VALID_TASKS = ("text2text-generation", "text-generation") +VALID_TASKS = ("text2text-generation", "text-generation", "summarization") logger = logging.getLogger(__name__) @@ -35,6 +35,8 @@ def _generate_text( text = response[0]["generated_text"][len(prompt) :] elif pipeline.task == "text2text-generation": text = response[0]["generated_text"] + elif pipeline.task == "summarization": + text = response[0]["summary_text"] else: raise ValueError( f"Got invalid task {pipeline.task}, " @@ -64,7 +66,7 @@ def _load_transformer( try: if task == "text-generation": model = AutoModelForCausalLM.from_pretrained(model_id, **_model_kwargs) - elif task == "text2text-generation": + elif task in ("text2text-generation", "summarization"): model = AutoModelForSeq2SeqLM.from_pretrained(model_id, **_model_kwargs) else: raise ValueError( @@ -119,7 +121,7 @@ class SelfHostedHuggingFaceLLM(SelfHostedPipeline): To use, you should have the ``runhouse`` python package installed. - Only supports `text-generation` and `text2text-generation` for now. + Only supports `text-generation`, `text2text-generation` and `summarization` for now. Example using from_model_id: .. code-block:: python @@ -153,7 +155,8 @@ class SelfHostedHuggingFaceLLM(SelfHostedPipeline): model_id: str = DEFAULT_MODEL_ID """Hugging Face model_id to load the model.""" task: str = DEFAULT_TASK - """Hugging Face task (either "text-generation" or "text2text-generation").""" + """Hugging Face task ("text-generation", "text2text-generation" or + "summarization").""" device: int = 0 """Device to use for inference. -1 for CPU, 0 for GPU, 1 for second GPU, etc.""" model_kwargs: Optional[dict] = None diff --git a/tests/integration_tests/llms/test_huggingface_endpoint.py b/tests/integration_tests/llms/test_huggingface_endpoint.py index 61639669..e50a7179 100644 --- a/tests/integration_tests/llms/test_huggingface_endpoint.py +++ b/tests/integration_tests/llms/test_huggingface_endpoint.py @@ -33,6 +33,16 @@ def test_huggingface_endpoint_text2text_generation() -> None: assert output == "Albany" +@unittest.skip( + "This test requires an inference endpoint. Tested with Hugging Face endpoints" +) +def test_huggingface_endpoint_summarization() -> None: + """Test valid call to HuggingFace summarization model.""" + llm = HuggingFaceEndpoint(endpoint_url="", task="summarization") + output = llm("Say foo:") + assert isinstance(output, str) + + def test_huggingface_endpoint_call_error() -> None: """Test valid call to HuggingFace that errors.""" llm = HuggingFaceEndpoint(model_kwargs={"max_new_tokens": -1}) diff --git a/tests/integration_tests/llms/test_huggingface_hub.py b/tests/integration_tests/llms/test_huggingface_hub.py index df0b4416..2d83f4b6 100644 --- a/tests/integration_tests/llms/test_huggingface_hub.py +++ b/tests/integration_tests/llms/test_huggingface_hub.py @@ -23,6 +23,13 @@ def test_huggingface_text2text_generation() -> None: assert output == "Albany" +def test_huggingface_summarization() -> None: + """Test valid call to HuggingFace summarization model.""" + llm = HuggingFaceHub(repo_id="facebook/bart-large-cnn") + output = llm("Say foo:") + assert isinstance(output, str) + + def test_huggingface_call_error() -> None: """Test valid call to HuggingFace that errors.""" llm = HuggingFaceHub(model_kwargs={"max_new_tokens": -1}) diff --git a/tests/integration_tests/llms/test_huggingface_pipeline.py b/tests/integration_tests/llms/test_huggingface_pipeline.py index b224a0a9..8dfc9c8b 100644 --- a/tests/integration_tests/llms/test_huggingface_pipeline.py +++ b/tests/integration_tests/llms/test_huggingface_pipeline.py @@ -27,6 +27,15 @@ def test_huggingface_pipeline_text2text_generation() -> None: assert isinstance(output, str) +def text_huggingface_pipeline_summarization() -> None: + """Test valid call to HuggingFace summarization model.""" + llm = HuggingFacePipeline.from_model_id( + model_id="facebook/bart-large-cnn", task="summarization" + ) + output = llm("Say foo:") + assert isinstance(output, str) + + def test_saving_loading_llm(tmp_path: Path) -> None: """Test saving/loading an HuggingFaceHub LLM.""" llm = HuggingFacePipeline.from_model_id( diff --git a/tests/integration_tests/llms/test_self_hosted_llm.py b/tests/integration_tests/llms/test_self_hosted_llm.py index 0dc753ab..73457a01 100644 --- a/tests/integration_tests/llms/test_self_hosted_llm.py +++ b/tests/integration_tests/llms/test_self_hosted_llm.py @@ -43,6 +43,19 @@ def test_self_hosted_huggingface_pipeline_text2text_generation() -> None: assert isinstance(output, str) +def test_self_hosted_huggingface_pipeline_summarization() -> None: + """Test valid call to self-hosted HuggingFace summarization model.""" + gpu = get_remote_instance() + llm = SelfHostedHuggingFaceLLM( + model_id="facebook/bart-large-cnn", + task="summarization", + hardware=gpu, + model_reqs=model_reqs, + ) + output = llm("Say foo:") + assert isinstance(output, str) + + def load_pipeline() -> Any: """Load pipeline for testing.""" model_id = "gpt2"