mirror of
https://github.com/HazyResearch/manifest
synced 2024-11-02 09:40:58 +00:00
commit
1f6d9250fe
1
Makefile
1
Makefile
@ -1,6 +1,7 @@
|
||||
dev:
|
||||
poetry install
|
||||
poetry run pre-commit install
|
||||
poetry run mypy --install-types
|
||||
|
||||
test: dev check
|
||||
poetry run pytest tests
|
||||
|
40
README.md
40
README.md
@ -3,19 +3,19 @@ Prompt programming with FMs.
|
||||
|
||||
# Install
|
||||
Download the code:
|
||||
```
|
||||
```bash
|
||||
git clone git@github.com:HazyResearch/manifest.git
|
||||
cd manifest
|
||||
```
|
||||
|
||||
Install:
|
||||
```
|
||||
```bash
|
||||
pip install poetry
|
||||
poetry install
|
||||
poetry run pre-commit install
|
||||
```
|
||||
or
|
||||
```
|
||||
```bash
|
||||
pip install poetry
|
||||
make dev
|
||||
```
|
||||
@ -28,14 +28,14 @@ Manifest is meant to be a very light weight package to help with prompt iteratio
|
||||
|
||||
## Prompts
|
||||
A Manifest prompt is a function that accepts a single input to generate a string prompt to send to a model.
|
||||
```
|
||||
```python
|
||||
from manifest import Prompt
|
||||
prompt = Prompt(lambda x: "Hello, my name is {x}")
|
||||
print(promt("Laurel"))
|
||||
print(prompt("Laurel"))
|
||||
>>> "Hello, my name is Laurel"
|
||||
```
|
||||
We also let you use static strings
|
||||
```
|
||||
```python
|
||||
prompt = Prompt("Hello, my name is static")
|
||||
print(prompt())
|
||||
>>> "Hello, my name is static"
|
||||
@ -46,13 +46,13 @@ print(prompt())
|
||||
## Sessions
|
||||
|
||||
Each Manifest run is a session that connects to a model endpoint and backend database to record prompt queries. To start a Manifest session for OpenAI, make sure you run
|
||||
```
|
||||
```bash
|
||||
export OPENAI_API_KEY=<OPENAIKEY>
|
||||
```
|
||||
so we can access OpenAI.
|
||||
|
||||
Then, in a notebook, run:
|
||||
```
|
||||
```python
|
||||
from manifest import Manifest
|
||||
|
||||
manifest = Manifest(
|
||||
@ -64,7 +64,7 @@ manifest = Manifest(
|
||||
This will start a session with OpenAI and save all results to a local file called `sqlite.cache`.
|
||||
|
||||
We also support a Redis backend. If you have a Redis database running on port 6379, run
|
||||
```
|
||||
```python
|
||||
manifest = Manifest(
|
||||
client_name = "openai",
|
||||
cache_name = "redis",
|
||||
@ -77,18 +77,18 @@ We will explain [below](#huggingface-models) how to use Manifest for a locally h
|
||||
|
||||
Once you have a session open, you can write and develop prompts.
|
||||
|
||||
```
|
||||
```python
|
||||
prompt = Prompt(lambda x: "Hello, my name is {x}")
|
||||
result = manifest.run(prompt, "Laurel")
|
||||
```
|
||||
|
||||
You can also run over multiple examples.
|
||||
```
|
||||
```python
|
||||
results = manifest.batch_run(prompt, ["Laurel", "Avanika"])
|
||||
```
|
||||
|
||||
If something doesn't go right, you can also ask to get a raw manifest Response.
|
||||
```
|
||||
```python
|
||||
result_object = manifest.batch_run(prompt, ["Laurel", "Avanika"], return_response=True)
|
||||
print(result_object.get_request())
|
||||
print(result_object.is_cached())
|
||||
@ -96,24 +96,24 @@ print(result_object.get_response())
|
||||
```
|
||||
|
||||
By default, we do not truncate results based on a stop token. You can change this by either passing a new stop token to a Manifest session or to a `run` or `batch_run`. If you set the stop token to `""`, we will not truncate the model output.
|
||||
```
|
||||
```python
|
||||
result = manifest.run(prompt, "Laurel", stop_token="and")
|
||||
```
|
||||
|
||||
If you want to change default parameters to a model, we pass those as `kwargs` to the client.
|
||||
```
|
||||
```python
|
||||
result = manifest.run(prompt, "Laurel", max_tokens=50)
|
||||
```
|
||||
# Huggingface Models
|
||||
To use a HuggingFace generative model, in `manifest/api` we have a Falsk application that hosts the models for you.
|
||||
|
||||
In a separate terminal or Tmux/Screen session, run
|
||||
```
|
||||
```python
|
||||
python3 manifest/api/app.py --model_type huggingface --model_name EleutherAI/gpt-j-6B --device 0
|
||||
```
|
||||
You will see the Flask session start and output a URL `http://127.0.0.1:5000`. Pass this in to Manifest. If you want to use a different port, set the `FLASK_PORT` environment variable.
|
||||
|
||||
```
|
||||
```python
|
||||
manifest = Manifest(
|
||||
client_name = "huggingface",
|
||||
client_connection = "http://127.0.0.1:5000",
|
||||
@ -122,11 +122,13 @@ manifest = Manifest(
|
||||
)
|
||||
```
|
||||
|
||||
If you have a custom model you trained, pass the model path to `--model_name`.
|
||||
|
||||
**Auto deployment coming soon**
|
||||
|
||||
# Development
|
||||
Before submitting a PR, run
|
||||
```
|
||||
```bash
|
||||
export REDIS_PORT="6380" # or whatever PORT local redis is running for those tests
|
||||
cd <REDIS_PATH>
|
||||
docker run -d -p 127.0.0.1:${REDIS_PORT}:6380 -v `pwd`:`pwd` -w `pwd` --name manifest_redis_test redis
|
||||
@ -134,12 +136,12 @@ make test
|
||||
```
|
||||
|
||||
To use our development Redis database, email [Laurel](lorr1@cs.stanford.edu). If you have access to our GCP account, in a separate terminal, run
|
||||
```
|
||||
```bash
|
||||
gcloud compute ssh "manifest-connect" --zone "europe-west4-a" --project "hai-gcp-head-models" -- -N -L 6379:10.152.93.107:6379
|
||||
```
|
||||
|
||||
Then if you issue
|
||||
```
|
||||
```bash
|
||||
redis-cli ping
|
||||
```
|
||||
You should see a `PONG` response from our database.
|
||||
|
@ -81,6 +81,12 @@ def completions() -> Dict:
|
||||
return OpenAIResponse(results).__dict__()
|
||||
|
||||
|
||||
@app.route("/params", methods=["POST"])
|
||||
def params() -> Dict:
|
||||
"""Get model params."""
|
||||
return model.get_init_params()
|
||||
|
||||
|
||||
@app.route("/")
|
||||
def index() -> str:
|
||||
"""Get index completion."""
|
||||
|
@ -1,7 +1,10 @@
|
||||
"""Huggingface model."""
|
||||
from typing import Any, List
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from transformers import (
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoTokenizer,
|
||||
GPT2LMHeadModel,
|
||||
GPTJForCausalLM,
|
||||
@ -17,6 +20,18 @@ MODEL_REGISTRY = {
|
||||
"EleutherAI/gpt-neo-1.3B": GPTNeoForCausalLM,
|
||||
"EleutherAI/gpt-neo-2.7B": GPTNeoForCausalLM,
|
||||
"gpt2": GPT2LMHeadModel,
|
||||
"bigscience/T0pp": AutoModelForSeq2SeqLM,
|
||||
"bigscience/T0_3B": AutoModelForSeq2SeqLM,
|
||||
}
|
||||
|
||||
MODEL_PIPELINE = {
|
||||
"EleutherAI/gpt-j-6B": "text-generation",
|
||||
"EleutherAI/gpt-neo-125M": "text-generation",
|
||||
"EleutherAI/gpt-neo-1.3B": "text-generation",
|
||||
"EleutherAI/gpt-neo-2.7B": "text-generation",
|
||||
"gpt2": "text-generation",
|
||||
"bigscience/T0pp": "text2text-generation",
|
||||
"bigscience/T0_3B": "text2text-generation",
|
||||
}
|
||||
|
||||
|
||||
@ -32,13 +47,27 @@ class HuggingFaceModel(Model):
|
||||
Args:
|
||||
model_name: model name string.
|
||||
"""
|
||||
# Check if providing path
|
||||
self.model_path = model_name
|
||||
if Path(self.model_path).exists() and Path(self.model_path).is_dir():
|
||||
# Try to find config
|
||||
if (Path(self.model_path) / "config.json").exists():
|
||||
config = json.load(open(Path(self.model_path) / "config.json"))
|
||||
model_name = config["_name_or_path"]
|
||||
self.model_name = model_name
|
||||
print("Model Name:", self.model_name, "Model Path:", self.model_path)
|
||||
model = MODEL_REGISTRY[model_name].from_pretrained(
|
||||
model_name, cache_dir=cache_dir
|
||||
self.model_path, cache_dir=cache_dir
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.pipeline = pipeline(
|
||||
"text-generation", model=model, tokenizer=tokenizer, device=device
|
||||
MODEL_PIPELINE[model_name], model=model, tokenizer=tokenizer, device=device
|
||||
)
|
||||
self.returns_input = MODEL_PIPELINE[model_name] == "text-generation"
|
||||
|
||||
def get_init_params(self) -> Dict:
|
||||
"""Return init params to determine what model is being used."""
|
||||
return {"model_name": self.model_name, "model_path": self.model_path}
|
||||
|
||||
def generate(self, prompt: str, **kwargs: Any) -> List[str]:
|
||||
"""
|
||||
@ -66,14 +95,12 @@ class HuggingFaceModel(Model):
|
||||
top_p=kwargs.get("top_p"),
|
||||
num_return_sequences=num_return,
|
||||
)
|
||||
# Removes tokens removed from tokenization
|
||||
decoded_prompt = self.pipeline.tokenizer.decode(
|
||||
encoded_prompt, clean_up_tokenization_spaces=True
|
||||
)
|
||||
if num_return == 1:
|
||||
final_results.append(result[0]["generated_text"][len(decoded_prompt) :])
|
||||
if self.returns_input:
|
||||
start_idx = len(prompt)
|
||||
else:
|
||||
final_results.append(
|
||||
[r["generated_text"][len(decoded_prompt) :] for r in result]
|
||||
)
|
||||
start_idx = 0
|
||||
if num_return == 1:
|
||||
final_results.append(result[0]["generated_text"][start_idx:])
|
||||
else:
|
||||
final_results.append([r["generated_text"][start_idx:] for r in result])
|
||||
return final_results
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Model class."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
class Model(ABC):
|
||||
@ -18,6 +18,11 @@ class Model(ABC):
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def get_init_params(self) -> Dict:
|
||||
"""Return init params to determine what model is being used."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, prompt: str, **kwargs: Any) -> List[str]:
|
||||
"""
|
||||
|
@ -61,19 +61,20 @@ def key_to_response(key: str) -> Dict:
|
||||
class Cache(ABC):
|
||||
"""A cache for request/response pairs."""
|
||||
|
||||
def __init__(self, connection_str: str, **kwargs: Any):
|
||||
def __init__(self, connection_str: str, cache_args: Dict[str, Any] = {}):
|
||||
"""
|
||||
Initialize client.
|
||||
|
||||
kwargs are passed to client as default parameters.
|
||||
cache_args are passed to client as default parameters.
|
||||
|
||||
For clients like OpenAI that do not require a connection,
|
||||
the connection_str can be None.
|
||||
|
||||
Args:
|
||||
connection_str: connection string for client.
|
||||
cache_args: cache arguments.
|
||||
"""
|
||||
self.connect(connection_str, **kwargs)
|
||||
self.connect(connection_str, cache_args)
|
||||
|
||||
@abstractmethod
|
||||
def close(self) -> None:
|
||||
@ -81,7 +82,7 @@ class Cache(ABC):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def connect(self, connection_str: str, **kwargs: Any) -> None:
|
||||
def connect(self, connection_str: str, cache_args: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Connect to client.
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""Redis cache."""
|
||||
from typing import Any, Union
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import redis
|
||||
|
||||
@ -9,12 +9,13 @@ from manifest.caches import Cache
|
||||
class RedisCache(Cache):
|
||||
"""A Redis cache for request/response pairs."""
|
||||
|
||||
def connect(self, connection_str: str, **kwargs: Any) -> None:
|
||||
def connect(self, connection_str: str, cache_args: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Connect to client.
|
||||
|
||||
Args:
|
||||
connection_str: connection string.
|
||||
cache_args: cache arguments.
|
||||
"""
|
||||
host, port = connection_str.split(":")
|
||||
self.redis = redis.Redis(host=host, port=int(port), db=0)
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""SQLite cache."""
|
||||
import logging
|
||||
from typing import Any, Union
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from sqlitedict import SqliteDict
|
||||
|
||||
@ -12,12 +12,13 @@ logging.getLogger("sqlitedict").setLevel(logging.WARNING)
|
||||
class SQLiteCache(Cache):
|
||||
"""A SQLite cache for request/response pairs."""
|
||||
|
||||
def connect(self, connection_str: str, **kwargs: Any) -> None:
|
||||
def connect(self, connection_str: str, cache_args: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Connect to client.
|
||||
|
||||
Args:
|
||||
connection_str: connection string.
|
||||
cache_args: cache arguments.
|
||||
"""
|
||||
self.cache_file = connection_str
|
||||
self.cache = SqliteDict(self.cache_file, autocommit=True)
|
||||
|
@ -6,7 +6,9 @@ from typing import Any, Callable, Dict, Optional, Tuple
|
||||
class Client(ABC):
|
||||
"""Client class."""
|
||||
|
||||
def __init__(self, connection_str: Optional[str] = None, **kwargs: Any):
|
||||
def __init__(
|
||||
self, connection_str: Optional[str] = None, client_args: Dict[str, Any] = {}
|
||||
):
|
||||
"""
|
||||
Initialize client.
|
||||
|
||||
@ -17,8 +19,9 @@ class Client(ABC):
|
||||
|
||||
Args:
|
||||
connection_str: connection string for client.
|
||||
client_args: client arguments.
|
||||
"""
|
||||
self.connect(connection_str, **kwargs)
|
||||
self.connect(connection_str, client_args)
|
||||
|
||||
@abstractmethod
|
||||
def close(self) -> None:
|
||||
@ -26,7 +29,7 @@ class Client(ABC):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def connect(self, connection_str: str, **kwargs: Any) -> None:
|
||||
def connect(self, connection_str: str, client_args: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Connect to client.
|
||||
|
||||
|
@ -13,15 +13,18 @@ class DummyClient(Client):
|
||||
def connect(
|
||||
self,
|
||||
connection_str: Optional[str] = None,
|
||||
num_results: Optional[int] = 1,
|
||||
**kwargs: Any,
|
||||
client_args: Dict[str, Any] = {},
|
||||
) -> None:
|
||||
"""
|
||||
Connect to dummpy server.
|
||||
|
||||
This is a dummy client that returns identity responses. Used for testing.
|
||||
|
||||
Args:
|
||||
connection_str: connection string.
|
||||
client_args: client arguments.
|
||||
"""
|
||||
self.num_results = num_results
|
||||
self.num_results = client_args.pop("num_results", 1)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the client."""
|
||||
|
@ -15,27 +15,41 @@ class HuggingFaceClient(Client):
|
||||
def connect(
|
||||
self,
|
||||
connection_str: Optional[str] = None,
|
||||
temperature: Optional[float] = 1.0,
|
||||
max_tokens: Optional[int] = 10,
|
||||
top_p: Optional[float] = 1.0,
|
||||
top_k: Optional[int] = 0,
|
||||
repetition_penalty: Optional[float] = 1.0,
|
||||
n: Optional[int] = 1,
|
||||
**kwargs: Any,
|
||||
client_args: Dict[str, Any] = {},
|
||||
) -> None:
|
||||
"""Connect to the HuggingFace url."""
|
||||
"""
|
||||
Connect to the HuggingFace url.
|
||||
|
||||
Arsg:
|
||||
connection_str: connection string.
|
||||
client_args: client arguments.
|
||||
"""
|
||||
self.host = connection_str.rstrip("/")
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
self.repetition_penalty = repetition_penalty
|
||||
self.n = n
|
||||
self.temperature = client_args.pop("temperature", 0.00001)
|
||||
self.max_tokens = client_args.pop("max_tokens", 10)
|
||||
self.top_p = client_args.pop("top_p", 1.0)
|
||||
self.top_k = client_args.pop("top_k", 50)
|
||||
self.repetition_penalty = client_args.pop("repetition_penalty", 1.0)
|
||||
self.n = client_args.pop("n", 1)
|
||||
self.model_params = self.get_model_params()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the client."""
|
||||
pass
|
||||
|
||||
def get_model_params(self) -> Dict:
|
||||
"""
|
||||
Get model params.
|
||||
|
||||
By getting model params from the server, we can add to request
|
||||
and make sure cache keys are unique to model.
|
||||
|
||||
Returns:
|
||||
model params.
|
||||
"""
|
||||
res = requests.post(self.host + "/params")
|
||||
return res.json()
|
||||
|
||||
def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]:
|
||||
"""
|
||||
Get request string function.
|
||||
@ -58,6 +72,7 @@ class HuggingFaceClient(Client):
|
||||
),
|
||||
"n": kwargs.get("n", self.n),
|
||||
}
|
||||
request_params.update(self.model_params)
|
||||
|
||||
def _run_completion() -> Dict:
|
||||
post_str = self.host + "/completions"
|
||||
|
@ -24,35 +24,32 @@ class OpenAIClient(Client):
|
||||
def connect(
|
||||
self,
|
||||
connection_str: Optional[str] = None,
|
||||
engine: Optional[str] = "text-ada-001",
|
||||
temperature: Optional[float] = 0.0,
|
||||
max_tokens: Optional[int] = 10,
|
||||
top_p: Optional[float] = 1.0,
|
||||
frequency_penalty: Optional[int] = 0,
|
||||
presence_penalty: Optional[int] = 0,
|
||||
n: Optional[int] = 1,
|
||||
**kwargs: Any,
|
||||
client_args: Dict[str, Any] = {},
|
||||
) -> None:
|
||||
"""
|
||||
Connect to the OpenAI server.
|
||||
|
||||
connection_str is passed as default OPENAI_API_KEY if variable not set.
|
||||
|
||||
Args:
|
||||
connection_str: connection string.
|
||||
client_args: client arguments.
|
||||
"""
|
||||
openai.api_key = os.environ.get("OPENAI_API_KEY", connection_str)
|
||||
if openai.api_key is None:
|
||||
raise ValueError(
|
||||
"OpenAI API key not set. Set OPENAI_API_KEY environment ",
|
||||
"svariable or pass through `connection_str`.",
|
||||
"OpenAI API key not set. Set OPENAI_API_KEY environment "
|
||||
"variable or pass through `connection_str`."
|
||||
)
|
||||
self.engine = engine
|
||||
self.engine = client_args.pop("engine", "text-davinci-002")
|
||||
if self.engine not in OPENAI_ENGINES:
|
||||
raise ValueError(f"Invalid engine {self.engine}. Must be {OPENAI_ENGINES}.")
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.top_p = top_p
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.presence_penalty = presence_penalty
|
||||
self.n = n
|
||||
self.temperature = client_args.pop("temperature", 0.0)
|
||||
self.max_tokens = client_args.pop("max_tokens", 10)
|
||||
self.top_p = client_args.pop("top_p", 1.0)
|
||||
self.frequency_penalty = client_args.pop("frequency_penalty", 0.0)
|
||||
self.presence_penalty = client_args.pop("presence_penalty", 0.0)
|
||||
self.n = client_args.pop("n", 1)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the client."""
|
||||
|
61
manifest/clients/opt.py
Normal file
61
manifest/clients/opt.py
Normal file
@ -0,0 +1,61 @@
|
||||
"""OpenAI client."""
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, Optional, Tuple
|
||||
|
||||
import requests
|
||||
|
||||
from manifest.clients.client import Client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OPTClient(Client):
|
||||
"""OPT client."""
|
||||
|
||||
def connect(
|
||||
self,
|
||||
connection_str: Optional[str] = None,
|
||||
client_args: Dict[str, Any] = {},
|
||||
) -> None:
|
||||
"""
|
||||
Connect to the OPT url.
|
||||
|
||||
Arsg:
|
||||
connection_str: connection string.
|
||||
client_args: client arguments.
|
||||
"""
|
||||
self.host = connection_str.rstrip("/")
|
||||
self.temperature = client_args.pop("temperature", 1.0)
|
||||
self.max_tokens = client_args.pop("max_tokens", 10)
|
||||
self.top_p = client_args.pop("top_p", 0)
|
||||
self.n = client_args.pop("n", 1)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the client."""
|
||||
pass
|
||||
|
||||
def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]:
|
||||
"""
|
||||
Get request string function.
|
||||
|
||||
Args:
|
||||
query: query string.
|
||||
|
||||
Returns:
|
||||
request function that takes no input.
|
||||
request parameters as dict.
|
||||
"""
|
||||
request_params = {
|
||||
"prompt": query,
|
||||
"temperature": kwargs.get("temperature", self.temperature),
|
||||
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
|
||||
"top_p": kwargs.get("top_p", self.top_p),
|
||||
"n": kwargs.get("n", self.n),
|
||||
}
|
||||
|
||||
def _run_completion() -> Dict:
|
||||
post_str = self.host + "/completions"
|
||||
res = requests.post(post_str, json=request_params)
|
||||
return res.json()
|
||||
|
||||
return _run_completion, request_params
|
@ -2,21 +2,24 @@
|
||||
import logging
|
||||
from typing import Any, Iterable, List, Optional, Union
|
||||
|
||||
from manifest.response import Response
|
||||
|
||||
logging.getLogger("openai").setLevel(logging.WARNING)
|
||||
logger = logging.getLogger(__name__)
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from manifest.caches.redis import RedisCache
|
||||
from manifest.caches.sqlite import SQLiteCache
|
||||
from manifest.clients.dummy import DummyClient
|
||||
from manifest.clients.huggingface import HuggingFaceClient
|
||||
from manifest.clients.openai import OpenAIClient
|
||||
from manifest.clients.opt import OPTClient
|
||||
from manifest.prompt import Prompt
|
||||
from manifest.response import Response
|
||||
|
||||
logging.getLogger("openai").setLevel(logging.WARNING)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CLIENT_CONSTRUCTORS = {
|
||||
"openai": OpenAIClient,
|
||||
"huggingface": HuggingFaceClient,
|
||||
"opt": OPTClient,
|
||||
"dummy": DummyClient,
|
||||
}
|
||||
|
||||
@ -62,10 +65,14 @@ class Manifest:
|
||||
f"Choices are {list(CACHE_CONSTRUCTORS.keys())}"
|
||||
)
|
||||
self.client_name = client_name
|
||||
# Must pass kwargs as dict to client "pop" methods removed used arguments
|
||||
self.client = CLIENT_CONSTRUCTORS[client_name]( # type: ignore
|
||||
client_connection, **kwargs
|
||||
client_connection, client_args=kwargs
|
||||
)
|
||||
self.cache = CACHE_CONSTRUCTORS[cache_name](cache_connection, **kwargs)
|
||||
self.cache = CACHE_CONSTRUCTORS[cache_name](cache_connection, cache_args=kwargs)
|
||||
if len(kwargs) > 0:
|
||||
raise ValueError(f"{list(kwargs.items())} arguments are not recognized.")
|
||||
|
||||
self.stop_token = stop_token
|
||||
|
||||
def close(self) -> None:
|
||||
@ -119,6 +126,7 @@ class Manifest:
|
||||
overwrite_cache: bool = False,
|
||||
stop_token: Optional[str] = None,
|
||||
return_response: bool = False,
|
||||
verbose: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Iterable[Union[str, List[str], Response]]:
|
||||
"""
|
||||
@ -141,7 +149,7 @@ class Manifest:
|
||||
self.run(
|
||||
prompt, inp, overwrite_cache, stop_token, return_response, **kwargs
|
||||
)
|
||||
for inp in input
|
||||
for inp in tqdm(input, desc="Running batch", disable=not verbose)
|
||||
]
|
||||
|
||||
def save_prompt(self, name: str, prompt: Prompt) -> None:
|
||||
|
41
poetry.lock
generated
41
poetry.lock
generated
@ -984,6 +984,33 @@ torch-speech = ["torchaudio", "librosa", "pyctcdecode (>=0.3.0)", "phonemizer"]
|
||||
torchhub = ["filelock", "huggingface-hub (>=0.1.0,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "torch (>=1.0)", "tokenizers (>=0.11.1,!=0.11.3,<0.13)", "tqdm (>=4.27)"]
|
||||
vision = ["pillow"]
|
||||
|
||||
[[package]]
|
||||
name = "types-redis"
|
||||
version = "4.2.6"
|
||||
description = "Typing stubs for redis"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
|
||||
[[package]]
|
||||
name = "types-requests"
|
||||
version = "2.27.29"
|
||||
description = "Typing stubs for requests"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
|
||||
[package.dependencies]
|
||||
types-urllib3 = "<1.27"
|
||||
|
||||
[[package]]
|
||||
name = "types-urllib3"
|
||||
version = "1.26.15"
|
||||
description = "Typing stubs for urllib3"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
|
||||
[[package]]
|
||||
name = "typing-extensions"
|
||||
version = "4.2.0"
|
||||
@ -1057,7 +1084,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-
|
||||
[metadata]
|
||||
lock-version = "1.1"
|
||||
python-versions = "^3.8"
|
||||
content-hash = "b0f92eed3a0f80cc9976f68098b772a19d19b4742426299bd63cb1c636d8a84a"
|
||||
content-hash = "9f1c7010530f8668850294c74a7dd60fdaad2086901a5bd0f4438b870666c1a7"
|
||||
|
||||
[metadata.files]
|
||||
alabaster = [
|
||||
@ -1687,6 +1714,18 @@ transformers = [
|
||||
{file = "transformers-4.19.2-py3-none-any.whl", hash = "sha256:1416315b7c5ff1f56d3915f416b67aa254a9907fbb73ef7f7bffc9210446b5fa"},
|
||||
{file = "transformers-4.19.2.tar.gz", hash = "sha256:e19a4ff07458eda143c738e5259caf48449fcf078a63d6b1bd1aa806543440a3"},
|
||||
]
|
||||
types-redis = [
|
||||
{file = "types-redis-4.2.6.tar.gz", hash = "sha256:d6adc77185cf40b300816767a64c0ee9ee0b21dc174e8e5c23b7e83d43189cb8"},
|
||||
{file = "types_redis-4.2.6-py3-none-any.whl", hash = "sha256:1136af954ade0be33b487f440c8cbcbee29f089a83e685484ec91f363c6c69fe"},
|
||||
]
|
||||
types-requests = [
|
||||
{file = "types-requests-2.27.29.tar.gz", hash = "sha256:fb453b3a76a48eca66381cea8004feaaea12835e838196f5c7ac87c75c5c19ef"},
|
||||
{file = "types_requests-2.27.29-py3-none-any.whl", hash = "sha256:014f4f82db7b96c41feea9adaea30e68cd64c230eeab34b70c29bebb26ec74ac"},
|
||||
]
|
||||
types-urllib3 = [
|
||||
{file = "types-urllib3-1.26.15.tar.gz", hash = "sha256:c89283541ef92e344b7f59f83ea9b5a295b16366ceee3f25ecfc5593c79f794e"},
|
||||
{file = "types_urllib3-1.26.15-py3-none-any.whl", hash = "sha256:6011befa13f901fc934f59bb1fd6973be6f3acf4ebfce427593a27e7f492918f"},
|
||||
]
|
||||
typing-extensions = [
|
||||
{file = "typing_extensions-4.2.0-py3-none-any.whl", hash = "sha256:6657594ee297170d19f67d55c05852a874e7eb634f4f753dbd667855e07c1708"},
|
||||
{file = "typing_extensions-4.2.0.tar.gz", hash = "sha256:f1c24655a0da0d1b67f07e17a5e6b2a105894e6824b92096378bb3668ef02376"},
|
||||
|
@ -25,6 +25,9 @@ Flask = "^2.1.2"
|
||||
transformers = "^4.19.2"
|
||||
torch = "^1.11.0"
|
||||
requests = "^2.27.1"
|
||||
tqdm = "^4.64.0"
|
||||
types-redis = "^4.2.6"
|
||||
types-requests = "^2.27.29"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
black = "^22.3.0"
|
||||
@ -55,6 +58,7 @@ module = [
|
||||
"tqdm",
|
||||
"sqlitedict",
|
||||
"dill",
|
||||
"tqdm.auto",
|
||||
]
|
||||
|
||||
[tool.isort]
|
||||
|
@ -9,13 +9,15 @@ from manifest.clients.dummy import DummyClient
|
||||
|
||||
def test_init():
|
||||
"""Test client initialization."""
|
||||
client = DummyClient(connection_str=None, num_results=3)
|
||||
args = {"num_results": 3}
|
||||
client = DummyClient(connection_str=None, client_args=args)
|
||||
assert client.num_results == 3
|
||||
|
||||
|
||||
def test_get_request():
|
||||
"""Test client get request."""
|
||||
client = DummyClient(connection_str=None, num_results=3)
|
||||
args = {"num_results": 3}
|
||||
client = DummyClient(connection_str=None, client_args=args)
|
||||
request_func, request_params = client.get_request("hello")
|
||||
assert request_params == {"prompt": "hello", "num_results": 3}
|
||||
assert request_func() == {"choices": [{"text": "hello"}] * 3}
|
||||
|
@ -10,6 +10,15 @@ from manifest.clients.dummy import DummyClient
|
||||
@pytest.mark.usefixtures("sqlite_cache")
|
||||
def test_init(sqlite_cache):
|
||||
"""Test manifest initialization."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
Manifest(
|
||||
client_name="dummy",
|
||||
cache_name="sqlite",
|
||||
cache_connection=sqlite_cache,
|
||||
sep_tok="",
|
||||
)
|
||||
assert str(exc_info.value) == "[('sep_tok', '')] arguments are not recognized."
|
||||
|
||||
manifest = Manifest(
|
||||
client_name="dummy",
|
||||
cache_name="sqlite",
|
||||
|
Loading…
Reference in New Issue
Block a user