[feature] better naming models in cache

laurel/helm
Laurel Orr 2 years ago
parent 894042c370
commit a82c8c89df

@ -3,19 +3,19 @@ Prompt programming with FMs.
# Install
Download the code:
```
```bash
git clone git@github.com:HazyResearch/manifest.git
cd manifest
```
Install:
```
```bash
pip install poetry
poetry install
poetry run pre-commit install
```
or
```
```bash
pip install poetry
make dev
```
@ -28,14 +28,14 @@ Manifest is meant to be a very light weight package to help with prompt iteratio
## Prompts
A Manifest prompt is a function that accepts a single input to generate a string prompt to send to a model.
```
```python
from manifest import Prompt
prompt = Prompt(lambda x: "Hello, my name is {x}")
print(prompt("Laurel"))
>>> "Hello, my name is Laurel"
```
We also let you use static strings
```
```python
prompt = Prompt("Hello, my name is static")
print(prompt())
>>> "Hello, my name is static"
@ -46,13 +46,13 @@ print(prompt())
## Sessions
Each Manifest run is a session that connects to a model endpoint and backend database to record prompt queries. To start a Manifest session for OpenAI, make sure you run
```
```bash
export OPENAI_API_KEY=<OPENAIKEY>
```
so we can access OpenAI.
Then, in a notebook, run:
```
```python
from manifest import Manifest
manifest = Manifest(
@ -64,7 +64,7 @@ manifest = Manifest(
This will start a session with OpenAI and save all results to a local file called `sqlite.cache`.
We also support a Redis backend. If you have a Redis database running on port 6379, run
```
```python
manifest = Manifest(
client_name = "openai",
cache_name = "redis",
@ -77,18 +77,18 @@ We will explain [below](#huggingface-models) how to use Manifest for a locally h
Once you have a session open, you can write and develop prompts.
```
```python
prompt = Prompt(lambda x: "Hello, my name is {x}")
result = manifest.run(prompt, "Laurel")
```
You can also run over multiple examples.
```
```python
results = manifest.batch_run(prompt, ["Laurel", "Avanika"])
```
If something doesn't go right, you can also ask to get a raw manifest Response.
```
```python
result_object = manifest.batch_run(prompt, ["Laurel", "Avanika"], return_response=True)
print(result_object.get_request())
print(result_object.is_cached())
@ -96,24 +96,24 @@ print(result_object.get_response())
```
By default, we do not truncate results based on a stop token. You can change this by either passing a new stop token to a Manifest session or to a `run` or `batch_run`. If you set the stop token to `""`, we will not truncate the model output.
```
```python
result = manifest.run(prompt, "Laurel", stop_token="and")
```
If you want to change default parameters to a model, we pass those as `kwargs` to the client.
```
```python
result = manifest.run(prompt, "Laurel", max_tokens=50)
```
# Huggingface Models
To use a HuggingFace generative model, in `manifest/api` we have a Falsk application that hosts the models for you.
In a separate terminal or Tmux/Screen session, run
```
```python
python3 manifest/api/app.py --model_type huggingface --model_name EleutherAI/gpt-j-6B --device 0
```
You will see the Flask session start and output a URL `http://127.0.0.1:5000`. Pass this in to Manifest. If you want to use a different port, set the `FLASK_PORT` environment variable.
```
```python
manifest = Manifest(
client_name = "huggingface",
client_connection = "http://127.0.0.1:5000",
@ -122,11 +122,13 @@ manifest = Manifest(
)
```
If you have a custom model you trained, pass the model path to `--model_name`.
**Auto deployment coming soon**
# Development
Before submitting a PR, run
```
```bash
export REDIS_PORT="6380" # or whatever PORT local redis is running for those tests
cd <REDIS_PATH>
docker run -d -p 127.0.0.1:${REDIS_PORT}:6380 -v `pwd`:`pwd` -w `pwd` --name manifest_redis_test redis
@ -134,12 +136,12 @@ make test
```
To use our development Redis database, email [Laurel](lorr1@cs.stanford.edu). If you have access to our GCP account, in a separate terminal, run
```
```bash
gcloud compute ssh "manifest-connect" --zone "europe-west4-a" --project "hai-gcp-head-models" -- -N -L 6379:10.152.93.107:6379
```
Then if you issue
```
```bash
redis-cli ping
```
You should see a `PONG` response from our database.

@ -81,6 +81,12 @@ def completions() -> Dict:
return OpenAIResponse(results).__dict__()
@app.route("/params", methods=["POST"])
def params() -> Dict:
"""Get model params."""
return model.get_init_params()
@app.route("/")
def index() -> str:
"""Get index completion."""

@ -1,5 +1,7 @@
"""Huggingface model."""
from typing import Any, List
import json
from pathlib import Path
from typing import Any, Dict, List
from transformers import (
AutoModelForSeq2SeqLM,
@ -19,6 +21,7 @@ MODEL_REGISTRY = {
"EleutherAI/gpt-neo-2.7B": GPTNeoForCausalLM,
"gpt2": GPT2LMHeadModel,
"bigscience/T0pp": AutoModelForSeq2SeqLM,
"bigscience/T0_3B": AutoModelForSeq2SeqLM,
}
MODEL_PIPELINE = {
@ -28,6 +31,7 @@ MODEL_PIPELINE = {
"EleutherAI/gpt-neo-2.7B": "text-generation",
"gpt2": "text-generation",
"bigscience/T0pp": "text2text-generation",
"bigscience/T0_3B": "text2text-generation",
}
@ -43,13 +47,27 @@ class HuggingFaceModel(Model):
Args:
model_name: model name string.
"""
# Check if providing path
self.model_path = model_name
if Path(self.model_path).exists() and Path(self.model_path).is_dir():
# Try to find config
if (Path(self.model_path) / "config.json").exists():
config = json.load(open(Path(self.model_path) / "config.json"))
model_name = config["_name_or_path"]
self.model_name = model_name
print("Model Name:", self.model_name, "Model Path:", self.model_path)
model = MODEL_REGISTRY[model_name].from_pretrained(
model_name, cache_dir=cache_dir
self.model_path, cache_dir=cache_dir
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
self.pipeline = pipeline(
MODEL_PIPELINE[model_name], model=model, tokenizer=tokenizer, device=device
)
self.returns_input = MODEL_PIPELINE[model_name] == "text-generation"
def get_init_params(self) -> Dict:
"""Return init params to determine what model is being used."""
return {"model_name": self.model_name, "model_path": self.model_path}
def generate(self, prompt: str, **kwargs: Any) -> List[str]:
"""
@ -77,14 +95,12 @@ class HuggingFaceModel(Model):
top_p=kwargs.get("top_p"),
num_return_sequences=num_return,
)
# Removes tokens removed from tokenization
decoded_prompt = self.pipeline.tokenizer.decode(
encoded_prompt, clean_up_tokenization_spaces=True
)
if self.returns_input:
start_idx = len(prompt)
else:
start_idx = 0
if num_return == 1:
final_results.append(result[0]["generated_text"][len(decoded_prompt) :])
final_results.append(result[0]["generated_text"][start_idx:])
else:
final_results.append(
[r["generated_text"][len(decoded_prompt) :] for r in result]
)
final_results.append([r["generated_text"][start_idx:] for r in result])
return final_results

@ -1,6 +1,6 @@
"""Model class."""
from abc import ABC, abstractmethod
from typing import Any, List
from typing import Any, Dict, List
class Model(ABC):
@ -18,6 +18,11 @@ class Model(ABC):
"""
raise NotImplementedError()
@abstractmethod
def get_init_params(self) -> Dict:
"""Return init params to determine what model is being used."""
raise NotImplementedError()
@abstractmethod
def generate(self, prompt: str, **kwargs: Any) -> List[str]:
"""

@ -25,17 +25,31 @@ class HuggingFaceClient(Client):
client_args: client arguments.
"""
self.host = connection_str.rstrip("/")
self.temperature = client_args.pop("temperature", 1.0)
self.temperature = client_args.pop("temperature", 0.00001)
self.max_tokens = client_args.pop("max_tokens", 10)
self.top_p = client_args.pop("top_p", 0)
self.top_k = client_args.pop("top_k", 0)
self.top_p = client_args.pop("top_p", 1.0)
self.top_k = client_args.pop("top_k", 50)
self.repetition_penalty = client_args.pop("repetition_penalty", 1.0)
self.n = client_args.pop("n", 1)
self.model_params = self.get_model_params()
def close(self) -> None:
"""Close the client."""
pass
def get_model_params(self) -> Dict:
"""
Get model params.
By getting model params from the server, we can add to request
and make sure cache keys are unique to model.
Returns:
model params.
"""
res = requests.post(self.host + "/params")
return res.json()
def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]:
"""
Get request string function.
@ -58,6 +72,7 @@ class HuggingFaceClient(Client):
),
"n": kwargs.get("n", self.n),
}
request_params.update(self.model_params)
def _run_completion() -> Dict:
post_str = self.host + "/completions"

