forked from Archives/langchain
Add summarization task type for HuggingFace APIs (#4721)
# Add summarization task type for HuggingFace APIs Add summarization task type for HuggingFace APIs. This task type is described by [HuggingFace inference API](https://huggingface.co/docs/api-inference/detailed_parameters#summarization-task) My project utilizes LangChain to connect multiple LLMs, including various HuggingFace models that support the summarization task. Integrating this task type is highly convenient and beneficial. Fixes #4720
This commit is contained in:
parent
580861e7f2
commit
3f0357f94a
@ -9,7 +9,7 @@ from langchain.llms.base import LLM
|
|||||||
from langchain.llms.utils import enforce_stop_tokens
|
from langchain.llms.utils import enforce_stop_tokens
|
||||||
from langchain.utils import get_from_dict_or_env
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
VALID_TASKS = ("text2text-generation", "text-generation")
|
VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
|
||||||
|
|
||||||
|
|
||||||
class HuggingFaceEndpoint(LLM):
|
class HuggingFaceEndpoint(LLM):
|
||||||
@ -37,7 +37,8 @@ class HuggingFaceEndpoint(LLM):
|
|||||||
endpoint_url: str = ""
|
endpoint_url: str = ""
|
||||||
"""Endpoint URL to use."""
|
"""Endpoint URL to use."""
|
||||||
task: Optional[str] = None
|
task: Optional[str] = None
|
||||||
"""Task to call the model with. Should be a task that returns `generated_text`."""
|
"""Task to call the model with.
|
||||||
|
Should be a task that returns `generated_text` or `summary_text`."""
|
||||||
model_kwargs: Optional[dict] = None
|
model_kwargs: Optional[dict] = None
|
||||||
"""Key word arguments to pass to the model."""
|
"""Key word arguments to pass to the model."""
|
||||||
|
|
||||||
@ -138,6 +139,8 @@ class HuggingFaceEndpoint(LLM):
|
|||||||
text = generated_text[0]["generated_text"][len(prompt) :]
|
text = generated_text[0]["generated_text"][len(prompt) :]
|
||||||
elif self.task == "text2text-generation":
|
elif self.task == "text2text-generation":
|
||||||
text = generated_text[0]["generated_text"]
|
text = generated_text[0]["generated_text"]
|
||||||
|
elif self.task == "summarization":
|
||||||
|
text = generated_text[0]["summary_text"]
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Got invalid task {self.task}, "
|
f"Got invalid task {self.task}, "
|
||||||
|
@ -9,7 +9,7 @@ from langchain.llms.utils import enforce_stop_tokens
|
|||||||
from langchain.utils import get_from_dict_or_env
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
DEFAULT_REPO_ID = "gpt2"
|
DEFAULT_REPO_ID = "gpt2"
|
||||||
VALID_TASKS = ("text2text-generation", "text-generation")
|
VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
|
||||||
|
|
||||||
|
|
||||||
class HuggingFaceHub(LLM):
|
class HuggingFaceHub(LLM):
|
||||||
@ -19,7 +19,7 @@ class HuggingFaceHub(LLM):
|
|||||||
environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass
|
environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass
|
||||||
it as a named parameter to the constructor.
|
it as a named parameter to the constructor.
|
||||||
|
|
||||||
Only supports `text-generation` and `text2text-generation` for now.
|
Only supports `text-generation`, `text2text-generation` and `summarization` for now.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
@ -32,7 +32,8 @@ class HuggingFaceHub(LLM):
|
|||||||
repo_id: str = DEFAULT_REPO_ID
|
repo_id: str = DEFAULT_REPO_ID
|
||||||
"""Model name to use."""
|
"""Model name to use."""
|
||||||
task: Optional[str] = None
|
task: Optional[str] = None
|
||||||
"""Task to call the model with. Should be a task that returns `generated_text`."""
|
"""Task to call the model with.
|
||||||
|
Should be a task that returns `generated_text` or `summary_text`."""
|
||||||
model_kwargs: Optional[dict] = None
|
model_kwargs: Optional[dict] = None
|
||||||
"""Key word arguments to pass to the model."""
|
"""Key word arguments to pass to the model."""
|
||||||
|
|
||||||
@ -114,6 +115,8 @@ class HuggingFaceHub(LLM):
|
|||||||
text = response[0]["generated_text"][len(prompt) :]
|
text = response[0]["generated_text"][len(prompt) :]
|
||||||
elif self.client.task == "text2text-generation":
|
elif self.client.task == "text2text-generation":
|
||||||
text = response[0]["generated_text"]
|
text = response[0]["generated_text"]
|
||||||
|
elif self.client.task == "summarization":
|
||||||
|
text = response[0]["summary_text"]
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Got invalid task {self.client.task}, "
|
f"Got invalid task {self.client.task}, "
|
||||||
|
@ -11,7 +11,7 @@ from langchain.llms.utils import enforce_stop_tokens
|
|||||||
|
|
||||||
DEFAULT_MODEL_ID = "gpt2"
|
DEFAULT_MODEL_ID = "gpt2"
|
||||||
DEFAULT_TASK = "text-generation"
|
DEFAULT_TASK = "text-generation"
|
||||||
VALID_TASKS = ("text2text-generation", "text-generation")
|
VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -21,7 +21,7 @@ class HuggingFacePipeline(LLM):
|
|||||||
|
|
||||||
To use, you should have the ``transformers`` python package installed.
|
To use, you should have the ``transformers`` python package installed.
|
||||||
|
|
||||||
Only supports `text-generation` and `text2text-generation` for now.
|
Only supports `text-generation`, `text2text-generation` and `summarization` for now.
|
||||||
|
|
||||||
Example using from_model_id:
|
Example using from_model_id:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
@ -86,7 +86,7 @@ class HuggingFacePipeline(LLM):
|
|||||||
try:
|
try:
|
||||||
if task == "text-generation":
|
if task == "text-generation":
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_id, **_model_kwargs)
|
model = AutoModelForCausalLM.from_pretrained(model_id, **_model_kwargs)
|
||||||
elif task == "text2text-generation":
|
elif task in ("text2text-generation", "summarization"):
|
||||||
model = AutoModelForSeq2SeqLM.from_pretrained(model_id, **_model_kwargs)
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_id, **_model_kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -162,6 +162,8 @@ class HuggingFacePipeline(LLM):
|
|||||||
text = response[0]["generated_text"][len(prompt) :]
|
text = response[0]["generated_text"][len(prompt) :]
|
||||||
elif self.pipeline.task == "text2text-generation":
|
elif self.pipeline.task == "text2text-generation":
|
||||||
text = response[0]["generated_text"]
|
text = response[0]["generated_text"]
|
||||||
|
elif self.pipeline.task == "summarization":
|
||||||
|
text = response[0]["summary_text"]
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Got invalid task {self.pipeline.task}, "
|
f"Got invalid task {self.pipeline.task}, "
|
||||||
|
@ -11,7 +11,7 @@ from langchain.llms.utils import enforce_stop_tokens
|
|||||||
|
|
||||||
DEFAULT_MODEL_ID = "gpt2"
|
DEFAULT_MODEL_ID = "gpt2"
|
||||||
DEFAULT_TASK = "text-generation"
|
DEFAULT_TASK = "text-generation"
|
||||||
VALID_TASKS = ("text2text-generation", "text-generation")
|
VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -35,6 +35,8 @@ def _generate_text(
|
|||||||
text = response[0]["generated_text"][len(prompt) :]
|
text = response[0]["generated_text"][len(prompt) :]
|
||||||
elif pipeline.task == "text2text-generation":
|
elif pipeline.task == "text2text-generation":
|
||||||
text = response[0]["generated_text"]
|
text = response[0]["generated_text"]
|
||||||
|
elif pipeline.task == "summarization":
|
||||||
|
text = response[0]["summary_text"]
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Got invalid task {pipeline.task}, "
|
f"Got invalid task {pipeline.task}, "
|
||||||
@ -64,7 +66,7 @@ def _load_transformer(
|
|||||||
try:
|
try:
|
||||||
if task == "text-generation":
|
if task == "text-generation":
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_id, **_model_kwargs)
|
model = AutoModelForCausalLM.from_pretrained(model_id, **_model_kwargs)
|
||||||
elif task == "text2text-generation":
|
elif task in ("text2text-generation", "summarization"):
|
||||||
model = AutoModelForSeq2SeqLM.from_pretrained(model_id, **_model_kwargs)
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_id, **_model_kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -119,7 +121,7 @@ class SelfHostedHuggingFaceLLM(SelfHostedPipeline):
|
|||||||
|
|
||||||
To use, you should have the ``runhouse`` python package installed.
|
To use, you should have the ``runhouse`` python package installed.
|
||||||
|
|
||||||
Only supports `text-generation` and `text2text-generation` for now.
|
Only supports `text-generation`, `text2text-generation` and `summarization` for now.
|
||||||
|
|
||||||
Example using from_model_id:
|
Example using from_model_id:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
@ -153,7 +155,8 @@ class SelfHostedHuggingFaceLLM(SelfHostedPipeline):
|
|||||||
model_id: str = DEFAULT_MODEL_ID
|
model_id: str = DEFAULT_MODEL_ID
|
||||||
"""Hugging Face model_id to load the model."""
|
"""Hugging Face model_id to load the model."""
|
||||||
task: str = DEFAULT_TASK
|
task: str = DEFAULT_TASK
|
||||||
"""Hugging Face task (either "text-generation" or "text2text-generation")."""
|
"""Hugging Face task ("text-generation", "text2text-generation" or
|
||||||
|
"summarization")."""
|
||||||
device: int = 0
|
device: int = 0
|
||||||
"""Device to use for inference. -1 for CPU, 0 for GPU, 1 for second GPU, etc."""
|
"""Device to use for inference. -1 for CPU, 0 for GPU, 1 for second GPU, etc."""
|
||||||
model_kwargs: Optional[dict] = None
|
model_kwargs: Optional[dict] = None
|
||||||
|
@ -33,6 +33,16 @@ def test_huggingface_endpoint_text2text_generation() -> None:
|
|||||||
assert output == "Albany"
|
assert output == "Albany"
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skip(
|
||||||
|
"This test requires an inference endpoint. Tested with Hugging Face endpoints"
|
||||||
|
)
|
||||||
|
def test_huggingface_endpoint_summarization() -> None:
|
||||||
|
"""Test valid call to HuggingFace summarization model."""
|
||||||
|
llm = HuggingFaceEndpoint(endpoint_url="", task="summarization")
|
||||||
|
output = llm("Say foo:")
|
||||||
|
assert isinstance(output, str)
|
||||||
|
|
||||||
|
|
||||||
def test_huggingface_endpoint_call_error() -> None:
|
def test_huggingface_endpoint_call_error() -> None:
|
||||||
"""Test valid call to HuggingFace that errors."""
|
"""Test valid call to HuggingFace that errors."""
|
||||||
llm = HuggingFaceEndpoint(model_kwargs={"max_new_tokens": -1})
|
llm = HuggingFaceEndpoint(model_kwargs={"max_new_tokens": -1})
|
||||||
|
@ -23,6 +23,13 @@ def test_huggingface_text2text_generation() -> None:
|
|||||||
assert output == "Albany"
|
assert output == "Albany"
|
||||||
|
|
||||||
|
|
||||||
|
def test_huggingface_summarization() -> None:
|
||||||
|
"""Test valid call to HuggingFace summarization model."""
|
||||||
|
llm = HuggingFaceHub(repo_id="facebook/bart-large-cnn")
|
||||||
|
output = llm("Say foo:")
|
||||||
|
assert isinstance(output, str)
|
||||||
|
|
||||||
|
|
||||||
def test_huggingface_call_error() -> None:
|
def test_huggingface_call_error() -> None:
|
||||||
"""Test valid call to HuggingFace that errors."""
|
"""Test valid call to HuggingFace that errors."""
|
||||||
llm = HuggingFaceHub(model_kwargs={"max_new_tokens": -1})
|
llm = HuggingFaceHub(model_kwargs={"max_new_tokens": -1})
|
||||||
|
@ -27,6 +27,15 @@ def test_huggingface_pipeline_text2text_generation() -> None:
|
|||||||
assert isinstance(output, str)
|
assert isinstance(output, str)
|
||||||
|
|
||||||
|
|
||||||
|
def text_huggingface_pipeline_summarization() -> None:
|
||||||
|
"""Test valid call to HuggingFace summarization model."""
|
||||||
|
llm = HuggingFacePipeline.from_model_id(
|
||||||
|
model_id="facebook/bart-large-cnn", task="summarization"
|
||||||
|
)
|
||||||
|
output = llm("Say foo:")
|
||||||
|
assert isinstance(output, str)
|
||||||
|
|
||||||
|
|
||||||
def test_saving_loading_llm(tmp_path: Path) -> None:
|
def test_saving_loading_llm(tmp_path: Path) -> None:
|
||||||
"""Test saving/loading an HuggingFaceHub LLM."""
|
"""Test saving/loading an HuggingFaceHub LLM."""
|
||||||
llm = HuggingFacePipeline.from_model_id(
|
llm = HuggingFacePipeline.from_model_id(
|
||||||
|
@ -43,6 +43,19 @@ def test_self_hosted_huggingface_pipeline_text2text_generation() -> None:
|
|||||||
assert isinstance(output, str)
|
assert isinstance(output, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_self_hosted_huggingface_pipeline_summarization() -> None:
|
||||||
|
"""Test valid call to self-hosted HuggingFace summarization model."""
|
||||||
|
gpu = get_remote_instance()
|
||||||
|
llm = SelfHostedHuggingFaceLLM(
|
||||||
|
model_id="facebook/bart-large-cnn",
|
||||||
|
task="summarization",
|
||||||
|
hardware=gpu,
|
||||||
|
model_reqs=model_reqs,
|
||||||
|
)
|
||||||
|
output = llm("Say foo:")
|
||||||
|
assert isinstance(output, str)
|
||||||
|
|
||||||
|
|
||||||
def load_pipeline() -> Any:
|
def load_pipeline() -> Any:
|
||||||
"""Load pipeline for testing."""
|
"""Load pipeline for testing."""
|
||||||
model_id = "gpt2"
|
model_id = "gpt2"
|
||||||
|
Loading…
Reference in New Issue
Block a user