Add summarization task type for HuggingFace APIs (#4721)

# Add summarization task type for HuggingFace APIs

Add summarization task type for HuggingFace APIs.
This task type is described by [HuggingFace inference
API](https://huggingface.co/docs/api-inference/detailed_parameters#summarization-task)

My project utilizes LangChain to connect multiple LLMs, including
various HuggingFace models that support the summarization task.
Integrating this task type is highly convenient and beneficial.

Fixes #4720
This commit is contained in:
whuwxl 2023-05-16 07:26:17 +08:00 committed by GitHub
parent 580861e7f2
commit 3f0357f94a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 62 additions and 12 deletions

View File

@ -9,7 +9,7 @@ from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
VALID_TASKS = ("text2text-generation", "text-generation") VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
class HuggingFaceEndpoint(LLM): class HuggingFaceEndpoint(LLM):
@ -37,7 +37,8 @@ class HuggingFaceEndpoint(LLM):
endpoint_url: str = "" endpoint_url: str = ""
"""Endpoint URL to use.""" """Endpoint URL to use."""
task: Optional[str] = None task: Optional[str] = None
"""Task to call the model with. Should be a task that returns `generated_text`.""" """Task to call the model with.
Should be a task that returns `generated_text` or `summary_text`."""
model_kwargs: Optional[dict] = None model_kwargs: Optional[dict] = None
"""Key word arguments to pass to the model.""" """Key word arguments to pass to the model."""
@ -138,6 +139,8 @@ class HuggingFaceEndpoint(LLM):
text = generated_text[0]["generated_text"][len(prompt) :] text = generated_text[0]["generated_text"][len(prompt) :]
elif self.task == "text2text-generation": elif self.task == "text2text-generation":
text = generated_text[0]["generated_text"] text = generated_text[0]["generated_text"]
elif self.task == "summarization":
text = generated_text[0]["summary_text"]
else: else:
raise ValueError( raise ValueError(
f"Got invalid task {self.task}, " f"Got invalid task {self.task}, "

View File

@ -9,7 +9,7 @@ from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
DEFAULT_REPO_ID = "gpt2" DEFAULT_REPO_ID = "gpt2"
VALID_TASKS = ("text2text-generation", "text-generation") VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
class HuggingFaceHub(LLM): class HuggingFaceHub(LLM):
@ -19,7 +19,7 @@ class HuggingFaceHub(LLM):
environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass
it as a named parameter to the constructor. it as a named parameter to the constructor.
Only supports `text-generation` and `text2text-generation` for now. Only supports `text-generation`, `text2text-generation` and `summarization` for now.
Example: Example:
.. code-block:: python .. code-block:: python
@ -32,7 +32,8 @@ class HuggingFaceHub(LLM):
repo_id: str = DEFAULT_REPO_ID repo_id: str = DEFAULT_REPO_ID
"""Model name to use.""" """Model name to use."""
task: Optional[str] = None task: Optional[str] = None
"""Task to call the model with. Should be a task that returns `generated_text`.""" """Task to call the model with.
Should be a task that returns `generated_text` or `summary_text`."""
model_kwargs: Optional[dict] = None model_kwargs: Optional[dict] = None
"""Key word arguments to pass to the model.""" """Key word arguments to pass to the model."""
@ -114,6 +115,8 @@ class HuggingFaceHub(LLM):
text = response[0]["generated_text"][len(prompt) :] text = response[0]["generated_text"][len(prompt) :]
elif self.client.task == "text2text-generation": elif self.client.task == "text2text-generation":
text = response[0]["generated_text"] text = response[0]["generated_text"]
elif self.client.task == "summarization":
text = response[0]["summary_text"]
else: else:
raise ValueError( raise ValueError(
f"Got invalid task {self.client.task}, " f"Got invalid task {self.client.task}, "

View File

@ -11,7 +11,7 @@ from langchain.llms.utils import enforce_stop_tokens
DEFAULT_MODEL_ID = "gpt2" DEFAULT_MODEL_ID = "gpt2"
DEFAULT_TASK = "text-generation" DEFAULT_TASK = "text-generation"
VALID_TASKS = ("text2text-generation", "text-generation") VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -21,7 +21,7 @@ class HuggingFacePipeline(LLM):
To use, you should have the ``transformers`` python package installed. To use, you should have the ``transformers`` python package installed.
Only supports `text-generation` and `text2text-generation` for now. Only supports `text-generation`, `text2text-generation` and `summarization` for now.
Example using from_model_id: Example using from_model_id:
.. code-block:: python .. code-block:: python
@ -86,7 +86,7 @@ class HuggingFacePipeline(LLM):
try: try:
if task == "text-generation": if task == "text-generation":
model = AutoModelForCausalLM.from_pretrained(model_id, **_model_kwargs) model = AutoModelForCausalLM.from_pretrained(model_id, **_model_kwargs)
elif task == "text2text-generation": elif task in ("text2text-generation", "summarization"):
model = AutoModelForSeq2SeqLM.from_pretrained(model_id, **_model_kwargs) model = AutoModelForSeq2SeqLM.from_pretrained(model_id, **_model_kwargs)
else: else:
raise ValueError( raise ValueError(
@ -162,6 +162,8 @@ class HuggingFacePipeline(LLM):
text = response[0]["generated_text"][len(prompt) :] text = response[0]["generated_text"][len(prompt) :]
elif self.pipeline.task == "text2text-generation": elif self.pipeline.task == "text2text-generation":
text = response[0]["generated_text"] text = response[0]["generated_text"]
elif self.pipeline.task == "summarization":
text = response[0]["summary_text"]
else: else:
raise ValueError( raise ValueError(
f"Got invalid task {self.pipeline.task}, " f"Got invalid task {self.pipeline.task}, "

View File

@ -11,7 +11,7 @@ from langchain.llms.utils import enforce_stop_tokens
DEFAULT_MODEL_ID = "gpt2" DEFAULT_MODEL_ID = "gpt2"
DEFAULT_TASK = "text-generation" DEFAULT_TASK = "text-generation"
VALID_TASKS = ("text2text-generation", "text-generation") VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -35,6 +35,8 @@ def _generate_text(
text = response[0]["generated_text"][len(prompt) :] text = response[0]["generated_text"][len(prompt) :]
elif pipeline.task == "text2text-generation": elif pipeline.task == "text2text-generation":
text = response[0]["generated_text"] text = response[0]["generated_text"]
elif pipeline.task == "summarization":
text = response[0]["summary_text"]
else: else:
raise ValueError( raise ValueError(
f"Got invalid task {pipeline.task}, " f"Got invalid task {pipeline.task}, "
@ -64,7 +66,7 @@ def _load_transformer(
try: try:
if task == "text-generation": if task == "text-generation":
model = AutoModelForCausalLM.from_pretrained(model_id, **_model_kwargs) model = AutoModelForCausalLM.from_pretrained(model_id, **_model_kwargs)
elif task == "text2text-generation": elif task in ("text2text-generation", "summarization"):
model = AutoModelForSeq2SeqLM.from_pretrained(model_id, **_model_kwargs) model = AutoModelForSeq2SeqLM.from_pretrained(model_id, **_model_kwargs)
else: else:
raise ValueError( raise ValueError(
@ -119,7 +121,7 @@ class SelfHostedHuggingFaceLLM(SelfHostedPipeline):
To use, you should have the ``runhouse`` python package installed. To use, you should have the ``runhouse`` python package installed.
Only supports `text-generation` and `text2text-generation` for now. Only supports `text-generation`, `text2text-generation` and `summarization` for now.
Example using from_model_id: Example using from_model_id:
.. code-block:: python .. code-block:: python
@ -153,7 +155,8 @@ class SelfHostedHuggingFaceLLM(SelfHostedPipeline):
model_id: str = DEFAULT_MODEL_ID model_id: str = DEFAULT_MODEL_ID
"""Hugging Face model_id to load the model.""" """Hugging Face model_id to load the model."""
task: str = DEFAULT_TASK task: str = DEFAULT_TASK
"""Hugging Face task (either "text-generation" or "text2text-generation").""" """Hugging Face task ("text-generation", "text2text-generation" or
"summarization")."""
device: int = 0 device: int = 0
"""Device to use for inference. -1 for CPU, 0 for GPU, 1 for second GPU, etc.""" """Device to use for inference. -1 for CPU, 0 for GPU, 1 for second GPU, etc."""
model_kwargs: Optional[dict] = None model_kwargs: Optional[dict] = None

View File

@ -33,6 +33,16 @@ def test_huggingface_endpoint_text2text_generation() -> None:
assert output == "Albany" assert output == "Albany"
@unittest.skip(
"This test requires an inference endpoint. Tested with Hugging Face endpoints"
)
def test_huggingface_endpoint_summarization() -> None:
"""Test valid call to HuggingFace summarization model."""
llm = HuggingFaceEndpoint(endpoint_url="", task="summarization")
output = llm("Say foo:")
assert isinstance(output, str)
def test_huggingface_endpoint_call_error() -> None: def test_huggingface_endpoint_call_error() -> None:
"""Test valid call to HuggingFace that errors.""" """Test valid call to HuggingFace that errors."""
llm = HuggingFaceEndpoint(model_kwargs={"max_new_tokens": -1}) llm = HuggingFaceEndpoint(model_kwargs={"max_new_tokens": -1})

View File

@ -23,6 +23,13 @@ def test_huggingface_text2text_generation() -> None:
assert output == "Albany" assert output == "Albany"
def test_huggingface_summarization() -> None:
"""Test valid call to HuggingFace summarization model."""
llm = HuggingFaceHub(repo_id="facebook/bart-large-cnn")
output = llm("Say foo:")
assert isinstance(output, str)
def test_huggingface_call_error() -> None: def test_huggingface_call_error() -> None:
"""Test valid call to HuggingFace that errors.""" """Test valid call to HuggingFace that errors."""
llm = HuggingFaceHub(model_kwargs={"max_new_tokens": -1}) llm = HuggingFaceHub(model_kwargs={"max_new_tokens": -1})

View File

@ -27,6 +27,15 @@ def test_huggingface_pipeline_text2text_generation() -> None:
assert isinstance(output, str) assert isinstance(output, str)
def text_huggingface_pipeline_summarization() -> None:
"""Test valid call to HuggingFace summarization model."""
llm = HuggingFacePipeline.from_model_id(
model_id="facebook/bart-large-cnn", task="summarization"
)
output = llm("Say foo:")
assert isinstance(output, str)
def test_saving_loading_llm(tmp_path: Path) -> None: def test_saving_loading_llm(tmp_path: Path) -> None:
"""Test saving/loading an HuggingFaceHub LLM.""" """Test saving/loading an HuggingFaceHub LLM."""
llm = HuggingFacePipeline.from_model_id( llm = HuggingFacePipeline.from_model_id(

View File

@ -43,6 +43,19 @@ def test_self_hosted_huggingface_pipeline_text2text_generation() -> None:
assert isinstance(output, str) assert isinstance(output, str)
def test_self_hosted_huggingface_pipeline_summarization() -> None:
"""Test valid call to self-hosted HuggingFace summarization model."""
gpu = get_remote_instance()
llm = SelfHostedHuggingFaceLLM(
model_id="facebook/bart-large-cnn",
task="summarization",
hardware=gpu,
model_reqs=model_reqs,
)
output = llm("Say foo:")
assert isinstance(output, str)
def load_pipeline() -> Any: def load_pipeline() -> Any:
"""Load pipeline for testing.""" """Load pipeline for testing."""
model_id = "gpt2" model_id = "gpt2"