41
poetry.lock generated

@ -984,6 +984,33 @@ torch-speech = ["torchaudio", "librosa", "pyctcdecode (>=0.3.0)", "phonemizer"]
torchhub = ["filelock", "huggingface-hub (>=0.1.0,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "torch (>=1.0)", "tokenizers (>=0.11.1,!=0.11.3,<0.13)", "tqdm (>=4.27)"]
vision = ["pillow"]
[[package]]
name = "types-redis"
version = "4.2.6"
description = "Typing stubs for redis"
category = "main"
optional = false
python-versions = "*"
[[package]]
name = "types-requests"
version = "2.27.29"
description = "Typing stubs for requests"
category = "main"
optional = false
python-versions = "*"
[package.dependencies]
types-urllib3 = "<1.27"
[[package]]
name = "types-urllib3"
version = "1.26.15"
description = "Typing stubs for urllib3"
category = "main"
optional = false
python-versions = "*"
[[package]]
name = "typing-extensions"
version = "4.2.0"
@ -1057,7 +1084,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-
[metadata]
lock-version = "1.1"
python-versions = "^3.8"
content-hash = "086335f2b487e195c42ac25f142c0cca1318550afcd8680e5c90be63b265c820"
content-hash = "9f1c7010530f8668850294c74a7dd60fdaad2086901a5bd0f4438b870666c1a7"
[metadata.files]
alabaster = [
@ -1687,6 +1714,18 @@ transformers = [
{file = "transformers-4.19.2-py3-none-any.whl", hash = "sha256:1416315b7c5ff1f56d3915f416b67aa254a9907fbb73ef7f7bffc9210446b5fa"},
{file = "transformers-4.19.2.tar.gz", hash = "sha256:e19a4ff07458eda143c738e5259caf48449fcf078a63d6b1bd1aa806543440a3"},
]
types-redis = [
{file = "types-redis-4.2.6.tar.gz", hash = "sha256:d6adc77185cf40b300816767a64c0ee9ee0b21dc174e8e5c23b7e83d43189cb8"},
{file = "types_redis-4.2.6-py3-none-any.whl", hash = "sha256:1136af954ade0be33b487f440c8cbcbee29f089a83e685484ec91f363c6c69fe"},
]
types-requests = [
{file = "types-requests-2.27.29.tar.gz", hash = "sha256:fb453b3a76a48eca66381cea8004feaaea12835e838196f5c7ac87c75c5c19ef"},
{file = "types_requests-2.27.29-py3-none-any.whl", hash = "sha256:014f4f82db7b96c41feea9adaea30e68cd64c230eeab34b70c29bebb26ec74ac"},
]
types-urllib3 = [
{file = "types-urllib3-1.26.15.tar.gz", hash = "sha256:c89283541ef92e344b7f59f83ea9b5a295b16366ceee3f25ecfc5593c79f794e"},
{file = "types_urllib3-1.26.15-py3-none-any.whl", hash = "sha256:6011befa13f901fc934f59bb1fd6973be6f3acf4ebfce427593a27e7f492918f"},
]
typing-extensions = [
{file = "typing_extensions-4.2.0-py3-none-any.whl", hash = "sha256:6657594ee297170d19f67d55c05852a874e7eb634f4f753dbd667855e07c1708"},
{file = "typing_extensions-4.2.0.tar.gz", hash = "sha256:f1c24655a0da0d1b67f07e17a5e6b2a105894e6824b92096378bb3668ef02376"},

@ -26,6 +26,8 @@ transformers = "^4.19.2"
torch = "^1.11.0"
requests = "^2.27.1"
tqdm = "^4.64.0"
types-redis = "^4.2.6"
types-requests = "^2.27.29"
[tool.poetry.dev-dependencies]
black = "^22.3.0"

Loading…
Cancel
Save