forked from Archives/langchain
Add summarization task type for HuggingFace APIs (#4721)
# Add summarization task type for HuggingFace APIs Add summarization task type for HuggingFace APIs. This task type is described by [HuggingFace inference API](https://huggingface.co/docs/api-inference/detailed_parameters#summarization-task) My project utilizes LangChain to connect multiple LLMs, including various HuggingFace models that support the summarization task. Integrating this task type is highly convenient and beneficial. Fixes #4720
This commit is contained in:
parent
580861e7f2
commit
3f0357f94a
@ -9,7 +9,7 @@ from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
VALID_TASKS = ("text2text-generation", "text-generation")
|
||||
VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
|
||||
|
||||
|
||||
class HuggingFaceEndpoint(LLM):
|
||||
@ -37,7 +37,8 @@ class HuggingFaceEndpoint(LLM):
|
||||
endpoint_url: str = ""
|
||||
"""Endpoint URL to use."""
|
||||
task: Optional[str] = None
|
||||
"""Task to call the model with. Should be a task that returns `generated_text`."""
|
||||
"""Task to call the model with.
|
||||
Should be a task that returns `generated_text` or `summary_text`."""
|
||||
model_kwargs: Optional[dict] = None
|
||||
"""Key word arguments to pass to the model."""
|
||||
|
||||
@ -138,6 +139,8 @@ class HuggingFaceEndpoint(LLM):
|
||||
text = generated_text[0]["generated_text"][len(prompt) :]
|
||||
elif self.task == "text2text-generation":
|
||||
text = generated_text[0]["generated_text"]
|
||||
elif self.task == "summarization":
|
||||
text = generated_text[0]["summary_text"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Got invalid task {self.task}, "
|
||||
|
@ -9,7 +9,7 @@ from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
DEFAULT_REPO_ID = "gpt2"
|
||||
VALID_TASKS = ("text2text-generation", "text-generation")
|
||||
VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
|
||||
|
||||
|
||||
class HuggingFaceHub(LLM):
|
||||
@ -19,7 +19,7 @@ class HuggingFaceHub(LLM):
|
||||
environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass
|
||||
it as a named parameter to the constructor.
|
||||
|
||||
Only supports `text-generation` and `text2text-generation` for now.
|
||||
Only supports `text-generation`, `text2text-generation` and `summarization` for now.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
@ -32,7 +32,8 @@ class HuggingFaceHub(LLM):
|
||||
repo_id: str = DEFAULT_REPO_ID
|
||||
"""Model name to use."""
|
||||
task: Optional[str] = None
|
||||
"""Task to call the model with. Should be a task that returns `generated_text`."""
|
||||
"""Task to call the model with.
|
||||
Should be a task that returns `generated_text` or `summary_text`."""
|
||||
model_kwargs: Optional[dict] = None
|
||||
"""Key word arguments to pass to the model."""
|
||||
|
||||
@ -114,6 +115,8 @@ class HuggingFaceHub(LLM):
|
||||
text = response[0]["generated_text"][len(prompt) :]
|
||||
elif self.client.task == "text2text-generation":
|
||||
text = response[0]["generated_text"]
|
||||
elif self.client.task == "summarization":
|
||||
text = response[0]["summary_text"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Got invalid task {self.client.task}, "
|
||||
|
@ -11,7 +11,7 @@ from langchain.llms.utils import enforce_stop_tokens
|
||||
|
||||
DEFAULT_MODEL_ID = "gpt2"
|
||||
DEFAULT_TASK = "text-generation"
|
||||
VALID_TASKS = ("text2text-generation", "text-generation")
|
||||
VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -21,7 +21,7 @@ class HuggingFacePipeline(LLM):
|
||||
|
||||
To use, you should have the ``transformers`` python package installed.
|
||||
|
||||
Only supports `text-generation` and `text2text-generation` for now.
|
||||
Only supports `text-generation`, `text2text-generation` and `summarization` for now.
|
||||
|
||||
Example using from_model_id:
|
||||
.. code-block:: python
|
||||
@ -86,7 +86,7 @@ class HuggingFacePipeline(LLM):
|
||||
try:
|
||||
if task == "text-generation":
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, **_model_kwargs)
|
||||
elif task == "text2text-generation":
|
||||
elif task in ("text2text-generation", "summarization"):
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(model_id, **_model_kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
@ -162,6 +162,8 @@ class HuggingFacePipeline(LLM):
|
||||
text = response[0]["generated_text"][len(prompt) :]
|
||||
elif self.pipeline.task == "text2text-generation":
|
||||
text = response[0]["generated_text"]
|
||||
elif self.pipeline.task == "summarization":
|
||||
text = response[0]["summary_text"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Got invalid task {self.pipeline.task}, "
|
||||
|
@ -11,7 +11,7 @@ from langchain.llms.utils import enforce_stop_tokens
|
||||
|
||||
DEFAULT_MODEL_ID = "gpt2"
|
||||
DEFAULT_TASK = "text-generation"
|
||||
VALID_TASKS = ("text2text-generation", "text-generation")
|
||||
VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -35,6 +35,8 @@ def _generate_text(
|
||||
text = response[0]["generated_text"][len(prompt) :]
|
||||
elif pipeline.task == "text2text-generation":
|
||||
text = response[0]["generated_text"]
|
||||
elif pipeline.task == "summarization":
|
||||
text = response[0]["summary_text"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Got invalid task {pipeline.task}, "
|
||||
@ -64,7 +66,7 @@ def _load_transformer(
|
||||
try:
|
||||
if task == "text-generation":
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, **_model_kwargs)
|
||||
elif task == "text2text-generation":
|
||||
elif task in ("text2text-generation", "summarization"):
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(model_id, **_model_kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
@ -119,7 +121,7 @@ class SelfHostedHuggingFaceLLM(SelfHostedPipeline):
|
||||
|
||||
To use, you should have the ``runhouse`` python package installed.
|
||||
|
||||
Only supports `text-generation` and `text2text-generation` for now.
|
||||
Only supports `text-generation`, `text2text-generation` and `summarization` for now.
|
||||
|
||||
Example using from_model_id:
|
||||
.. code-block:: python
|
||||
@ -153,7 +155,8 @@ class SelfHostedHuggingFaceLLM(SelfHostedPipeline):
|
||||
model_id: str = DEFAULT_MODEL_ID
|
||||
"""Hugging Face model_id to load the model."""
|
||||
task: str = DEFAULT_TASK
|
||||
"""Hugging Face task (either "text-generation" or "text2text-generation")."""
|
||||
"""Hugging Face task ("text-generation", "text2text-generation" or
|
||||
"summarization")."""
|
||||
device: int = 0
|
||||
"""Device to use for inference. -1 for CPU, 0 for GPU, 1 for second GPU, etc."""
|
||||
model_kwargs: Optional[dict] = None
|
||||
|
@ -33,6 +33,16 @@ def test_huggingface_endpoint_text2text_generation() -> None:
|
||||
assert output == "Albany"
|
||||
|
||||
|
||||
@unittest.skip(
|
||||
"This test requires an inference endpoint. Tested with Hugging Face endpoints"
|
||||
)
|
||||
def test_huggingface_endpoint_summarization() -> None:
|
||||
"""Test valid call to HuggingFace summarization model."""
|
||||
llm = HuggingFaceEndpoint(endpoint_url="", task="summarization")
|
||||
output = llm("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_huggingface_endpoint_call_error() -> None:
|
||||
"""Test valid call to HuggingFace that errors."""
|
||||
llm = HuggingFaceEndpoint(model_kwargs={"max_new_tokens": -1})
|
||||
|
@ -23,6 +23,13 @@ def test_huggingface_text2text_generation() -> None:
|
||||
assert output == "Albany"
|
||||
|
||||
|
||||
def test_huggingface_summarization() -> None:
|
||||
"""Test valid call to HuggingFace summarization model."""
|
||||
llm = HuggingFaceHub(repo_id="facebook/bart-large-cnn")
|
||||
output = llm("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_huggingface_call_error() -> None:
|
||||
"""Test valid call to HuggingFace that errors."""
|
||||
llm = HuggingFaceHub(model_kwargs={"max_new_tokens": -1})
|
||||
|
@ -27,6 +27,15 @@ def test_huggingface_pipeline_text2text_generation() -> None:
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def text_huggingface_pipeline_summarization() -> None:
|
||||
"""Test valid call to HuggingFace summarization model."""
|
||||
llm = HuggingFacePipeline.from_model_id(
|
||||
model_id="facebook/bart-large-cnn", task="summarization"
|
||||
)
|
||||
output = llm("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_saving_loading_llm(tmp_path: Path) -> None:
|
||||
"""Test saving/loading an HuggingFaceHub LLM."""
|
||||
llm = HuggingFacePipeline.from_model_id(
|
||||
|
@ -43,6 +43,19 @@ def test_self_hosted_huggingface_pipeline_text2text_generation() -> None:
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_self_hosted_huggingface_pipeline_summarization() -> None:
|
||||
"""Test valid call to self-hosted HuggingFace summarization model."""
|
||||
gpu = get_remote_instance()
|
||||
llm = SelfHostedHuggingFaceLLM(
|
||||
model_id="facebook/bart-large-cnn",
|
||||
task="summarization",
|
||||
hardware=gpu,
|
||||
model_reqs=model_reqs,
|
||||
)
|
||||
output = llm("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def load_pipeline() -> Any:
|
||||
"""Load pipeline for testing."""
|
||||
model_id = "gpt2"
|
||||
|
Loading…
Reference in New Issue
Block a user