Merge pull request #3 from HazyResearch/laurel/opt

Laurel/opt
This commit is contained in:
Laurel Orr 2022-05-26 23:28:17 -07:00 committed by GitHub
commit 1f6d9250fe
18 changed files with 272 additions and 87 deletions

View File

@ -1,6 +1,7 @@
dev:
poetry install
poetry run pre-commit install
poetry run mypy --install-types
test: dev check
poetry run pytest tests

View File

@ -3,19 +3,19 @@ Prompt programming with FMs.
# Install
Download the code:
```
```bash
git clone git@github.com:HazyResearch/manifest.git
cd manifest
```
Install:
```
```bash
pip install poetry
poetry install
poetry run pre-commit install
```
or
```
```bash
pip install poetry
make dev
```
@ -28,14 +28,14 @@ Manifest is meant to be a very light weight package to help with prompt iteratio
## Prompts
A Manifest prompt is a function that accepts a single input to generate a string prompt to send to a model.
```
```python
from manifest import Prompt
prompt = Prompt(lambda x: "Hello, my name is {x}")
print(promt("Laurel"))
print(prompt("Laurel"))
>>> "Hello, my name is Laurel"
```
We also let you use static strings
```
```python
prompt = Prompt("Hello, my name is static")
print(prompt())
>>> "Hello, my name is static"
@ -46,13 +46,13 @@ print(prompt())
## Sessions
Each Manifest run is a session that connects to a model endpoint and backend database to record prompt queries. To start a Manifest session for OpenAI, make sure you run
```
```bash
export OPENAI_API_KEY=<OPENAIKEY>
```
so we can access OpenAI.
Then, in a notebook, run:
```
```python
from manifest import Manifest
manifest = Manifest(
@ -64,7 +64,7 @@ manifest = Manifest(
This will start a session with OpenAI and save all results to a local file called `sqlite.cache`.
We also support a Redis backend. If you have a Redis database running on port 6379, run
```
```python
manifest = Manifest(
client_name = "openai",
cache_name = "redis",
@ -77,18 +77,18 @@ We will explain [below](#huggingface-models) how to use Manifest for a locally h
Once you have a session open, you can write and develop prompts.
```
```python
prompt = Prompt(lambda x: "Hello, my name is {x}")
result = manifest.run(prompt, "Laurel")
```
You can also run over multiple examples.
```
```python
results = manifest.batch_run(prompt, ["Laurel", "Avanika"])
```
If something doesn't go right, you can also ask to get a raw manifest Response.
```
```python
result_object = manifest.batch_run(prompt, ["Laurel", "Avanika"], return_response=True)
print(result_object.get_request())
print(result_object.is_cached())
@ -96,24 +96,24 @@ print(result_object.get_response())
```
By default, we do not truncate results based on a stop token. You can change this by either passing a new stop token to a Manifest session or to a `run` or `batch_run`. If you set the stop token to `""`, we will not truncate the model output.
```
```python
result = manifest.run(prompt, "Laurel", stop_token="and")
```
If you want to change default parameters to a model, we pass those as `kwargs` to the client.
```
```python
result = manifest.run(prompt, "Laurel", max_tokens=50)
```
# Huggingface Models
To use a HuggingFace generative model, in `manifest/api` we have a Falsk application that hosts the models for you.
In a separate terminal or Tmux/Screen session, run
```
```python
python3 manifest/api/app.py --model_type huggingface --model_name EleutherAI/gpt-j-6B --device 0
```
You will see the Flask session start and output a URL `http://127.0.0.1:5000`. Pass this in to Manifest. If you want to use a different port, set the `FLASK_PORT` environment variable.
```
```python
manifest = Manifest(
client_name = "huggingface",
client_connection = "http://127.0.0.1:5000",
@ -122,11 +122,13 @@ manifest = Manifest(
)
```
If you have a custom model you trained, pass the model path to `--model_name`.
**Auto deployment coming soon**
# Development
Before submitting a PR, run
```
```bash
export REDIS_PORT="6380" # or whatever PORT local redis is running for those tests
cd <REDIS_PATH>
docker run -d -p 127.0.0.1:${REDIS_PORT}:6380 -v `pwd`:`pwd` -w `pwd` --name manifest_redis_test redis
@ -134,12 +136,12 @@ make test
```
To use our development Redis database, email [Laurel](lorr1@cs.stanford.edu). If you have access to our GCP account, in a separate terminal, run
```
```bash
gcloud compute ssh "manifest-connect" --zone "europe-west4-a" --project "hai-gcp-head-models" -- -N -L 6379:10.152.93.107:6379
```
Then if you issue
```
```bash
redis-cli ping
```
You should see a `PONG` response from our database.

View File

@ -81,6 +81,12 @@ def completions() -> Dict:
return OpenAIResponse(results).__dict__()
@app.route("/params", methods=["POST"])
def params() -> Dict:
"""Get model params."""
return model.get_init_params()
@app.route("/")
def index() -> str:
"""Get index completion."""

View File

@ -1,7 +1,10 @@
"""Huggingface model."""
from typing import Any, List
import json
from pathlib import Path
from typing import Any, Dict, List
from transformers import (
AutoModelForSeq2SeqLM,
AutoTokenizer,
GPT2LMHeadModel,
GPTJForCausalLM,
@ -17,6 +20,18 @@ MODEL_REGISTRY = {
"EleutherAI/gpt-neo-1.3B": GPTNeoForCausalLM,
"EleutherAI/gpt-neo-2.7B": GPTNeoForCausalLM,
"gpt2": GPT2LMHeadModel,
"bigscience/T0pp": AutoModelForSeq2SeqLM,
"bigscience/T0_3B": AutoModelForSeq2SeqLM,
}
MODEL_PIPELINE = {
"EleutherAI/gpt-j-6B": "text-generation",
"EleutherAI/gpt-neo-125M": "text-generation",
"EleutherAI/gpt-neo-1.3B": "text-generation",
"EleutherAI/gpt-neo-2.7B": "text-generation",
"gpt2": "text-generation",
"bigscience/T0pp": "text2text-generation",
"bigscience/T0_3B": "text2text-generation",
}
@ -32,13 +47,27 @@ class HuggingFaceModel(Model):
Args:
model_name: model name string.
"""
# Check if providing path
self.model_path = model_name
if Path(self.model_path).exists() and Path(self.model_path).is_dir():
# Try to find config
if (Path(self.model_path) / "config.json").exists():
config = json.load(open(Path(self.model_path) / "config.json"))
model_name = config["_name_or_path"]
self.model_name = model_name
print("Model Name:", self.model_name, "Model Path:", self.model_path)
model = MODEL_REGISTRY[model_name].from_pretrained(
model_name, cache_dir=cache_dir
self.model_path, cache_dir=cache_dir
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
self.pipeline = pipeline(
"text-generation", model=model, tokenizer=tokenizer, device=device
MODEL_PIPELINE[model_name], model=model, tokenizer=tokenizer, device=device
)
self.returns_input = MODEL_PIPELINE[model_name] == "text-generation"
def get_init_params(self) -> Dict:
"""Return init params to determine what model is being used."""
return {"model_name": self.model_name, "model_path": self.model_path}
def generate(self, prompt: str, **kwargs: Any) -> List[str]:
"""
@ -66,14 +95,12 @@ class HuggingFaceModel(Model):
top_p=kwargs.get("top_p"),
num_return_sequences=num_return,
)
# Removes tokens removed from tokenization
decoded_prompt = self.pipeline.tokenizer.decode(
encoded_prompt, clean_up_tokenization_spaces=True
)
if num_return == 1:
final_results.append(result[0]["generated_text"][len(decoded_prompt) :])
if self.returns_input:
start_idx = len(prompt)
else:
final_results.append(
[r["generated_text"][len(decoded_prompt) :] for r in result]
)
start_idx = 0
if num_return == 1:
final_results.append(result[0]["generated_text"][start_idx:])
else:
final_results.append([r["generated_text"][start_idx:] for r in result])
return final_results

View File

@ -1,6 +1,6 @@
"""Model class."""
from abc import ABC, abstractmethod
from typing import Any, List
from typing import Any, Dict, List
class Model(ABC):
@ -18,6 +18,11 @@ class Model(ABC):
"""
raise NotImplementedError()
@abstractmethod
def get_init_params(self) -> Dict:
"""Return init params to determine what model is being used."""
raise NotImplementedError()
@abstractmethod
def generate(self, prompt: str, **kwargs: Any) -> List[str]:
"""

View File

@ -61,19 +61,20 @@ def key_to_response(key: str) -> Dict:
class Cache(ABC):
"""A cache for request/response pairs."""
def __init__(self, connection_str: str, **kwargs: Any):
def __init__(self, connection_str: str, cache_args: Dict[str, Any] = {}):
"""
Initialize client.
kwargs are passed to client as default parameters.
cache_args are passed to client as default parameters.
For clients like OpenAI that do not require a connection,
the connection_str can be None.
Args:
connection_str: connection string for client.
cache_args: cache arguments.
"""
self.connect(connection_str, **kwargs)
self.connect(connection_str, cache_args)
@abstractmethod
def close(self) -> None:
@ -81,7 +82,7 @@ class Cache(ABC):
raise NotImplementedError()
@abstractmethod
def connect(self, connection_str: str, **kwargs: Any) -> None:
def connect(self, connection_str: str, cache_args: Dict[str, Any]) -> None:
"""
Connect to client.

View File

@ -1,5 +1,5 @@
"""Redis cache."""
from typing import Any, Union
from typing import Any, Dict, Union
import redis
@ -9,12 +9,13 @@ from manifest.caches import Cache
class RedisCache(Cache):
"""A Redis cache for request/response pairs."""
def connect(self, connection_str: str, **kwargs: Any) -> None:
def connect(self, connection_str: str, cache_args: Dict[str, Any]) -> None:
"""
Connect to client.
Args:
connection_str: connection string.
cache_args: cache arguments.
"""
host, port = connection_str.split(":")
self.redis = redis.Redis(host=host, port=int(port), db=0)

View File

@ -1,6 +1,6 @@
"""SQLite cache."""
import logging
from typing import Any, Union
from typing import Any, Dict, Union
from sqlitedict import SqliteDict
@ -12,12 +12,13 @@ logging.getLogger("sqlitedict").setLevel(logging.WARNING)
class SQLiteCache(Cache):
"""A SQLite cache for request/response pairs."""
def connect(self, connection_str: str, **kwargs: Any) -> None:
def connect(self, connection_str: str, cache_args: Dict[str, Any]) -> None:
"""
Connect to client.
Args:
connection_str: connection string.
cache_args: cache arguments.
"""
self.cache_file = connection_str
self.cache = SqliteDict(self.cache_file, autocommit=True)

View File

@ -6,7 +6,9 @@ from typing import Any, Callable, Dict, Optional, Tuple
class Client(ABC):
"""Client class."""
def __init__(self, connection_str: Optional[str] = None, **kwargs: Any):
def __init__(
self, connection_str: Optional[str] = None, client_args: Dict[str, Any] = {}
):
"""
Initialize client.
@ -17,8 +19,9 @@ class Client(ABC):
Args:
connection_str: connection string for client.
client_args: client arguments.
"""
self.connect(connection_str, **kwargs)
self.connect(connection_str, client_args)
@abstractmethod
def close(self) -> None:
@ -26,7 +29,7 @@ class Client(ABC):
raise NotImplementedError()
@abstractmethod
def connect(self, connection_str: str, **kwargs: Any) -> None:
def connect(self, connection_str: str, client_args: Dict[str, Any]) -> None:
"""
Connect to client.

View File

@ -13,15 +13,18 @@ class DummyClient(Client):
def connect(
self,
connection_str: Optional[str] = None,
num_results: Optional[int] = 1,
**kwargs: Any,
client_args: Dict[str, Any] = {},
) -> None:
"""
Connect to dummpy server.
This is a dummy client that returns identity responses. Used for testing.
Args:
connection_str: connection string.
client_args: client arguments.
"""
self.num_results = num_results
self.num_results = client_args.pop("num_results", 1)
def close(self) -> None:
"""Close the client."""

View File

@ -15,27 +15,41 @@ class HuggingFaceClient(Client):
def connect(
self,
connection_str: Optional[str] = None,
temperature: Optional[float] = 1.0,
max_tokens: Optional[int] = 10,
top_p: Optional[float] = 1.0,
top_k: Optional[int] = 0,
repetition_penalty: Optional[float] = 1.0,
n: Optional[int] = 1,
**kwargs: Any,
client_args: Dict[str, Any] = {},
) -> None:
"""Connect to the HuggingFace url."""
"""
Connect to the HuggingFace url.
Arsg:
connection_str: connection string.
client_args: client arguments.
"""
self.host = connection_str.rstrip("/")
self.temperature = temperature
self.max_tokens = max_tokens
self.top_p = top_p
self.top_k = top_k
self.repetition_penalty = repetition_penalty
self.n = n
self.temperature = client_args.pop("temperature", 0.00001)
self.max_tokens = client_args.pop("max_tokens", 10)
self.top_p = client_args.pop("top_p", 1.0)
self.top_k = client_args.pop("top_k", 50)
self.repetition_penalty = client_args.pop("repetition_penalty", 1.0)
self.n = client_args.pop("n", 1)
self.model_params = self.get_model_params()
def close(self) -> None:
"""Close the client."""
pass
def get_model_params(self) -> Dict:
"""
Get model params.
By getting model params from the server, we can add to request
and make sure cache keys are unique to model.
Returns:
model params.
"""
res = requests.post(self.host + "/params")
return res.json()
def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]:
"""
Get request string function.
@ -58,6 +72,7 @@ class HuggingFaceClient(Client):
),
"n": kwargs.get("n", self.n),
}
request_params.update(self.model_params)
def _run_completion() -> Dict:
post_str = self.host + "/completions"

View File

@ -24,35 +24,32 @@ class OpenAIClient(Client):
def connect(
self,
connection_str: Optional[str] = None,
engine: Optional[str] = "text-ada-001",
temperature: Optional[float] = 0.0,
max_tokens: Optional[int] = 10,
top_p: Optional[float] = 1.0,
frequency_penalty: Optional[int] = 0,
presence_penalty: Optional[int] = 0,
n: Optional[int] = 1,
**kwargs: Any,
client_args: Dict[str, Any] = {},
) -> None:
"""
Connect to the OpenAI server.
connection_str is passed as default OPENAI_API_KEY if variable not set.
Args:
connection_str: connection string.
client_args: client arguments.
"""
openai.api_key = os.environ.get("OPENAI_API_KEY", connection_str)
if openai.api_key is None:
raise ValueError(
"OpenAI API key not set. Set OPENAI_API_KEY environment ",
"svariable or pass through `connection_str`.",
"OpenAI API key not set. Set OPENAI_API_KEY environment "
"variable or pass through `connection_str`."
)
self.engine = engine
self.engine = client_args.pop("engine", "text-davinci-002")
if self.engine not in OPENAI_ENGINES:
raise ValueError(f"Invalid engine {self.engine}. Must be {OPENAI_ENGINES}.")
self.temperature = temperature
self.max_tokens = max_tokens
self.top_p = top_p
self.frequency_penalty = frequency_penalty
self.presence_penalty = presence_penalty
self.n = n
self.temperature = client_args.pop("temperature", 0.0)
self.max_tokens = client_args.pop("max_tokens", 10)
self.top_p = client_args.pop("top_p", 1.0)
self.frequency_penalty = client_args.pop("frequency_penalty", 0.0)
self.presence_penalty = client_args.pop("presence_penalty", 0.0)
self.n = client_args.pop("n", 1)
def close(self) -> None:
"""Close the client."""

61
manifest/clients/opt.py Normal file
View File

@ -0,0 +1,61 @@
"""OpenAI client."""
import logging
from typing import Any, Callable, Dict, Optional, Tuple
import requests
from manifest.clients.client import Client
logger = logging.getLogger(__name__)
class OPTClient(Client):
"""OPT client."""
def connect(
self,
connection_str: Optional[str] = None,
client_args: Dict[str, Any] = {},
) -> None:
"""
Connect to the OPT url.
Arsg:
connection_str: connection string.
client_args: client arguments.
"""
self.host = connection_str.rstrip("/")
self.temperature = client_args.pop("temperature", 1.0)
self.max_tokens = client_args.pop("max_tokens", 10)
self.top_p = client_args.pop("top_p", 0)
self.n = client_args.pop("n", 1)
def close(self) -> None:
"""Close the client."""
pass
def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]:
"""
Get request string function.
Args:
query: query string.
Returns:
request function that takes no input.
request parameters as dict.
"""
request_params = {
"prompt": query,
"temperature": kwargs.get("temperature", self.temperature),
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
"top_p": kwargs.get("top_p", self.top_p),
"n": kwargs.get("n", self.n),
}
def _run_completion() -> Dict:
post_str = self.host + "/completions"
res = requests.post(post_str, json=request_params)
return res.json()
return _run_completion, request_params

View File

@ -2,21 +2,24 @@
import logging
from typing import Any, Iterable, List, Optional, Union
from manifest.response import Response
logging.getLogger("openai").setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
from tqdm.auto import tqdm
from manifest.caches.redis import RedisCache
from manifest.caches.sqlite import SQLiteCache
from manifest.clients.dummy import DummyClient
from manifest.clients.huggingface import HuggingFaceClient
from manifest.clients.openai import OpenAIClient
from manifest.clients.opt import OPTClient
from manifest.prompt import Prompt
from manifest.response import Response
logging.getLogger("openai").setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
CLIENT_CONSTRUCTORS = {
"openai": OpenAIClient,
"huggingface": HuggingFaceClient,
"opt": OPTClient,
"dummy": DummyClient,
}
@ -62,10 +65,14 @@ class Manifest:
f"Choices are {list(CACHE_CONSTRUCTORS.keys())}"
)
self.client_name = client_name
# Must pass kwargs as dict to client "pop" methods removed used arguments
self.client = CLIENT_CONSTRUCTORS[client_name]( # type: ignore
client_connection, **kwargs
client_connection, client_args=kwargs
)
self.cache = CACHE_CONSTRUCTORS[cache_name](cache_connection, **kwargs)
self.cache = CACHE_CONSTRUCTORS[cache_name](cache_connection, cache_args=kwargs)
if len(kwargs) > 0:
raise ValueError(f"{list(kwargs.items())} arguments are not recognized.")
self.stop_token = stop_token
def close(self) -> None:
@ -119,6 +126,7 @@ class Manifest:
overwrite_cache: bool = False,
stop_token: Optional[str] = None,
return_response: bool = False,
verbose: bool = False,
**kwargs: Any,
) -> Iterable[Union[str, List[str], Response]]:
"""
@ -141,7 +149,7 @@ class Manifest:
self.run(
prompt, inp, overwrite_cache, stop_token, return_response, **kwargs
)
for inp in input
for inp in tqdm(input, desc="Running batch", disable=not verbose)
]
def save_prompt(self, name: str, prompt: Prompt) -> None:

