add text2text generation (#93)

fixes issue #90
pull/95/head
Harrison Chase 2 years ago committed by GitHub
parent e48e562ea5
commit b9f61390e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -10,21 +10,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"Justin Beiber was born in New York City on July 1, 1967. He was the son of the late John Beiber and his wife, Mary.\n",
"\n",
"Justin was raised in a small town in the Bronx, New York. He attended the University of New York at Buffalo, where he majored in English.\n",
"\n",
"Justin was a member of the New York Giants from 1967 to 1969. He was a member of the New York Giants from 1969 to 1971.\n",
"\n",
"Justin was a member of the New York Giants from 1971 to 1972. He was a member of the New York Giants from 1972 to 1974.\n",
"\n",
"Justin was a member of the New York Giants from 1974 to 1975. He was a member of the New York Giants from 1975 to 1977.\n",
"\n",
"Justin was a member of the New York Giants from 1977 to 1978. He was a member of the New York Giants from 1978 to 1979.\n",
"\n",
"Justin was a member of the New York Giants from 1979 to\n"
"The Seattle Seahawks won the Super Bowl in 2010. Justin Beiber was born in 2010. The\n"
]
}
],
@ -35,7 +21,7 @@
"\n",
"Answer: Let's think step by step.\"\"\"\n",
"prompt = Prompt(template=template, input_variables=[\"question\"])\n",
"llm_chain = LLMChain(prompt=prompt, llm=HuggingFaceHub(repo_id=\"gpt2\", temperature=1e-10))\n",
"llm_chain = LLMChain(prompt=prompt, llm=HuggingFaceHub(repo_id=\"google/flan-t5-xl\", model_kwargs={\"temperature\":1e-10}))\n",
"\n",
"question = \"What NFL team won the Super Bowl in the year Justin Beiber was born?\"\n",
"\n",
@ -67,7 +53,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
"version": "3.8.7"
}
},
"nbformat": 4,

@ -1,6 +1,6 @@
"""Wrapper around HuggingFace APIs."""
import os
from typing import Any, Dict, List, Mapping, Optional
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra, root_validator
@ -8,6 +8,7 @@ from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
DEFAULT_REPO_ID = "gpt2"
VALID_TASKS = ("text2text-generation", "text-generation")
class HuggingFaceHub(BaseModel, LLM):
@ -29,14 +30,10 @@ class HuggingFaceHub(BaseModel, LLM):
client: Any #: :meta private:
repo_id: str = DEFAULT_REPO_ID
"""Model name to use."""
temperature: float = 0.7
"""What sampling temperature to use."""
max_new_tokens: int = 200
"""The maximum number of tokens to generate in the completion."""
top_p: int = 1
"""Total probability mass of tokens to consider at each step."""
num_return_sequences: int = 1
"""How many completions to generate for each prompt."""
task: Optional[str] = None
"""Task to call the model with. Should be a task that returns `generated_text`."""
model_kwargs: Optional[dict] = None
"""Key word arguments to pass to the model."""
huggingfacehub_api_token: Optional[str] = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
@ -49,7 +46,6 @@ class HuggingFaceHub(BaseModel, LLM):
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
huggingfacehub_api_token = values.get("huggingfacehub_api_token")
if huggingfacehub_api_token is None or huggingfacehub_api_token == "":
raise ValueError(
"Did not find HuggingFace API token, please add an environment variable"
@ -60,11 +56,17 @@ class HuggingFaceHub(BaseModel, LLM):
from huggingface_hub.inference_api import InferenceApi
repo_id = values.get("repo_id", DEFAULT_REPO_ID)
values["client"] = InferenceApi(
client = InferenceApi(
repo_id=repo_id,
token=huggingfacehub_api_token,
task="text-generation",
task=values.get("task"),
)
if client.task not in VALID_TASKS:
raise ValueError(
f"Got invalid task {client.task}, "
f"currently only {VALID_TASKS} are supported"
)
values["client"] = client
except ImportError:
raise ValueError(
"Could not import huggingface_hub python package. "
@ -72,16 +74,6 @@ class HuggingFaceHub(BaseModel, LLM):
)
return values
@property
def _default_params(self) -> Mapping[str, Any]:
"""Get the default parameters for calling HuggingFace Hub API."""
return {
"temperature": self.temperature,
"max_new_tokens": self.max_new_tokens,
"top_p": self.top_p,
"num_return_sequences": self.num_return_sequences,
}
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Call out to HuggingFace Hub's inference endpoint.
@ -97,10 +89,19 @@ class HuggingFaceHub(BaseModel, LLM):
response = hf("Tell me a joke.")
"""
response = self.client(inputs=prompt, params=self._default_params)
response = self.client(inputs=prompt, params=self.model_kwargs)
if "error" in response:
raise ValueError(f"Error raised by inference API: {response['error']}")
text = response[0]["generated_text"][len(prompt) :]
if self.client.task == "text-generation":
# Text generation return includes the starter text.
text = response[0]["generated_text"][len(prompt) :]
elif self.client.task == "text2text-generation":
text = response[0]["generated_text"]
else:
raise ValueError(
f"Got invalid task {self.client.task}, "
f"currently only {VALID_TASKS} are supported"
)
if stop is not None:
# This is a bit hacky, but I can't figure out a better way to enforce
# stop tokens when making calls to huggingface_hub.

@ -5,15 +5,22 @@ import pytest
from langchain.llms.huggingface_hub import HuggingFaceHub
def test_huggingface_call() -> None:
"""Test valid call to HuggingFace."""
llm = HuggingFaceHub(max_new_tokens=10)
def test_huggingface_text_generation() -> None:
"""Test valid call to HuggingFace text generation model."""
llm = HuggingFaceHub(repo_id="gpt2", model_kwargs={"max_new_tokens": 10})
output = llm("Say foo:")
assert isinstance(output, str)
def test_huggingface_text2text_generation() -> None:
"""Test valid call to HuggingFace text2text model."""
llm = HuggingFaceHub(repo_id="google/flan-t5-xl")
output = llm("The capital of New York is")
assert output == "Albany"
def test_huggingface_call_error() -> None:
"""Test valid call to HuggingFace that errors."""
llm = HuggingFaceHub(max_new_tokens=-1)
llm = HuggingFaceHub(model_kwargs={"max_new_tokens": -1})
with pytest.raises(ValueError):
llm("Say foo:")

Loading…
Cancel
Save