mirror of
https://github.com/HazyResearch/manifest
synced 2024-11-02 09:40:58 +00:00
[bug] pipeline generation gpt3 made generator
This commit is contained in:
parent
1f6d9250fe
commit
966fe6b5d4
@ -1,5 +1,6 @@
|
||||
"""Huggingface model."""
|
||||
import json
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
@ -9,11 +10,56 @@ from transformers import (
|
||||
GPT2LMHeadModel,
|
||||
GPTJForCausalLM,
|
||||
GPTNeoForCausalLM,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizer,
|
||||
pipeline,
|
||||
)
|
||||
|
||||
from manifest.api.models.model import Model
|
||||
|
||||
|
||||
class GPT2Pipeline:
|
||||
"""Custom GPT3 Pipeline."""
|
||||
|
||||
def __init__(
|
||||
self, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, device: int
|
||||
):
|
||||
"""Initialize."""
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.device = device
|
||||
self.model.to(self.device) # type: ignore
|
||||
|
||||
def __call__(self, text: str, **kwargs: Any) -> List[Dict[str, str]]:
|
||||
"""Generate from text.
|
||||
|
||||
Args:
|
||||
text: text to generate.
|
||||
|
||||
Returns:
|
||||
generated text.
|
||||
"""
|
||||
encoded_prompt = self.tokenizer.encode(text, return_tensors="pt")
|
||||
encoded_prompt = encoded_prompt.to(self.device)
|
||||
output_sequences = self.model.generate( # type: ignore
|
||||
encoded_prompt,
|
||||
max_length=kwargs.get("max_length"),
|
||||
temperature=kwargs.get("temperature"),
|
||||
top_k=kwargs.get("top_k"),
|
||||
top_p=kwargs.get("top_p"),
|
||||
repetition_penalty=kwargs.get("repetition_penalty"),
|
||||
do_sample=kwargs.get("do_sample"),
|
||||
eos_token_id=self.tokenizer.eos_token_id,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
num_return_sequences=kwargs.get("num_return_sequences"),
|
||||
)
|
||||
generated_sequences = [
|
||||
{"generated_text": self.tokenizer.decode(output_seq)}
|
||||
for output_seq in output_sequences
|
||||
]
|
||||
return generated_sequences
|
||||
|
||||
|
||||
MODEL_REGISTRY = {
|
||||
"EleutherAI/gpt-j-6B": GPTJForCausalLM,
|
||||
"EleutherAI/gpt-neo-125M": GPTNeoForCausalLM,
|
||||
@ -25,13 +71,13 @@ MODEL_REGISTRY = {
|
||||
}
|
||||
|
||||
MODEL_PIPELINE = {
|
||||
"EleutherAI/gpt-j-6B": "text-generation",
|
||||
"EleutherAI/gpt-neo-125M": "text-generation",
|
||||
"EleutherAI/gpt-neo-1.3B": "text-generation",
|
||||
"EleutherAI/gpt-neo-2.7B": "text-generation",
|
||||
"gpt2": "text-generation",
|
||||
"bigscience/T0pp": "text2text-generation",
|
||||
"bigscience/T0_3B": "text2text-generation",
|
||||
"EleutherAI/gpt-j-6B": GPT2Pipeline,
|
||||
"EleutherAI/gpt-neo-125M": GPT2Pipeline,
|
||||
"EleutherAI/gpt-neo-1.3B": GPT2Pipeline,
|
||||
"EleutherAI/gpt-neo-2.7B": GPT2Pipeline,
|
||||
"gpt2": GPT2Pipeline,
|
||||
"bigscience/T0pp": partial(pipeline, "text2text-generation"),
|
||||
"bigscience/T0_3B": partial(pipeline, "text2text-generation"),
|
||||
}
|
||||
|
||||
|
||||
@ -56,14 +102,14 @@ class HuggingFaceModel(Model):
|
||||
model_name = config["_name_or_path"]
|
||||
self.model_name = model_name
|
||||
print("Model Name:", self.model_name, "Model Path:", self.model_path)
|
||||
model = MODEL_REGISTRY[model_name].from_pretrained(
|
||||
model = MODEL_REGISTRY[model_name].from_pretrained( # type: ignore
|
||||
self.model_path, cache_dir=cache_dir
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.pipeline = pipeline(
|
||||
MODEL_PIPELINE[model_name], model=model, tokenizer=tokenizer, device=device
|
||||
self.pipeline = MODEL_PIPELINE[model_name](
|
||||
model=model, tokenizer=tokenizer, device=device
|
||||
)
|
||||
self.returns_input = MODEL_PIPELINE[model_name] == "text-generation"
|
||||
self.returns_input = "gpt" in model_name
|
||||
|
||||
def get_init_params(self) -> Dict:
|
||||
"""Return init params to determine what model is being used."""
|
||||
@ -93,6 +139,7 @@ class HuggingFaceModel(Model):
|
||||
repetition_penalty=kwargs.get("repetition_penalty"),
|
||||
top_k=kwargs.get("top_k"),
|
||||
top_p=kwargs.get("top_p"),
|
||||
do_sample=kwargs.get("do_sample"),
|
||||
num_return_sequences=num_return,
|
||||
)
|
||||
if self.returns_input:
|
||||
|
@ -21,6 +21,8 @@ class SQLiteCache(Cache):
|
||||
cache_args: cache arguments.
|
||||
"""
|
||||
self.cache_file = connection_str
|
||||
if not self.cache_file:
|
||||
self.cache_file = ".sqlite.cache"
|
||||
self.cache = SqliteDict(self.cache_file, autocommit=True)
|
||||
return
|
||||
|
||||
|
@ -31,6 +31,7 @@ class HuggingFaceClient(Client):
|
||||
self.top_k = client_args.pop("top_k", 50)
|
||||
self.repetition_penalty = client_args.pop("repetition_penalty", 1.0)
|
||||
self.n = client_args.pop("n", 1)
|
||||
self.do_sample = client_args.pop("do_sample", True)
|
||||
self.model_params = self.get_model_params()
|
||||
|
||||
def close(self) -> None:
|
||||
@ -67,6 +68,7 @@ class HuggingFaceClient(Client):
|
||||
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
|
||||
"top_p": kwargs.get("top_p", self.top_p),
|
||||
"top_k": kwargs.get("top_k", self.top_k),
|
||||
"do_sample": kwargs.get("do_sample", self.do_sample),
|
||||
"repetition_penalty": kwargs.get(
|
||||
"repetition_penalty", self.repetition_penalty
|
||||
),
|
||||
|
@ -47,6 +47,7 @@ class OPTClient(Client):
|
||||
"""
|
||||
request_params = {
|
||||
"prompt": query,
|
||||
"engine": "opt",
|
||||
"temperature": kwargs.get("temperature", self.temperature),
|
||||
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
|
||||
"top_p": kwargs.get("top_p", self.top_p),
|
||||
|
50
poetry.lock
generated
50
poetry.lock
generated
@ -984,6 +984,30 @@ torch-speech = ["torchaudio", "librosa", "pyctcdecode (>=0.3.0)", "phonemizer"]
|
||||
torchhub = ["filelock", "huggingface-hub (>=0.1.0,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "torch (>=1.0)", "tokenizers (>=0.11.1,!=0.11.3,<0.13)", "tqdm (>=4.27)"]
|
||||
vision = ["pillow"]
|
||||
|
||||
[[package]]
|
||||
name = "types-protobuf"
|
||||
version = "3.19.21"
|
||||
description = "Typing stubs for protobuf"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
|
||||
[[package]]
|
||||
name = "types-python-dateutil"
|
||||
version = "2.8.16"
|
||||
description = "Typing stubs for python-dateutil"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
|
||||
[[package]]
|
||||
name = "types-pyyaml"
|
||||
version = "6.0.7"
|
||||
description = "Typing stubs for PyYAML"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
|
||||
[[package]]
|
||||
name = "types-redis"
|
||||
version = "4.2.6"
|
||||
@ -1003,6 +1027,14 @@ python-versions = "*"
|
||||
[package.dependencies]
|
||||
types-urllib3 = "<1.27"
|
||||
|
||||
[[package]]
|
||||
name = "types-setuptools"
|
||||
version = "57.4.17"
|
||||
description = "Typing stubs for setuptools"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
|
||||
[[package]]
|
||||
name = "types-urllib3"
|
||||
version = "1.26.15"
|
||||
@ -1084,7 +1116,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-
|
||||
[metadata]
|
||||
lock-version = "1.1"
|
||||
python-versions = "^3.8"
|
||||
content-hash = "9f1c7010530f8668850294c74a7dd60fdaad2086901a5bd0f4438b870666c1a7"
|
||||
content-hash = "1942e48ea705836d5f0440e47a7461bd3d13e41fc67c546dd7467f9dda5e3dcf"
|
||||
|
||||
[metadata.files]
|
||||
alabaster = [
|
||||
@ -1714,6 +1746,18 @@ transformers = [
|
||||
{file = "transformers-4.19.2-py3-none-any.whl", hash = "sha256:1416315b7c5ff1f56d3915f416b67aa254a9907fbb73ef7f7bffc9210446b5fa"},
|
||||
{file = "transformers-4.19.2.tar.gz", hash = "sha256:e19a4ff07458eda143c738e5259caf48449fcf078a63d6b1bd1aa806543440a3"},
|
||||
]
|
||||
types-protobuf = [
|
||||
{file = "types-protobuf-3.19.21.tar.gz", hash = "sha256:ecbff0b17b4196aecec055ae4bd25eb344aaf5d848ce38043a44d68a0d51ec3e"},
|
||||
{file = "types_protobuf-3.19.21-py3-none-any.whl", hash = "sha256:6ac25d57f6d674a3f31c9657997aeebdd1196b3ead87e0b798b174243d719856"},
|
||||
]
|
||||
types-python-dateutil = [
|
||||
{file = "types-python-dateutil-2.8.16.tar.gz", hash = "sha256:3aaac4c138eb6b8ecbc2550996ec25d6e45b5d32887d1e693d0842ee2fa659d2"},
|
||||
{file = "types_python_dateutil-2.8.16-py3-none-any.whl", hash = "sha256:0e7286436d049d2732ba2f01552f84c52184b955552359156fc8835c10714f2a"},
|
||||
]
|
||||
types-pyyaml = [
|
||||
{file = "types-PyYAML-6.0.7.tar.gz", hash = "sha256:59480cf44595d836aaae050f35e3c39f197f3a833679ef3978d97aa9f2fb7def"},
|
||||
{file = "types_PyYAML-6.0.7-py3-none-any.whl", hash = "sha256:7b273a34f32af9910cf9405728c9d2ad3afc4be63e4048091a1a73d76681fe67"},
|
||||
]
|
||||
types-redis = [
|
||||
{file = "types-redis-4.2.6.tar.gz", hash = "sha256:d6adc77185cf40b300816767a64c0ee9ee0b21dc174e8e5c23b7e83d43189cb8"},
|
||||
{file = "types_redis-4.2.6-py3-none-any.whl", hash = "sha256:1136af954ade0be33b487f440c8cbcbee29f089a83e685484ec91f363c6c69fe"},
|
||||
@ -1722,6 +1766,10 @@ types-requests = [
|
||||
{file = "types-requests-2.27.29.tar.gz", hash = "sha256:fb453b3a76a48eca66381cea8004feaaea12835e838196f5c7ac87c75c5c19ef"},
|
||||
{file = "types_requests-2.27.29-py3-none-any.whl", hash = "sha256:014f4f82db7b96c41feea9adaea30e68cd64c230eeab34b70c29bebb26ec74ac"},
|
||||
]
|
||||
types-setuptools = [
|
||||
{file = "types-setuptools-57.4.17.tar.gz", hash = "sha256:9d556fcaf6808a1cead4aaa41e5c07a61f0152a875811e1239738eba4e0b7b16"},
|
||||
{file = "types_setuptools-57.4.17-py3-none-any.whl", hash = "sha256:9c7cdaf0d55113e24ac17103bde2d434472abf1dbf444238e989fe4e798ffa26"},
|
||||
]
|
||||
types-urllib3 = [
|
||||
{file = "types-urllib3-1.26.15.tar.gz", hash = "sha256:c89283541ef92e344b7f59f83ea9b5a295b16366ceee3f25ecfc5593c79f794e"},
|
||||
{file = "types_urllib3-1.26.15-py3-none-any.whl", hash = "sha256:6011befa13f901fc934f59bb1fd6973be6f3acf4ebfce427593a27e7f492918f"},
|
||||
|
@ -28,6 +28,10 @@ requests = "^2.27.1"
|
||||
tqdm = "^4.64.0"
|
||||
types-redis = "^4.2.6"
|
||||
types-requests = "^2.27.29"
|
||||
types-PyYAML = "^6.0.7"
|
||||
types-protobuf = "^3.19.21"
|
||||
types-python-dateutil = "^2.8.16"
|
||||
types-setuptools = "^57.4.17"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
black = "^22.3.0"
|
||||
|
Loading…
Reference in New Issue
Block a user