41
poetry.lock generated
View File

@ -984,6 +984,33 @@ torch-speech = ["torchaudio", "librosa", "pyctcdecode (>=0.3.0)", "phonemizer"]
torchhub = ["filelock", "huggingface-hub (>=0.1.0,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "torch (>=1.0)", "tokenizers (>=0.11.1,!=0.11.3,<0.13)", "tqdm (>=4.27)"]
vision = ["pillow"]
[[package]]
name = "types-redis"
version = "4.2.6"
description = "Typing stubs for redis"
category = "main"
optional = false
python-versions = "*"
[[package]]
name = "types-requests"
version = "2.27.29"
description = "Typing stubs for requests"
category = "main"
optional = false
python-versions = "*"
[package.dependencies]
types-urllib3 = "<1.27"
[[package]]
name = "types-urllib3"
version = "1.26.15"
description = "Typing stubs for urllib3"
category = "main"
optional = false
python-versions = "*"
[[package]]
name = "typing-extensions"
version = "4.2.0"
@ -1057,7 +1084,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-
[metadata]
lock-version = "1.1"
python-versions = "^3.8"
content-hash = "b0f92eed3a0f80cc9976f68098b772a19d19b4742426299bd63cb1c636d8a84a"
content-hash = "9f1c7010530f8668850294c74a7dd60fdaad2086901a5bd0f4438b870666c1a7"
[metadata.files]
alabaster = [
@ -1687,6 +1714,18 @@ transformers = [
{file = "transformers-4.19.2-py3-none-any.whl", hash = "sha256:1416315b7c5ff1f56d3915f416b67aa254a9907fbb73ef7f7bffc9210446b5fa"},
{file = "transformers-4.19.2.tar.gz", hash = "sha256:e19a4ff07458eda143c738e5259caf48449fcf078a63d6b1bd1aa806543440a3"},
]
types-redis = [
{file = "types-redis-4.2.6.tar.gz", hash = "sha256:d6adc77185cf40b300816767a64c0ee9ee0b21dc174e8e5c23b7e83d43189cb8"},
{file = "types_redis-4.2.6-py3-none-any.whl", hash = "sha256:1136af954ade0be33b487f440c8cbcbee29f089a83e685484ec91f363c6c69fe"},
]
types-requests = [
{file = "types-requests-2.27.29.tar.gz", hash = "sha256:fb453b3a76a48eca66381cea8004feaaea12835e838196f5c7ac87c75c5c19ef"},
{file = "types_requests-2.27.29-py3-none-any.whl", hash = "sha256:014f4f82db7b96c41feea9adaea30e68cd64c230eeab34b70c29bebb26ec74ac"},
]
types-urllib3 = [
{file = "types-urllib3-1.26.15.tar.gz", hash = "sha256:c89283541ef92e344b7f59f83ea9b5a295b16366ceee3f25ecfc5593c79f794e"},
{file = "types_urllib3-1.26.15-py3-none-any.whl", hash = "sha256:6011befa13f901fc934f59bb1fd6973be6f3acf4ebfce427593a27e7f492918f"},
]
typing-extensions = [
{file = "typing_extensions-4.2.0-py3-none-any.whl", hash = "sha256:6657594ee297170d19f67d55c05852a874e7eb634f4f753dbd667855e07c1708"},
{file = "typing_extensions-4.2.0.tar.gz", hash = "sha256:f1c24655a0da0d1b67f07e17a5e6b2a105894e6824b92096378bb3668ef02376"},

View File

@ -25,6 +25,9 @@ Flask = "^2.1.2"
transformers = "^4.19.2"
torch = "^1.11.0"
requests = "^2.27.1"
tqdm = "^4.64.0"
types-redis = "^4.2.6"
types-requests = "^2.27.29"
[tool.poetry.dev-dependencies]
black = "^22.3.0"
@ -55,6 +58,7 @@ module = [
"tqdm",
"sqlitedict",
"dill",
"tqdm.auto",
]
[tool.isort]

View File

@ -9,13 +9,15 @@ from manifest.clients.dummy import DummyClient
def test_init():
"""Test client initialization."""
client = DummyClient(connection_str=None, num_results=3)
args = {"num_results": 3}
client = DummyClient(connection_str=None, client_args=args)
assert client.num_results == 3
def test_get_request():
"""Test client get request."""
client = DummyClient(connection_str=None, num_results=3)
args = {"num_results": 3}
client = DummyClient(connection_str=None, client_args=args)
request_func, request_params = client.get_request("hello")
assert request_params == {"prompt": "hello", "num_results": 3}
assert request_func() == {"choices": [{"text": "hello"}] * 3}

View File

@ -10,6 +10,15 @@ from manifest.clients.dummy import DummyClient
@pytest.mark.usefixtures("sqlite_cache")
def test_init(sqlite_cache):
"""Test manifest initialization."""
with pytest.raises(ValueError) as exc_info:
Manifest(
client_name="dummy",
cache_name="sqlite",
cache_connection=sqlite_cache,
sep_tok="",
)
assert str(exc_info.value) == "[('sep_tok', '')] arguments are not recognized."
manifest = Manifest(
client_name="dummy",
cache_name="sqlite",