[bug] pipeline generation gpt3 made generator

This commit is contained in:
Laurel Orr 2022-05-28 01:01:42 -07:00
parent 1f6d9250fe
commit 966fe6b5d4
6 changed files with 116 additions and 12 deletions

View File

@ -1,5 +1,6 @@
"""Huggingface model."""
import json
from functools import partial
from pathlib import Path
from typing import Any, Dict, List
@ -9,11 +10,56 @@ from transformers import (
GPT2LMHeadModel,
GPTJForCausalLM,
GPTNeoForCausalLM,
PreTrainedModel,
PreTrainedTokenizer,
pipeline,
)
from manifest.api.models.model import Model
class GPT2Pipeline:
"""Custom GPT3 Pipeline."""
def __init__(
self, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, device: int
):
"""Initialize."""
self.model = model
self.tokenizer = tokenizer
self.device = device
self.model.to(self.device) # type: ignore
def __call__(self, text: str, **kwargs: Any) -> List[Dict[str, str]]:
"""Generate from text.
Args:
text: text to generate.
Returns:
generated text.
"""
encoded_prompt = self.tokenizer.encode(text, return_tensors="pt")
encoded_prompt = encoded_prompt.to(self.device)
output_sequences = self.model.generate( # type: ignore
encoded_prompt,
max_length=kwargs.get("max_length"),
temperature=kwargs.get("temperature"),
top_k=kwargs.get("top_k"),
top_p=kwargs.get("top_p"),
repetition_penalty=kwargs.get("repetition_penalty"),
do_sample=kwargs.get("do_sample"),
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
num_return_sequences=kwargs.get("num_return_sequences"),
)
generated_sequences = [
{"generated_text": self.tokenizer.decode(output_seq)}
for output_seq in output_sequences
]
return generated_sequences
MODEL_REGISTRY = {
"EleutherAI/gpt-j-6B": GPTJForCausalLM,
"EleutherAI/gpt-neo-125M": GPTNeoForCausalLM,
@ -25,13 +71,13 @@ MODEL_REGISTRY = {
}
MODEL_PIPELINE = {
"EleutherAI/gpt-j-6B": "text-generation",
"EleutherAI/gpt-neo-125M": "text-generation",
"EleutherAI/gpt-neo-1.3B": "text-generation",
"EleutherAI/gpt-neo-2.7B": "text-generation",
"gpt2": "text-generation",
"bigscience/T0pp": "text2text-generation",
"bigscience/T0_3B": "text2text-generation",
"EleutherAI/gpt-j-6B": GPT2Pipeline,
"EleutherAI/gpt-neo-125M": GPT2Pipeline,
"EleutherAI/gpt-neo-1.3B": GPT2Pipeline,
"EleutherAI/gpt-neo-2.7B": GPT2Pipeline,
"gpt2": GPT2Pipeline,
"bigscience/T0pp": partial(pipeline, "text2text-generation"),
"bigscience/T0_3B": partial(pipeline, "text2text-generation"),
}
@ -56,14 +102,14 @@ class HuggingFaceModel(Model):
model_name = config["_name_or_path"]
self.model_name = model_name
print("Model Name:", self.model_name, "Model Path:", self.model_path)
model = MODEL_REGISTRY[model_name].from_pretrained(
model = MODEL_REGISTRY[model_name].from_pretrained( # type: ignore
self.model_path, cache_dir=cache_dir
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
self.pipeline = pipeline(
MODEL_PIPELINE[model_name], model=model, tokenizer=tokenizer, device=device
self.pipeline = MODEL_PIPELINE[model_name](
model=model, tokenizer=tokenizer, device=device
)
self.returns_input = MODEL_PIPELINE[model_name] == "text-generation"
self.returns_input = "gpt" in model_name
def get_init_params(self) -> Dict:
"""Return init params to determine what model is being used."""
@ -93,6 +139,7 @@ class HuggingFaceModel(Model):
repetition_penalty=kwargs.get("repetition_penalty"),
top_k=kwargs.get("top_k"),
top_p=kwargs.get("top_p"),
do_sample=kwargs.get("do_sample"),
num_return_sequences=num_return,
)
if self.returns_input:

View File

@ -21,6 +21,8 @@ class SQLiteCache(Cache):
cache_args: cache arguments.
"""
self.cache_file = connection_str
if not self.cache_file:
self.cache_file = ".sqlite.cache"
self.cache = SqliteDict(self.cache_file, autocommit=True)
return

View File

@ -31,6 +31,7 @@ class HuggingFaceClient(Client):
self.top_k = client_args.pop("top_k", 50)
self.repetition_penalty = client_args.pop("repetition_penalty", 1.0)
self.n = client_args.pop("n", 1)
self.do_sample = client_args.pop("do_sample", True)
self.model_params = self.get_model_params()
def close(self) -> None:
@ -67,6 +68,7 @@ class HuggingFaceClient(Client):
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
"top_p": kwargs.get("top_p", self.top_p),
"top_k": kwargs.get("top_k", self.top_k),
"do_sample": kwargs.get("do_sample", self.do_sample),
"repetition_penalty": kwargs.get(
"repetition_penalty", self.repetition_penalty
),

View File

@ -47,6 +47,7 @@ class OPTClient(Client):
"""
request_params = {
"prompt": query,
"engine": "opt",
"temperature": kwargs.get("temperature", self.temperature),
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
"top_p": kwargs.get("top_p", self.top_p),

50
poetry.lock generated
View File

@ -984,6 +984,30 @@ torch-speech = ["torchaudio", "librosa", "pyctcdecode (>=0.3.0)", "phonemizer"]
torchhub = ["filelock", "huggingface-hub (>=0.1.0,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "torch (>=1.0)", "tokenizers (>=0.11.1,!=0.11.3,<0.13)", "tqdm (>=4.27)"]
vision = ["pillow"]
[[package]]
name = "types-protobuf"
version = "3.19.21"
description = "Typing stubs for protobuf"
category = "main"
optional = false
python-versions = "*"
[[package]]
name = "types-python-dateutil"
version = "2.8.16"
description = "Typing stubs for python-dateutil"
category = "main"
optional = false
python-versions = "*"
[[package]]
name = "types-pyyaml"
version = "6.0.7"
description = "Typing stubs for PyYAML"
category = "main"
optional = false
python-versions = "*"
[[package]]
name = "types-redis"
version = "4.2.6"
@ -1003,6 +1027,14 @@ python-versions = "*"
[package.dependencies]
types-urllib3 = "<1.27"
[[package]]
name = "types-setuptools"
version = "57.4.17"
description = "Typing stubs for setuptools"
category = "main"
optional = false
python-versions = "*"
[[package]]
name = "types-urllib3"
version = "1.26.15"
@ -1084,7 +1116,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-
[metadata]
lock-version = "1.1"
python-versions = "^3.8"
content-hash = "9f1c7010530f8668850294c74a7dd60fdaad2086901a5bd0f4438b870666c1a7"
content-hash = "1942e48ea705836d5f0440e47a7461bd3d13e41fc67c546dd7467f9dda5e3dcf"
[metadata.files]
alabaster = [
@ -1714,6 +1746,18 @@ transformers = [
{file = "transformers-4.19.2-py3-none-any.whl", hash = "sha256:1416315b7c5ff1f56d3915f416b67aa254a9907fbb73ef7f7bffc9210446b5fa"},
{file = "transformers-4.19.2.tar.gz", hash = "sha256:e19a4ff07458eda143c738e5259caf48449fcf078a63d6b1bd1aa806543440a3"},
]
types-protobuf = [
{file = "types-protobuf-3.19.21.tar.gz", hash = "sha256:ecbff0b17b4196aecec055ae4bd25eb344aaf5d848ce38043a44d68a0d51ec3e"},
{file = "types_protobuf-3.19.21-py3-none-any.whl", hash = "sha256:6ac25d57f6d674a3f31c9657997aeebdd1196b3ead87e0b798b174243d719856"},
]
types-python-dateutil = [
{file = "types-python-dateutil-2.8.16.tar.gz", hash = "sha256:3aaac4c138eb6b8ecbc2550996ec25d6e45b5d32887d1e693d0842ee2fa659d2"},
{file = "types_python_dateutil-2.8.16-py3-none-any.whl", hash = "sha256:0e7286436d049d2732ba2f01552f84c52184b955552359156fc8835c10714f2a"},
]
types-pyyaml = [
{file = "types-PyYAML-6.0.7.tar.gz", hash = "sha256:59480cf44595d836aaae050f35e3c39f197f3a833679ef3978d97aa9f2fb7def"},
{file = "types_PyYAML-6.0.7-py3-none-any.whl", hash = "sha256:7b273a34f32af9910cf9405728c9d2ad3afc4be63e4048091a1a73d76681fe67"},
]
types-redis = [
{file = "types-redis-4.2.6.tar.gz", hash = "sha256:d6adc77185cf40b300816767a64c0ee9ee0b21dc174e8e5c23b7e83d43189cb8"},
{file = "types_redis-4.2.6-py3-none-any.whl", hash = "sha256:1136af954ade0be33b487f440c8cbcbee29f089a83e685484ec91f363c6c69fe"},
@ -1722,6 +1766,10 @@ types-requests = [
{file = "types-requests-2.27.29.tar.gz", hash = "sha256:fb453b3a76a48eca66381cea8004feaaea12835e838196f5c7ac87c75c5c19ef"},
{file = "types_requests-2.27.29-py3-none-any.whl", hash = "sha256:014f4f82db7b96c41feea9adaea30e68cd64c230eeab34b70c29bebb26ec74ac"},
]
types-setuptools = [
{file = "types-setuptools-57.4.17.tar.gz", hash = "sha256:9d556fcaf6808a1cead4aaa41e5c07a61f0152a875811e1239738eba4e0b7b16"},
{file = "types_setuptools-57.4.17-py3-none-any.whl", hash = "sha256:9c7cdaf0d55113e24ac17103bde2d434472abf1dbf444238e989fe4e798ffa26"},
]
types-urllib3 = [
{file = "types-urllib3-1.26.15.tar.gz", hash = "sha256:c89283541ef92e344b7f59f83ea9b5a295b16366ceee3f25ecfc5593c79f794e"},
{file = "types_urllib3-1.26.15-py3-none-any.whl", hash = "sha256:6011befa13f901fc934f59bb1fd6973be6f3acf4ebfce427593a27e7f492918f"},

View File

@ -28,6 +28,10 @@ requests = "^2.27.1"
tqdm = "^4.64.0"
types-redis = "^4.2.6"
types-requests = "^2.27.29"
types-PyYAML = "^6.0.7"
types-protobuf = "^3.19.21"
types-python-dateutil = "^2.8.16"
types-setuptools = "^57.4.17"
[tool.poetry.dev-dependencies]
black = "^22.3.0"