Merge pull request #2 from HazyResearch/laurel/clients

[feature] redis DB, flask API, tests
This commit is contained in:
Laurel Orr 2022-05-25 20:54:26 -07:00 committed by GitHub
commit 9dd292f2b1
31 changed files with 1191 additions and 234 deletions

View File

@ -2,9 +2,10 @@
# - E731: do not assign a lambda expression, use a def # - E731: do not assign a lambda expression, use a def
# - E402: module level import not at top of file # - E402: module level import not at top of file
# - W503: line break before binary operator # - W503: line break before binary operator
# - E203: whitespace before :
[flake8] [flake8]
exclude = .git exclude = .git
max-line-length = 88 max-line-length = 88
ignore = E731, E402, W503 ignore = E731, E402, W503, E203
per-file-ignores = __init__.py:F401 per-file-ignores = __init__.py:F401

View File

@ -3,18 +3,17 @@ dev:
poetry run pre-commit install poetry run pre-commit install
test: dev check test: dev check
poetry install
poetry run pytest tests poetry run pytest tests
format: format:
isort --atomic manifest/ tests/ poetry run isort --atomic manifest/ tests/
black manifest/ tests/ poetry run black manifest/ tests/
check: check:
isort -c manifest/ tests/ poetry run isort -c manifest/ tests/
black manifest/ tests/ --check poetry run black manifest/ tests/ --check
flake8 manifest/ tests/ poetry run flake8 manifest/ tests/
mypy manifest/ poetry run mypy manifest/
clean: clean:
pip uninstall -y manifest pip uninstall -y manifest

119
README.md
View File

@ -19,10 +19,127 @@ or
pip install poetry pip install poetry
make dev make dev
``` ```
# Run
Manifest is meant to be a very light weight package to help with prompt iteration. Two key design decisions are
* Prompt are functional -- they can take an input example and dynamically change
* All models are behind API calls (e.g., OpenAI)
* Everything is cached for reuse to both save credits and to explore past results
## Prompts
A Manifest prompt is a function that accepts a single input to generate a string prompt to send to a model.
```
from manifest import Prompt
prompt = Prompt(lambda x: "Hello, my name is {x}")
print(promt("Laurel"))
>>> "Hello, my name is Laurel"
```
We also let you use static strings
```
prompt = Prompt("Hello, my name is static")
print(prompt())
>>> "Hello, my name is static"
```
**Chaining prompts coming soon**
## Sessions
Each Manifest run is a session that connects to a model endpoint and backend database to record prompt queries. To start a Manifest session for OpenAI, make sure you run
```
export OPENAI_API_KEY=<OPENAIKEY>
```
so we can access OpenAI.
Then, in a notebook, run:
```
from manifest import Manifest
manifest = Manifest(
client_name = "openai",
cache_name = "sqlite",
cache_connection = "sqlite.cache"
)
```
This will start a session with OpenAI and save all results to a local file called `sqlite.cache`.
We also support a Redis backend. If you have a Redis database running on port 6379, run
```
manifest = Manifest(
client_name = "openai",
cache_name = "redis",
cache_connection = "localhost:6379"
)
```
As a hint, if you want to get Redis running, see the `docker run` command below under development.
We will explain [below](#huggingface-models) how to use Manifest for a locally hosted HuggingFace model.
Once you have a session open, you can write and develop prompts.
```
prompt = Prompt(lambda x: "Hello, my name is {x}")
result = manifest.run(prompt, "Laurel")
```
You can also run over multiple examples.
```
results = manifest.batch_run(prompt, ["Laurel", "Avanika"])
```
If something doesn't go right, you can also ask to get a raw manifest Response.
```
result_object = manifest.batch_run(prompt, ["Laurel", "Avanika"], return_response=True)
print(result_object.get_request())
print(result_object.is_cached())
print(result_object.get_response())
```
By default, we do not truncate results based on a stop token. You can change this by either passing a new stop token to a Manifest session or to a `run` or `batch_run`. If you set the stop token to `""`, we will not truncate the model output.
```
result = manifest.run(prompt, "Laurel", stop_token="and")
```
If you want to change default parameters to a model, we pass those as `kwargs` to the client.
```
result = manifest.run(prompt, "Laurel", max_tokens=50)
```
# Huggingface Models
To use a HuggingFace generative model, in `manifest/api` we have a Falsk application that hosts the models for you.
In a separate terminal or Tmux/Screen session, run
```
python3 manifest/api/app.py --model_type huggingface --model_name EleutherAI/gpt-j-6B --device 0
```
You will see the Flask session start and output a URL `http://127.0.0.1:5000`. Pass this in to Manifest. If you want to use a different port, set the `FLASK_PORT` environment variable.
```
manifest = Manifest(
client_name = "huggingface",
client_connection = "http://127.0.0.1:5000",
cache_name = "redis",
cache_connection = "localhost:6379"
)
```
**Auto deployment coming soon**
# Development # Development
Before submitting a PR, run Before submitting a PR, run
``` ```
export REDIS_PORT="6379" # or whatever PORT local redis is running for those tests export REDIS_PORT="6380" # or whatever PORT local redis is running for those tests
cd <REDIS_PATH>
docker run -d -p 127.0.0.1:${REDIS_PORT}:6380 -v `pwd`:`pwd` -w `pwd` --name manifest_redis_test redis
make test make test
``` ```
To use our development Redis database, email [Laurel](lorr1@cs.stanford.edu). If you have access to our GCP account, in a separate terminal, run
```
gcloud compute ssh "manifest-connect" --zone "europe-west4-a" --project "hai-gcp-head-models" -- -N -L 6379:10.152.93.107:6379
```
Then if you issue
```
redis-cli ping
```
You should see a `PONG` response from our database.

View File

@ -1,3 +1,4 @@
"""Manifest init.""" """Manifest init."""
from manifest.manifest import Manifest from manifest.manifest import Manifest
from manifest.prompt import Prompt from manifest.prompt import Prompt
from manifest.response import Response

1
manifest/api/__init__.py Normal file
View File

@ -0,0 +1 @@
"""Api init."""

View File

@ -1 +1,93 @@
"""Flask app.""" """Flask app."""
import argparse
import logging
import os
from typing import Dict
import pkg_resources
from flask import Flask, request
from manifest.api.models.huggingface import HuggingFaceModel
from manifest.api.response import OpenAIResponse
os.environ["TOKENIZERS_PARALLELISM"] = "false"
logger = logging.getLogger(__name__)
app = Flask(__name__) # define app using Flask
# Will be global
model = None
PORT = int(os.environ.get("FLASK_PORT", 5000))
MODEL_CONSTRUCTORS = {
"huggingface": HuggingFaceModel,
}
def parse_args() -> argparse.Namespace:
"""Generate args."""
parser = argparse.ArgumentParser(description="Model args")
parser.add_argument(
"--model_type",
default=None,
type=str,
required=True,
help="Model type used for finding constructor.",
choices=["huggingface"],
)
parser.add_argument(
"--model_name",
default=None,
type=str,
required=True,
help="Name of model. Used in initialize of model class.",
)
parser.add_argument(
"--cache_dir", default=None, type=str, help="Cache directory for models."
)
parser.add_argument(
"--device", type=int, default=-1, help="Model device. -1 for CPU."
)
args = parser.parse_args()
return args
def main() -> None:
"""Run main."""
kwargs = parse_args()
model_type = kwargs.model_type
model_name = kwargs.model_name
# Global model
global model
model = MODEL_CONSTRUCTORS[model_type](
model_name, cache_dir=kwargs.cache_dir, device=kwargs.device
)
app.run(host="0.0.0.0", port=PORT)
@app.route("/completions", methods=["POST"])
def completions() -> Dict:
"""Get completions for generation."""
prompt = request.json["prompt"]
del request.json["prompt"]
generation_args = request.json
if not isinstance(prompt, str):
raise ValueError("Prompt must be a str")
results = []
for generations in model.generate(prompt, **generation_args):
results.append(generations)
# transform the result into the openai format
return OpenAIResponse(results).__dict__()
@app.route("/")
def index() -> str:
"""Get index completion."""
fn = pkg_resources.resource_filename("metaseq", "service/index.html")
with open(fn) as f:
return f.read()
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
"""Models init."""

View File

@ -0,0 +1,79 @@
"""Huggingface model."""
from typing import Any, List
from transformers import (
AutoTokenizer,
GPT2LMHeadModel,
GPTJForCausalLM,
GPTNeoForCausalLM,
pipeline,
)
from manifest.api.models.model import Model
MODEL_REGISTRY = {
"EleutherAI/gpt-j-6B": GPTJForCausalLM,
"EleutherAI/gpt-neo-125M": GPTNeoForCausalLM,
"EleutherAI/gpt-neo-1.3B": GPTNeoForCausalLM,
"EleutherAI/gpt-neo-2.7B": GPTNeoForCausalLM,
"gpt2": GPT2LMHeadModel,
}
class HuggingFaceModel(Model):
"""Huggingface model."""
def __init__(self, model_name: str, cache_dir: str, device: int):
"""
Initialize model.
All arguments will be passed in the request from Manifest.
Args:
model_name: model name string.
"""
model = MODEL_REGISTRY[model_name].from_pretrained(
model_name, cache_dir=cache_dir
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
self.pipeline = pipeline(
"text-generation", model=model, tokenizer=tokenizer, device=device
)
def generate(self, prompt: str, **kwargs: Any) -> List[str]:
"""
Generate the prompt from model.
Outputs must be generated text, not including prompt.
Args:
prompt: promt to generate from.
Returns:
list of generated text (list of length 1 for 1 generation).
"""
num_return = kwargs.get("n")
final_results = []
encoded_prompt = self.pipeline.tokenizer.encode(
prompt, add_special_tokens=False
)
result = self.pipeline(
prompt,
max_length=kwargs.get("max_tokens") + len(encoded_prompt),
temperature=kwargs.get("temperature"),
repetition_penalty=kwargs.get("repetition_penalty"),
top_k=kwargs.get("top_k"),
top_p=kwargs.get("top_p"),
num_return_sequences=num_return,
)
# Removes tokens removed from tokenization
decoded_prompt = self.pipeline.tokenizer.decode(
encoded_prompt, clean_up_tokenization_spaces=True
)
if num_return == 1:
final_results.append(result[0]["generated_text"][len(decoded_prompt) :])
else:
final_results.append(
[r["generated_text"][len(decoded_prompt) :] for r in result]
)
return final_results

View File

@ -1 +0,0 @@
"""Huggingface model."""

View File

@ -1 +1,34 @@
"""Model class.""" """Model class."""
from abc import ABC, abstractmethod
from typing import Any, List
class Model(ABC):
"""Model class."""
@abstractmethod
def __init__(self, model_name: str, **kwargs: Any):
"""
Initialize model.
kwargs are passed to model as default parameters.
Args:
model_name: model name string.
"""
raise NotImplementedError()
@abstractmethod
def generate(self, prompt: str, **kwargs: Any) -> List[str]:
"""
Generate the prompt from model.
Outputs must be generated text, not including prompt.
Args:
prompt: promt to generate from.
Returns:
list of generated text (list of length 1 for 1 generation).
"""
raise NotImplementedError()

38
manifest/api/response.py Normal file
View File

@ -0,0 +1,38 @@
"""OpenAI response."""
import time
import uuid
from typing import Any, Dict
class OpenAIResponse:
"""OpenAI response."""
def __init__(self, results: list) -> None:
"""Initialize response."""
self.results = results
self.response_id = str(uuid.uuid4())
self.created = int(time.time())
def __dict__(self) -> Dict[str, Any]: # type: ignore
"""Return dictionary representation of response."""
return {
"id": self.response_id,
"object": "text_completion",
"created": self.created,
"model": "flask_model",
"choices": [
{
"text": result,
# TODO: Add in more metadata for HF models
# "logprobs": {
# "tokens": result["tokens"],
# "token_logprobs": result["token_scores"],
# "text_offset": result["text_offset"],
# "top_logprobs": result["top_logprobs"],
# "finish_reason": "length",
# },
}
for result in self.results
],
}

View File

@ -1,9 +1,9 @@
"""Cache for queries and responses.""" """Cache for queries and responses."""
import json import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Tuple, Union from typing import Any, Callable, Dict, Union
from manifest.clients.response import Response from manifest.response import Response
def request_to_key(request: Dict) -> str: def request_to_key(request: Dict) -> str:
@ -32,6 +32,32 @@ def key_to_request(key: str) -> Dict:
return json.loads(key) return json.loads(key)
def response_to_key(response: Dict) -> str:
"""
Normalize a response into a key.
Args:
response: response to normalize.
Returns:
normalized key.
"""
return json.dumps(response, sort_keys=True)
def key_to_response(key: str) -> Dict:
"""
Convert the normalized version to the response.
Args:
key: normalized key to convert.
Returns:
unnormalized response dict.
"""
return json.loads(key)
class Cache(ABC): class Cache(ABC):
"""A cache for request/response pairs.""" """A cache for request/response pairs."""
@ -97,17 +123,17 @@ class Cache(ABC):
raise NotImplementedError() raise NotImplementedError()
def get( def get(
self, request: Dict, overwrite_cache: bool, compute: Callable[[], Response] self, request: Dict, overwrite_cache: bool, compute: Callable[[], Dict]
) -> Tuple[Response, bool]: ) -> Response:
"""Get the result of request (by calling compute as needed).""" """Get the result of request (by calling compute as needed)."""
key = request_to_key(request) key = request_to_key(request)
cached_response = self.get_key(key) cached_response = self.get_key(key)
if cached_response and not overwrite_cache: if cached_response and not overwrite_cache:
cached = True cached = True
response = Response.deserialize(cached_response) response = key_to_response(cached_response)
else: else:
# Type Response # Type Response
response = compute() response = compute()
self.set_key(key, response.serialize()) self.set_key(key, response_to_key(response))
cached = False cached = False
return response, cached return Response(response, cached, request)

View File

@ -17,13 +17,17 @@ class RedisCache(Cache):
connection_str: connection string. connection_str: connection string.
""" """
host, port = connection_str.split(":") host, port = connection_str.split(":")
self.redis = redis.Redis(host=host, port=int(port)) self.redis = redis.Redis(host=host, port=int(port), db=0)
return return
def close(self) -> None: def close(self) -> None:
"""Close the client.""" """Close the client."""
self.redis.close() self.redis.close()
def _normalize_table_key(self, key: str, table: str) -> str:
"""Cast key for prompt key."""
return f"{table}:{key}"
def get_key(self, key: str, table: str = "default") -> Union[str, None]: def get_key(self, key: str, table: str = "default") -> Union[str, None]:
""" """
Get the key for a request. Get the key for a request.
@ -32,8 +36,13 @@ class RedisCache(Cache):
Args: Args:
key: key for cache. key: key for cache.
table: table to get key in.
""" """
pass norm_key = self._normalize_table_key(key, table)
if self.redis.exists(norm_key):
return self.redis.get(norm_key).decode("utf-8")
else:
return None
def set_key(self, key: str, value: str, table: str = "default") -> None: def set_key(self, key: str, value: str, table: str = "default") -> None:
""" """
@ -44,8 +53,10 @@ class RedisCache(Cache):
Args: Args:
key: key for cache. key: key for cache.
value: new value for key. value: new value for key.
table: table to set key in.
""" """
self.redis[key] = value self.redis.set(self._normalize_table_key(key, table), value)
self.commit()
def commit(self) -> None: def commit(self) -> None:
"""Commit any results.""" """Commit any results."""

View File

@ -1,6 +1,5 @@
"""SQLite cache.""" """SQLite cache."""
import logging import logging
from pathlib import Path
from typing import Any, Union from typing import Any, Union
from sqlitedict import SqliteDict from sqlitedict import SqliteDict
@ -20,19 +19,18 @@ class SQLiteCache(Cache):
Args: Args:
connection_str: connection string. connection_str: connection string.
""" """
self.cache_dir = connection_str self.cache_file = connection_str
Path(self.cache_dir).mkdir(parents=True, exist_ok=True) self.cache = SqliteDict(self.cache_file, autocommit=True)
# If more than two tables, switch to full on SQL connection
self.query_file = Path(self.cache_dir, "query.sqlite")
self.prompt_file = Path(self.cache_dir, "prompts.sqlite")
self.cache = SqliteDict(self.query_file, autocommit=False)
self.prompt_cache = SqliteDict(self.prompt_file, autocommit=False)
return return
def close(self) -> None: def close(self) -> None:
"""Close the client.""" """Close the client."""
self.cache.close() self.cache.close()
def _normalize_table_key(self, key: str, table: str) -> str:
"""Cast key for prompt key."""
return f"{table}:{key}"
def get_key(self, key: str, table: str = "default") -> Union[str, None]: def get_key(self, key: str, table: str = "default") -> Union[str, None]:
""" """
Get the key for a request. Get the key for a request.
@ -43,14 +41,7 @@ class SQLiteCache(Cache):
key: key for cache. key: key for cache.
table: table to get key in. table: table to get key in.
""" """
if table == "prompt": return self.cache.get(self._normalize_table_key(key, table))
return self.prompt_cache.get(key)
else:
if table != "default":
raise ValueError(
"SQLiteDict only support table of `default` or `prompt`"
)
return self.cache.get(key)
def set_key(self, key: str, value: str, table: str = "default") -> None: def set_key(self, key: str, value: str, table: str = "default") -> None:
""" """
@ -63,17 +54,9 @@ class SQLiteCache(Cache):
value: new value for key. value: new value for key.
table: table to set key in. table: table to set key in.
""" """
if table == "prompt": self.cache[self._normalize_table_key(key, table)] = value
self.prompt_cache[key] = value
else:
if table != "default":
raise ValueError(
"SQLiteDict only support table of `default` or `prompt`"
)
self.cache[key] = value
self.commit() self.commit()
def commit(self) -> None: def commit(self) -> None:
"""Commit any results.""" """Commit any results."""
self.prompt_cache.commit()
self.cache.commit() self.cache.commit()

View File

@ -1,3 +1,2 @@
"""Client init.""" """Client init."""
from manifest.clients.client import Client from manifest.clients.client import Client
from manifest.clients.response import Response

View File

@ -2,8 +2,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Optional, Tuple from typing import Any, Callable, Dict, Optional, Tuple
from manifest.clients.response import Response
class Client(ABC): class Client(ABC):
"""Client class.""" """Client class."""
@ -38,9 +36,7 @@ class Client(ABC):
raise NotImplementedError() raise NotImplementedError()
@abstractmethod @abstractmethod
def get_request( def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]:
self, query: str, **kwargs: Any
) -> Tuple[Callable[[], Response], Dict]:
""" """
Get request function. Get request function.

View File

@ -3,7 +3,6 @@ import logging
from typing import Any, Callable, Dict, Optional, Tuple from typing import Any, Callable, Dict, Optional, Tuple
from manifest.clients import Client from manifest.clients import Client
from manifest.clients.response import Response
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -28,9 +27,7 @@ class DummyClient(Client):
"""Close the client.""" """Close the client."""
pass pass
def get_request( def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]:
self, query: str, **kwargs: Any
) -> Tuple[Callable[[], Response], Dict]:
""" """
Get request string function. Get request string function.
@ -46,7 +43,7 @@ class DummyClient(Client):
"num_results": kwargs.get("num_results", self.num_results), "num_results": kwargs.get("num_results", self.num_results),
} }
def _run_completion() -> Response: def _run_completion() -> Dict:
return Response({"choices": [{"text": "hello"}] * self.num_results}) return {"choices": [{"text": "hello"}] * self.num_results}
return _run_completion, request_params return _run_completion, request_params

View File

@ -0,0 +1,67 @@
"""OpenAI client."""
import logging
from typing import Any, Callable, Dict, Optional, Tuple
import requests
from manifest.clients.client import Client
logger = logging.getLogger(__name__)
class HuggingFaceClient(Client):
"""HuggingFace client."""
def connect(
self,
connection_str: Optional[str] = None,
temperature: Optional[float] = 1.0,
max_tokens: Optional[int] = 10,
top_p: Optional[float] = 1.0,
top_k: Optional[int] = 0,
repetition_penalty: Optional[float] = 1.0,
n: Optional[int] = 1,
**kwargs: Any,
) -> None:
"""Connect to the HuggingFace url."""
self.host = connection_str.rstrip("/")
self.temperature = temperature
self.max_tokens = max_tokens
self.top_p = top_p
self.top_k = top_k
self.repetition_penalty = repetition_penalty
self.n = n
def close(self) -> None:
"""Close the client."""
pass
def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]:
"""
Get request string function.
Args:
query: query string.
Returns:
request function that takes no input.
request parameters as dict.
"""
request_params = {
"prompt": query,
"temperature": kwargs.get("temperature", self.temperature),
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
"top_p": kwargs.get("top_p", self.top_p),
"top_k": kwargs.get("top_k", self.top_k),
"repetition_penalty": kwargs.get(
"repetition_penalty", self.repetition_penalty
),
"n": kwargs.get("n", self.n),
}
def _run_completion() -> Dict:
post_str = self.host + "/completions"
res = requests.post(post_str, json=request_params)
return res.json()
return _run_completion, request_params

View File

@ -5,7 +5,6 @@ from typing import Any, Callable, Dict, Optional, Tuple
import openai import openai
from manifest.clients import Response
from manifest.clients.client import Client from manifest.clients.client import Client
logging.getLogger("openai").setLevel(logging.WARNING) logging.getLogger("openai").setLevel(logging.WARNING)
@ -28,7 +27,7 @@ class OpenAIClient(Client):
engine: Optional[str] = "text-ada-001", engine: Optional[str] = "text-ada-001",
temperature: Optional[float] = 0.0, temperature: Optional[float] = 0.0,
max_tokens: Optional[int] = 10, max_tokens: Optional[int] = 10,
top_p: Optional[int] = 1, top_p: Optional[float] = 1.0,
frequency_penalty: Optional[int] = 0, frequency_penalty: Optional[int] = 0,
presence_penalty: Optional[int] = 0, presence_penalty: Optional[int] = 0,
n: Optional[int] = 1, n: Optional[int] = 1,
@ -59,9 +58,7 @@ class OpenAIClient(Client):
"""Close the client.""" """Close the client."""
pass pass
def get_request( def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]:
self, query: str, **kwargs: Any
) -> Tuple[Callable[[], Response], Dict]:
""" """
Get request string function. Get request string function.
@ -85,9 +82,9 @@ class OpenAIClient(Client):
"n": kwargs.get("n", self.n), "n": kwargs.get("n", self.n),
} }
def _run_completion() -> Response: def _run_completion() -> Dict:
try: try:
return Response(openai.Completion.create(**request_params)) return openai.Completion.create(**request_params)
except openai.error.OpenAIError as e: except openai.error.OpenAIError as e:
logger.error(e) logger.error(e)
raise e raise e

View File

@ -1,70 +0,0 @@
"""Client response."""
import json
from typing import Dict, List, Union
class Response:
"""Response class."""
def __init__(self, response: Union[str, Dict]):
"""Initialize response."""
if isinstance(response, str):
self.response = json.loads(response)
elif isinstance(response, dict):
self.response = response
else:
raise ValueError("Response must be str or dict")
if ("choices" not in self.response) or (
not isinstance(self.response["choices"], list)
):
raise ValueError(
"Response must be serialized to a dict with a list of choices"
)
if len(self.response["choices"]) > 0:
if "text" not in self.response["choices"][0]:
raise ValueError(
"Response must be serialized to a dict with a "
"list of choices with text field"
)
def __getitem__(self, key: str) -> str:
"""
Return the response given the key.
Args:
key: key to get.
Returns:
value of key.
"""
return self.response[key]
def get_results(self) -> Union[str, List[str]]:
"""Get all text results from response."""
if len(self.response["choices"]) == 0:
return None
if len(self.response["choices"]) == 1:
return self.response["choices"][0]["text"]
return [choice["text"] for choice in self.response["choices"]]
def serialize(self) -> str:
"""
Serialize response to string.
Returns:
serialized response.
"""
return json.dumps(self.response, sort_keys=True)
@classmethod
def deserialize(cls, value: str) -> "Response":
"""
Deserialize string to response.
Args:
value: serialized response.
Returns:
serialized response.
"""
return Response(value)

View File

@ -2,18 +2,21 @@
import logging import logging
from typing import Any, Iterable, List, Optional, Union from typing import Any, Iterable, List, Optional, Union
from manifest.response import Response
logging.getLogger("openai").setLevel(logging.WARNING) logging.getLogger("openai").setLevel(logging.WARNING)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from manifest.caches.redis import RedisCache from manifest.caches.redis import RedisCache
from manifest.caches.sqlite import SQLiteCache from manifest.caches.sqlite import SQLiteCache
from manifest.clients.dummy import DummyClient from manifest.clients.dummy import DummyClient
from manifest.clients.huggingface import HuggingFaceClient
from manifest.clients.openai import OpenAIClient from manifest.clients.openai import OpenAIClient
from manifest.prompt import Prompt from manifest.prompt import Prompt
CLIENT_CONSTRUCTORS = { CLIENT_CONSTRUCTORS = {
"openai": OpenAIClient, "openai": OpenAIClient,
# "huggingface": manifest.clients.huggingface.HuggingFaceClient, "huggingface": HuggingFaceClient,
"dummy": DummyClient, "dummy": DummyClient,
} }
@ -32,11 +35,20 @@ class Manifest:
client_connection: Optional[str] = None, client_connection: Optional[str] = None,
cache_name: str = "redis", cache_name: str = "redis",
cache_connection: str = "localhost:6379", cache_connection: str = "localhost:6379",
stop_token: str = "",
**kwargs: Any, **kwargs: Any,
): ):
""" """
Initialize manifest. Initialize manifest.
Args:
client_name: name of client.
client_connection: connection string for client.
cache_name: name of cache.
cache_connection: connection string for cache.
stop_token: stop token prompt generation.
Can be overridden in run
Remaining kwargs sent to client and cache. Remaining kwargs sent to client and cache.
""" """
if client_name not in CLIENT_CONSTRUCTORS: if client_name not in CLIENT_CONSTRUCTORS:
@ -50,8 +62,11 @@ class Manifest:
f"Choices are {list(CACHE_CONSTRUCTORS.keys())}" f"Choices are {list(CACHE_CONSTRUCTORS.keys())}"
) )
self.client_name = client_name self.client_name = client_name
self.client = CLIENT_CONSTRUCTORS[client_name](client_connection, **kwargs) self.client = CLIENT_CONSTRUCTORS[client_name]( # type: ignore
client_connection, **kwargs
)
self.cache = CACHE_CONSTRUCTORS[cache_name](cache_connection, **kwargs) self.cache = CACHE_CONSTRUCTORS[cache_name](cache_connection, **kwargs)
self.stop_token = stop_token
def close(self) -> None: def close(self) -> None:
"""Close the client and cache.""" """Close the client and cache."""
@ -63,8 +78,10 @@ class Manifest:
prompt: Prompt, prompt: Prompt,
input: Optional[Any] = None, input: Optional[Any] = None,
overwrite_cache: bool = False, overwrite_cache: bool = False,
stop_token: Optional[str] = None,
return_response: bool = False,
**kwargs: Any, **kwargs: Any,
) -> Union[str, List[str]]: ) -> Union[str, List[str], Response]:
""" """
Run the prompt. Run the prompt.
@ -72,10 +89,14 @@ class Manifest:
prompt: prompt to run. prompt: prompt to run.
input: input to prompt. input: input to prompt.
overwrite_cache: whether to overwrite cache. overwrite_cache: whether to overwrite cache.
stop_token: stop token for prompt generation.
Default is self.stop_token.
"" for no stop token.
Returns: Returns:
response from prompt. response from prompt.
""" """
stop_token = stop_token if stop_token is not None else self.stop_token
prompt_str = prompt(input) prompt_str = prompt(input)
possible_request, full_kwargs = self.client.get_request(prompt_str, **kwargs) possible_request, full_kwargs = self.client.get_request(prompt_str, **kwargs)
# Create cacke key # Create cacke key
@ -84,16 +105,22 @@ class Manifest:
cache_key["client_name"] = self.client_name cache_key["client_name"] = self.client_name
# Make query prompt dependent # Make query prompt dependent
cache_key["prompt"] = prompt_str cache_key["prompt"] = prompt_str
response, _ = self.cache.get(cache_key, overwrite_cache, possible_request) response_obj = self.cache.get(cache_key, overwrite_cache, possible_request)
return response.get_results() # Extract text results
if return_response:
return response_obj
else:
return response_obj.get_response(stop_token)
def run_batch( def run_batch(
self, self,
prompt: Prompt, prompt: Prompt,
input: Optional[Iterable[Any]] = None, input: Optional[Iterable[Any]] = None,
overwrite_cache: bool = False, overwrite_cache: bool = False,
stop_token: Optional[str] = None,
return_response: bool = False,
**kwargs: Any, **kwargs: Any,
) -> Iterable[Union[str, List[str]]]: ) -> Iterable[Union[str, List[str], Response]]:
""" """
Run the prompt on a batch of inputs. Run the prompt on a batch of inputs.
@ -101,13 +128,21 @@ class Manifest:
prompt: prompt to run. prompt: prompt to run.
input: batch of inputs. input: batch of inputs.
overwrite_cache: whether to overwrite cache. overwrite_cache: whether to overwrite cache.
stop_token: stop token for prompt generation.
Default is self.stop_token.
"" for no stop token.
Returns: Returns:
batch of responses. batch of responses.
""" """
if input is None: if input is None:
input = [None] input = [None]
return [self.run(prompt, inp, overwrite_cache, **kwargs) for inp in input] return [
self.run(
prompt, inp, overwrite_cache, stop_token, return_response, **kwargs
)
for inp in input
]
def save_prompt(self, name: str, prompt: Prompt) -> None: def save_prompt(self, name: str, prompt: Prompt) -> None:
""" """

106
manifest/response.py Normal file
View File

@ -0,0 +1,106 @@
"""Client response."""
import json
from typing import Dict, List, Union
class Response:
"""Response class."""
def __init__(self, response: Dict, cached: bool, request_params: Dict):
"""Initialize response."""
if isinstance(response, dict):
self._response = response
else:
raise ValueError("Response must be str or dict")
if ("choices" not in self._response) or (
not isinstance(self._response["choices"], list)
):
raise ValueError(
"Response must be serialized to a dict with a list of choices"
)
if len(self._response["choices"]) > 0:
if "text" not in self._response["choices"][0]:
raise ValueError(
"Response must be serialized to a dict with a "
"list of choices with text field"
)
self._cached = cached
self._request_params = request_params
def is_cached(self) -> bool:
"""Check if response is cached."""
return self._cached
def get_request(self) -> Dict:
"""Get request parameters."""
return self._request_params
def get_raw_response(self) -> Dict:
"""Get response dict without parsing."""
return self._response
def get_response(self, stop_token: str = "") -> Union[str, List[str]]:
"""
Get all text results from response.
Args:
stop_token: stop token for string generation
"""
process_result = (
lambda x: x.strip().split(stop_token)[0] if stop_token else x.strip()
)
if len(self._response["choices"]) == 0:
return None
if len(self._response["choices"]) == 1:
return process_result(self._response["choices"][0]["text"])
return [process_result(choice["text"]) for choice in self._response["choices"]]
def serialize(self) -> str:
"""
Serialize response to string.
Returns:
serialized response.
"""
to_serialize = {
"response": self._response,
"cached": self._cached,
"request_params": self._request_params,
}
return json.dumps(to_serialize, sort_keys=True)
@classmethod
def deserialize(cls, value: str) -> "Response":
"""
Deserialize string to response.
Args:
value: serialized response.
Returns:
serialized response.
"""
deserialized = json.loads(value)
return cls(
deserialized["response"],
deserialized["cached"],
deserialized["request_params"],
)
def __str__(self) -> str:
"""
Get string representation of response.
Returns:
string representation of response.
"""
return self.serialize()
def __repr__(self) -> str:
"""
Get string representation of response.
Returns:
string representation of response.
"""
return str(self)

356
poetry.lock generated
View File

@ -100,7 +100,7 @@ unicode_backport = ["unicodedata2"]
name = "click" name = "click"
version = "8.1.3" version = "8.1.3"
description = "Composable command line interface toolkit" description = "Composable command line interface toolkit"
category = "dev" category = "main"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
@ -193,7 +193,7 @@ python-versions = ">=3.6"
name = "filelock" name = "filelock"
version = "3.7.0" version = "3.7.0"
description = "A platform independent file lock." description = "A platform independent file lock."
category = "dev" category = "main"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
@ -237,6 +237,50 @@ python-versions = "*"
[package.dependencies] [package.dependencies]
flake8 = "*" flake8 = "*"
[[package]]
name = "flask"
version = "2.1.2"
description = "A simple framework for building complex web applications."
category = "main"
optional = false
python-versions = ">=3.7"
[package.dependencies]
click = ">=8.0"
importlib-metadata = {version = ">=3.6.0", markers = "python_version < \"3.10\""}
itsdangerous = ">=2.0"
Jinja2 = ">=3.0"
Werkzeug = ">=2.0"
[package.extras]
async = ["asgiref (>=3.2)"]
dotenv = ["python-dotenv"]
[[package]]
name = "huggingface-hub"
version = "0.7.0"
description = "Client library to download and publish models on the huggingface.co hub"
category = "main"
optional = false
python-versions = ">=3.7.0"
[package.dependencies]
filelock = "*"
packaging = ">=20.9"
pyyaml = ">=5.1"
requests = "*"
tqdm = "*"
typing-extensions = ">=3.7.4.3"
[package.extras]
all = ["pytest", "datasets", "soundfile", "black (>=22.0,<23.0)", "isort (>=5.5.4)", "flake8 (>=3.8.3)"]
dev = ["pytest", "datasets", "soundfile", "black (>=22.0,<23.0)", "isort (>=5.5.4)", "flake8 (>=3.8.3)"]
fastai = ["toml", "fastai (>=2.4)", "fastcore (>=1.3.27)"]
quality = ["black (>=22.0,<23.0)", "isort (>=5.5.4)", "flake8 (>=3.8.3)"]
tensorflow = ["tensorflow", "pydot", "graphviz"]
testing = ["pytest", "datasets", "soundfile"]
torch = ["torch"]
[[package]] [[package]]
name = "identify" name = "identify"
version = "2.5.1" version = "2.5.1"
@ -264,6 +308,22 @@ category = "dev"
optional = false optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
[[package]]
name = "importlib-metadata"
version = "4.11.4"
description = "Read metadata from Python packages"
category = "main"
optional = false
python-versions = ">=3.7"
[package.dependencies]
zipp = ">=0.5"
[package.extras]
docs = ["sphinx", "jaraco.packaging (>=9)", "rst.linker (>=1.9)"]
perf = ["ipython"]
testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-cov", "pytest-enabler (>=1.0.1)", "packaging", "pyfakefs", "flufl.flake8", "pytest-perf (>=0.9.2)", "pytest-black (>=0.3.7)", "pytest-mypy (>=0.9.1)", "importlib-resources (>=1.3)"]
[[package]] [[package]]
name = "iniconfig" name = "iniconfig"
version = "1.1.1" version = "1.1.1"
@ -286,11 +346,19 @@ requirements_deprecated_finder = ["pipreqs", "pip-api"]
colors = ["colorama (>=0.4.3,<0.5.0)"] colors = ["colorama (>=0.4.3,<0.5.0)"]
plugins = ["setuptools"] plugins = ["setuptools"]
[[package]]
name = "itsdangerous"
version = "2.1.2"
description = "Safely pass data to untrusted environments and back."
category = "main"
optional = false
python-versions = ">=3.7"
[[package]] [[package]]
name = "jinja2" name = "jinja2"
version = "3.1.2" version = "3.1.2"
description = "A very fast and expressive template engine." description = "A very fast and expressive template engine."
category = "dev" category = "main"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
@ -304,7 +372,7 @@ i18n = ["Babel (>=2.7)"]
name = "markupsafe" name = "markupsafe"
version = "2.1.1" version = "2.1.1"
description = "Safely add untrusted strings to HTML/XML markup." description = "Safely add untrusted strings to HTML/XML markup."
category = "dev" category = "main"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
@ -614,7 +682,7 @@ python-versions = "*"
name = "pyyaml" name = "pyyaml"
version = "6.0" version = "6.0"
description = "YAML parser and emitter for Python" description = "YAML parser and emitter for Python"
category = "dev" category = "main"
optional = false optional = false
python-versions = ">=3.6" python-versions = ">=3.6"
@ -648,6 +716,14 @@ packaging = ">=20.4"
hiredis = ["hiredis (>=1.0.0)"] hiredis = ["hiredis (>=1.0.0)"]
ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==20.0.1)", "requests (>=2.26.0)"] ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==20.0.1)", "requests (>=2.26.0)"]
[[package]]
name = "regex"
version = "2022.4.24"
description = "Alternative regular expression module, to replace re."
category = "main"
optional = false
python-versions = ">=3.6"
[[package]] [[package]]
name = "requests" name = "requests"
version = "2.27.1" version = "2.27.1"
@ -792,6 +868,18 @@ category = "main"
optional = false optional = false
python-versions = "*" python-versions = "*"
[[package]]
name = "tokenizers"
version = "0.12.1"
description = "Fast and Customizable Tokenizers"
category = "main"
optional = false
python-versions = "*"
[package.extras]
docs = ["sphinx", "sphinx-rtd-theme", "setuptools-rust"]
testing = ["pytest", "requests", "numpy", "datasets"]
[[package]] [[package]]
name = "toml" name = "toml"
version = "0.10.2" version = "0.10.2"
@ -808,6 +896,17 @@ category = "dev"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
[[package]]
name = "torch"
version = "1.11.0"
description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
category = "main"
optional = false
python-versions = ">=3.7.0"
[package.dependencies]
typing-extensions = "*"
[[package]] [[package]]
name = "tqdm" name = "tqdm"
version = "4.64.0" version = "4.64.0"
@ -825,11 +924,71 @@ notebook = ["ipywidgets (>=6)"]
slack = ["slack-sdk"] slack = ["slack-sdk"]
telegram = ["requests"] telegram = ["requests"]
[[package]]
name = "transformers"
version = "4.19.2"
description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow"
category = "main"
optional = false
python-versions = ">=3.7.0"
[package.dependencies]
filelock = "*"
huggingface-hub = ">=0.1.0,<1.0"
numpy = ">=1.17"
packaging = ">=20.0"
pyyaml = ">=5.1"
regex = "!=2019.12.17"
requests = "*"
tokenizers = ">=0.11.1,<0.11.3 || >0.11.3,<0.13"
tqdm = ">=4.27"
[package.extras]
all = ["tensorflow (>=2.3)", "onnxconverter-common", "tf2onnx", "torch (>=1.0)", "jax (>=0.2.8,!=0.3.2,<=0.3.6)", "jaxlib (>=0.1.65,<=0.3.6)", "flax (>=0.3.5)", "optax (>=0.0.8)", "sentencepiece (>=0.1.91,!=0.1.92)", "protobuf", "tokenizers (>=0.11.1,!=0.11.3,<0.13)", "torchaudio", "librosa", "pyctcdecode (>=0.3.0)", "phonemizer", "pillow", "optuna", "ray", "sigopt", "timm", "codecarbon (==1.2.0)"]
audio = ["librosa", "pyctcdecode (>=0.3.0)", "phonemizer"]
codecarbon = ["codecarbon (==1.2.0)"]
deepspeed = ["deepspeed (>=0.6.4)"]
deepspeed-testing = ["deepspeed (>=0.6.4)", "pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil", "datasets", "pytest-timeout", "black (>=22.0,<23.0)", "sacrebleu (>=1.4.12,<2.0.0)", "rouge-score", "nltk", "GitPython (<3.1.19)", "hf-doc-builder (>=0.3.0)", "sacremoses", "rjieba", "faiss-cpu", "cookiecutter (==1.7.3)", "optuna"]
dev = ["tensorflow (>=2.3)", "onnxconverter-common", "tf2onnx", "torch (>=1.0)", "jax (>=0.2.8,!=0.3.2,<=0.3.6)", "jaxlib (>=0.1.65,<=0.3.6)", "flax (>=0.3.5)", "optax (>=0.0.8)", "sentencepiece (>=0.1.91,!=0.1.92)", "protobuf", "tokenizers (>=0.11.1,!=0.11.3,<0.13)", "torchaudio", "librosa", "pyctcdecode (>=0.3.0)", "phonemizer", "pillow", "optuna", "ray", "sigopt", "timm", "codecarbon (==1.2.0)", "pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil", "datasets", "pytest-timeout", "black (>=22.0,<23.0)", "sacrebleu (>=1.4.12,<2.0.0)", "rouge-score", "nltk", "GitPython (<3.1.19)", "hf-doc-builder (>=0.3.0)", "sacremoses", "rjieba", "faiss-cpu", "cookiecutter (==1.7.3)", "isort (>=5.5.4)", "flake8 (>=3.8.3)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "unidic-lite (>=1.0.7)", "unidic (>=1.0.2)", "hf-doc-builder", "scikit-learn"]
dev-tensorflow = ["pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil", "datasets", "pytest-timeout", "black (>=22.0,<23.0)", "sacrebleu (>=1.4.12,<2.0.0)", "rouge-score", "nltk", "GitPython (<3.1.19)", "hf-doc-builder (>=0.3.0)", "sacremoses", "rjieba", "faiss-cpu", "cookiecutter (==1.7.3)", "tensorflow (>=2.3)", "onnxconverter-common", "tf2onnx", "sentencepiece (>=0.1.91,!=0.1.92)", "protobuf", "tokenizers (>=0.11.1,!=0.11.3,<0.13)", "pillow", "isort (>=5.5.4)", "flake8 (>=3.8.3)", "hf-doc-builder", "scikit-learn", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "librosa", "pyctcdecode (>=0.3.0)", "phonemizer"]
dev-torch = ["pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil", "datasets", "pytest-timeout", "black (>=22.0,<23.0)", "sacrebleu (>=1.4.12,<2.0.0)", "rouge-score", "nltk", "GitPython (<3.1.19)", "hf-doc-builder (>=0.3.0)", "sacremoses", "rjieba", "faiss-cpu", "cookiecutter (==1.7.3)", "torch (>=1.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "protobuf", "tokenizers (>=0.11.1,!=0.11.3,<0.13)", "torchaudio", "librosa", "pyctcdecode (>=0.3.0)", "phonemizer", "pillow", "optuna", "ray", "sigopt", "timm", "codecarbon (==1.2.0)", "isort (>=5.5.4)", "flake8 (>=3.8.3)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "unidic-lite (>=1.0.7)", "unidic (>=1.0.2)", "hf-doc-builder", "scikit-learn", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"]
docs = ["tensorflow (>=2.3)", "onnxconverter-common", "tf2onnx", "torch (>=1.0)", "jax (>=0.2.8,!=0.3.2,<=0.3.6)", "jaxlib (>=0.1.65,<=0.3.6)", "flax (>=0.3.5)", "optax (>=0.0.8)", "sentencepiece (>=0.1.91,!=0.1.92)", "protobuf", "tokenizers (>=0.11.1,!=0.11.3,<0.13)", "torchaudio", "librosa", "pyctcdecode (>=0.3.0)", "phonemizer", "pillow", "optuna", "ray", "sigopt", "timm", "codecarbon (==1.2.0)", "hf-doc-builder"]
docs_specific = ["hf-doc-builder"]
fairscale = ["fairscale (>0.3)"]
flax = ["jax (>=0.2.8,!=0.3.2,<=0.3.6)", "jaxlib (>=0.1.65,<=0.3.6)", "flax (>=0.3.5)", "optax (>=0.0.8)"]
flax-speech = ["librosa", "pyctcdecode (>=0.3.0)", "phonemizer"]
ftfy = ["ftfy"]
integrations = ["optuna", "ray", "sigopt"]
ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "unidic-lite (>=1.0.7)", "unidic (>=1.0.2)"]
modelcreation = ["cookiecutter (==1.7.3)"]
onnx = ["onnxconverter-common", "tf2onnx", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"]
onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"]
optuna = ["optuna"]
quality = ["black (>=22.0,<23.0)", "isort (>=5.5.4)", "flake8 (>=3.8.3)", "GitPython (<3.1.19)", "hf-doc-builder (>=0.3.0)"]
ray = ["ray"]
retrieval = ["faiss-cpu", "datasets"]
sagemaker = ["sagemaker (>=2.31.0)"]
sentencepiece = ["sentencepiece (>=0.1.91,!=0.1.92)", "protobuf"]
serving = ["pydantic", "uvicorn", "fastapi", "starlette"]
sigopt = ["sigopt"]
sklearn = ["scikit-learn"]
speech = ["torchaudio", "librosa", "pyctcdecode (>=0.3.0)", "phonemizer"]
testing = ["pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil", "datasets", "pytest-timeout", "black (>=22.0,<23.0)", "sacrebleu (>=1.4.12,<2.0.0)", "rouge-score", "nltk", "GitPython (<3.1.19)", "hf-doc-builder (>=0.3.0)", "sacremoses", "rjieba", "faiss-cpu", "cookiecutter (==1.7.3)"]
tf = ["tensorflow (>=2.3)", "onnxconverter-common", "tf2onnx"]
tf-cpu = ["tensorflow-cpu (>=2.3)", "onnxconverter-common", "tf2onnx"]
tf-speech = ["librosa", "pyctcdecode (>=0.3.0)", "phonemizer"]
timm = ["timm"]
tokenizers = ["tokenizers (>=0.11.1,!=0.11.3,<0.13)"]
torch = ["torch (>=1.0)"]
torch-speech = ["torchaudio", "librosa", "pyctcdecode (>=0.3.0)", "phonemizer"]
torchhub = ["filelock", "huggingface-hub (>=0.1.0,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "torch (>=1.0)", "tokenizers (>=0.11.1,!=0.11.3,<0.13)", "tqdm (>=4.27)"]
vision = ["pillow"]
[[package]] [[package]]
name = "typing-extensions" name = "typing-extensions"
version = "4.2.0" version = "4.2.0"
description = "Backported and Experimental Type Hints for Python 3.7+" description = "Backported and Experimental Type Hints for Python 3.7+"
category = "dev" category = "main"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
@ -864,6 +1023,17 @@ six = ">=1.9.0,<2"
docs = ["proselint (>=0.10.2)", "sphinx (>=3)", "sphinx-argparse (>=0.2.5)", "sphinx-rtd-theme (>=0.4.3)", "towncrier (>=21.3)"] docs = ["proselint (>=0.10.2)", "sphinx (>=3)", "sphinx-argparse (>=0.2.5)", "sphinx-rtd-theme (>=0.4.3)", "towncrier (>=21.3)"]
testing = ["coverage (>=4)", "coverage-enable-subprocess (>=1)", "flaky (>=3)", "pytest (>=4)", "pytest-env (>=0.6.2)", "pytest-freezegun (>=0.4.1)", "pytest-mock (>=2)", "pytest-randomly (>=1)", "pytest-timeout (>=1)", "packaging (>=20.0)"] testing = ["coverage (>=4)", "coverage-enable-subprocess (>=1)", "flaky (>=3)", "pytest (>=4)", "pytest-env (>=0.6.2)", "pytest-freezegun (>=0.4.1)", "pytest-mock (>=2)", "pytest-randomly (>=1)", "pytest-timeout (>=1)", "packaging (>=20.0)"]
[[package]]
name = "werkzeug"
version = "2.1.2"
description = "The comprehensive WSGI web application library."
category = "main"
optional = false
python-versions = ">=3.7"
[package.extras]
watchdog = ["watchdog"]
[[package]] [[package]]
name = "wrapt" name = "wrapt"
version = "1.14.1" version = "1.14.1"
@ -872,10 +1042,22 @@ category = "main"
optional = false optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7" python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7"
[[package]]
name = "zipp"
version = "3.8.0"
description = "Backport of pathlib-compatible object wrapper for zip files"
category = "main"
optional = false
python-versions = ">=3.7"
[package.extras]
docs = ["sphinx", "jaraco.packaging (>=9)", "rst.linker (>=1.9)"]
testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-cov", "pytest-enabler (>=1.0.1)", "jaraco.itertools", "func-timeout", "pytest-black (>=0.3.7)", "pytest-mypy (>=0.9.1)"]
[metadata] [metadata]
lock-version = "1.1" lock-version = "1.1"
python-versions = "^3.8" python-versions = "^3.8"
content-hash = "cfa01a852b186ab6ddd1b610dba12c7d4ac6f10c7bdb10ca7bac6a0d25568d00" content-hash = "b0f92eed3a0f80cc9976f68098b772a19d19b4742426299bd63cb1c636d8a84a"
[metadata.files] [metadata.files]
alabaster = [ alabaster = [
@ -1026,6 +1208,14 @@ flake8-polyfill = [
{file = "flake8-polyfill-1.0.2.tar.gz", hash = "sha256:e44b087597f6da52ec6393a709e7108b2905317d0c0b744cdca6208e670d8eda"}, {file = "flake8-polyfill-1.0.2.tar.gz", hash = "sha256:e44b087597f6da52ec6393a709e7108b2905317d0c0b744cdca6208e670d8eda"},
{file = "flake8_polyfill-1.0.2-py2.py3-none-any.whl", hash = "sha256:12be6a34ee3ab795b19ca73505e7b55826d5f6ad7230d31b18e106400169b9e9"}, {file = "flake8_polyfill-1.0.2-py2.py3-none-any.whl", hash = "sha256:12be6a34ee3ab795b19ca73505e7b55826d5f6ad7230d31b18e106400169b9e9"},
] ]
flask = [
{file = "Flask-2.1.2-py3-none-any.whl", hash = "sha256:fad5b446feb0d6db6aec0c3184d16a8c1f6c3e464b511649c8918a9be100b4fe"},
{file = "Flask-2.1.2.tar.gz", hash = "sha256:315ded2ddf8a6281567edb27393010fe3406188bafbfe65a3339d5787d89e477"},
]
huggingface-hub = [
{file = "huggingface_hub-0.7.0-py3-none-any.whl", hash = "sha256:fd448fd0b738d803411c79bdf9f12f0ba171fecd24a59edf88c1391b473bc2c0"},
{file = "huggingface_hub-0.7.0.tar.gz", hash = "sha256:8154dc2fad84b32a4bca18372a647d9381ed8550a80b11050758357b8fcea639"},
]
identify = [ identify = [
{file = "identify-2.5.1-py2.py3-none-any.whl", hash = "sha256:0dca2ea3e4381c435ef9c33ba100a78a9b40c0bab11189c7cf121f75815efeaa"}, {file = "identify-2.5.1-py2.py3-none-any.whl", hash = "sha256:0dca2ea3e4381c435ef9c33ba100a78a9b40c0bab11189c7cf121f75815efeaa"},
{file = "identify-2.5.1.tar.gz", hash = "sha256:3d11b16f3fe19f52039fb7e39c9c884b21cb1b586988114fbe42671f03de3e82"}, {file = "identify-2.5.1.tar.gz", hash = "sha256:3d11b16f3fe19f52039fb7e39c9c884b21cb1b586988114fbe42671f03de3e82"},
@ -1038,6 +1228,10 @@ imagesize = [
{file = "imagesize-1.3.0-py2.py3-none-any.whl", hash = "sha256:1db2f82529e53c3e929e8926a1fa9235aa82d0bd0c580359c67ec31b2fddaa8c"}, {file = "imagesize-1.3.0-py2.py3-none-any.whl", hash = "sha256:1db2f82529e53c3e929e8926a1fa9235aa82d0bd0c580359c67ec31b2fddaa8c"},
{file = "imagesize-1.3.0.tar.gz", hash = "sha256:cd1750d452385ca327479d45b64d9c7729ecf0b3969a58148298c77092261f9d"}, {file = "imagesize-1.3.0.tar.gz", hash = "sha256:cd1750d452385ca327479d45b64d9c7729ecf0b3969a58148298c77092261f9d"},
] ]
importlib-metadata = [
{file = "importlib_metadata-4.11.4-py3-none-any.whl", hash = "sha256:c58c8eb8a762858f49e18436ff552e83914778e50e9d2f1660535ffb364552ec"},
{file = "importlib_metadata-4.11.4.tar.gz", hash = "sha256:5d26852efe48c0a32b0509ffbc583fda1a2266545a78d104a6f4aff3db17d700"},
]
iniconfig = [ iniconfig = [
{file = "iniconfig-1.1.1-py2.py3-none-any.whl", hash = "sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3"}, {file = "iniconfig-1.1.1-py2.py3-none-any.whl", hash = "sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3"},
{file = "iniconfig-1.1.1.tar.gz", hash = "sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32"}, {file = "iniconfig-1.1.1.tar.gz", hash = "sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32"},
@ -1046,6 +1240,10 @@ isort = [
{file = "isort-5.10.1-py3-none-any.whl", hash = "sha256:6f62d78e2f89b4500b080fe3a81690850cd254227f27f75c3a0c491a1f351ba7"}, {file = "isort-5.10.1-py3-none-any.whl", hash = "sha256:6f62d78e2f89b4500b080fe3a81690850cd254227f27f75c3a0c491a1f351ba7"},
{file = "isort-5.10.1.tar.gz", hash = "sha256:e8443a5e7a020e9d7f97f1d7d9cd17c88bcb3bc7e218bf9cf5095fe550be2951"}, {file = "isort-5.10.1.tar.gz", hash = "sha256:e8443a5e7a020e9d7f97f1d7d9cd17c88bcb3bc7e218bf9cf5095fe550be2951"},
] ]
itsdangerous = [
{file = "itsdangerous-2.1.2-py3-none-any.whl", hash = "sha256:2c2349112351b88699d8d4b6b075022c0808887cb7ad10069318a8b0bc88db44"},
{file = "itsdangerous-2.1.2.tar.gz", hash = "sha256:5dbbc68b317e5e42f327f9021763545dc3fc3bfe22e6deb96aaf1fc38874156a"},
]
jinja2 = [ jinja2 = [
{file = "Jinja2-3.1.2-py3-none-any.whl", hash = "sha256:6088930bfe239f0e6710546ab9c19c9ef35e29792895fed6e6e31a023a182a61"}, {file = "Jinja2-3.1.2-py3-none-any.whl", hash = "sha256:6088930bfe239f0e6710546ab9c19c9ef35e29792895fed6e6e31a023a182a61"},
{file = "Jinja2-3.1.2.tar.gz", hash = "sha256:31351a702a408a9e7595a8fc6150fc3f43bb6bf7e319770cbc0db9df9437e852"}, {file = "Jinja2-3.1.2.tar.gz", hash = "sha256:31351a702a408a9e7595a8fc6150fc3f43bb6bf7e319770cbc0db9df9437e852"},
@ -1298,6 +1496,82 @@ redis = [
{file = "redis-4.3.1-py3-none-any.whl", hash = "sha256:84316970995a7adb907a56754d2b92d88fc2d252963dc5ac34c88f0f1a22c25d"}, {file = "redis-4.3.1-py3-none-any.whl", hash = "sha256:84316970995a7adb907a56754d2b92d88fc2d252963dc5ac34c88f0f1a22c25d"},
{file = "redis-4.3.1.tar.gz", hash = "sha256:94b617b4cd296e94991146f66fc5559756fbefe9493604f0312e4d3298ac63e9"}, {file = "redis-4.3.1.tar.gz", hash = "sha256:94b617b4cd296e94991146f66fc5559756fbefe9493604f0312e4d3298ac63e9"},
] ]
regex = [
{file = "regex-2022.4.24-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:f86aef546add4ff1202e1f31e9bb54f9268f17d996b2428877283146bf9bc013"},
{file = "regex-2022.4.24-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e944268445b5694f5d41292c9228f0ca46d5a32a67f195d5f8547c1f1d91f4bc"},
{file = "regex-2022.4.24-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0f8da3145f4b72f7ce6181c804eaa44cdcea313c8998cdade3d9e20a8717a9cb"},
{file = "regex-2022.4.24-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0fd464e547dbabf4652ca5fe9d88d75ec30182981e737c07b3410235a44b9939"},
{file = "regex-2022.4.24-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:071bcb625e890f28b7c4573124a6512ea65107152b1d3ca101ce33a52dad4593"},
{file = "regex-2022.4.24-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1c2de7f32fa87d04d40f54bce3843af430697aba51c3a114aa62837a0772f219"},
{file = "regex-2022.4.24-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1a07e8366115069f26822c47732122ab61598830a69f5629a37ea8881487c107"},
{file = "regex-2022.4.24-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:036d1c1fbe69eba3ee253c107e71749cdbb4776db93d674bc0d5e28f30300734"},
{file = "regex-2022.4.24-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:af1e687ffab18a75409e5e5d6215b6ccd41a5a1a0ea6ce9665e01253f737a0d3"},
{file = "regex-2022.4.24-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:165cc75cfa5aa0f12adb2ac6286330e7229a06dc0e6c004ec35da682b5b89579"},
{file = "regex-2022.4.24-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:3e35c50b27f36176c792738cb9b858523053bc495044d2c2b44db24376b266f1"},
{file = "regex-2022.4.24-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:43ee0df35925ae4b0cc6ee3f60b73369e559dd2ac40945044da9394dd9d3a51d"},
{file = "regex-2022.4.24-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:58521abdab76583bd41ef47e5e2ddd93b32501aee4ee8cee71dee10a45ba46b1"},
{file = "regex-2022.4.24-cp310-cp310-win32.whl", hash = "sha256:275afc7352982ee947fc88f67a034b52c78395977b5fc7c9be15f7dc95b76f06"},
{file = "regex-2022.4.24-cp310-cp310-win_amd64.whl", hash = "sha256:253f858a0255cd91a0424a4b15c2eedb12f20274f85731b0d861c8137e843065"},
{file = "regex-2022.4.24-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:85b7ee4d0c7a46296d884f6b489af8b960c4291d76aea4b22fd4fbe05e6ec08e"},
{file = "regex-2022.4.24-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e0da7ef160d4f3eb3d4d3e39a02c3c42f7dbcfce62c81f784cc99fc7059765f"},
{file = "regex-2022.4.24-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4f2e2cef324ca9355049ee1e712f68e2e92716eba24275e6767b9bfa15f1f478"},
{file = "regex-2022.4.24-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6165e737acb3bea3271372e8aa5ebe7226c8a8e8da1b94af2d6547c5a09d689d"},
{file = "regex-2022.4.24-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3f6bd8178cce5bb56336722d5569d19c50bba5915a69a2050c497fb921e7cb0f"},
{file = "regex-2022.4.24-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:45b761406777a681db0c24686178532134c937d24448d9e085279b69e9eb7da4"},
{file = "regex-2022.4.24-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:3dfbadb7b74d95f72f9f9dbf9778f7de92722ab520a109ceaf7927461fa85b10"},
{file = "regex-2022.4.24-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:9913bcf730eb6e9b441fb176832eea9acbebab6035542c7c89d90c803f5cd3be"},
{file = "regex-2022.4.24-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:68aed3fb0c61296bd6d234f558f78c51671f79ccb069cbcd428c2eea6fee7a5b"},
{file = "regex-2022.4.24-cp36-cp36m-musllinux_1_1_ppc64le.whl", hash = "sha256:8e7d33f93cdd01868327d834d0f5bb029241cd293b47d51b96814dec27fc9b4b"},
{file = "regex-2022.4.24-cp36-cp36m-musllinux_1_1_s390x.whl", hash = "sha256:82b7fc67e49fdce671bdbec1127189fc979badf062ce6e79dc95ef5e07a8bf92"},
{file = "regex-2022.4.24-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:c36906a7855ec33a9083608e6cd595e4729dab18aeb9aad0dd0b039240266239"},
{file = "regex-2022.4.24-cp36-cp36m-win32.whl", hash = "sha256:b2df3ede85d778c949d9bd2a50237072cee3df0a423c91f5514f78f8035bde87"},
{file = "regex-2022.4.24-cp36-cp36m-win_amd64.whl", hash = "sha256:dffd9114ade73137ab2b79a8faf864683dbd2dbbb6b23a305fbbd4cbaeeb2187"},
{file = "regex-2022.4.24-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6a0ef57cccd8089b4249eebad95065390e56c04d4a92c51316eab4131bca96a9"},
{file = "regex-2022.4.24-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:12af15b6edb00e425f713160cfd361126e624ec0de86e74f7cad4b97b7f169b3"},
{file = "regex-2022.4.24-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7f271d0831d8ebc56e17b37f9fa1824b0379221d1238ae77c18a6e8c47f1fdce"},
{file = "regex-2022.4.24-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:37903d5ca11fa47577e8952d2e2c6de28553b11c70defee827afb941ab2c6729"},
{file = "regex-2022.4.24-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8b747cef8e5dcdaf394192d43a0c02f5825aeb0ecd3d43e63ae500332ab830b0"},
{file = "regex-2022.4.24-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:582ea06079a03750b5f71e20a87cd99e646d796638b5894ff85987ebf5e04924"},
{file = "regex-2022.4.24-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:aa6daa189db9104787ff1fd7a7623ce017077aa59eaac609d0d25ba95ed251a0"},
{file = "regex-2022.4.24-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:7dbc96419ef0fb6ac56626014e6d3a345aeb8b17a3df8830235a88626ffc8d84"},
{file = "regex-2022.4.24-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:0fb6cb16518ac7eff29d1e0b0cce90275dfae0f17154165491058c31d58bdd1d"},
{file = "regex-2022.4.24-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:bea61de0c688198e3d9479344228c7accaa22a78b58ec408e41750ebafee6c08"},
{file = "regex-2022.4.24-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:46cbc5b23f85e94161b093dba1b49035697cf44c7db3c930adabfc0e6d861b95"},
{file = "regex-2022.4.24-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:50b77622016f03989cd06ecf6b602c7a6b4ed2e3ce04133876b041d109c934ee"},
{file = "regex-2022.4.24-cp37-cp37m-win32.whl", hash = "sha256:2bde99f2cdfd6db1ec7e02d68cadd384ffe7413831373ea7cc68c5415a0cb577"},
{file = "regex-2022.4.24-cp37-cp37m-win_amd64.whl", hash = "sha256:66fb765b2173d90389384708e3e1d3e4be1148bd8d4d50476b1469da5a2f0229"},
{file = "regex-2022.4.24-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:709396c0c95b95045fac89b94f997410ff39b81a09863fe21002f390d48cc7d3"},
{file = "regex-2022.4.24-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7a608022f4593fc67518c6c599ae5abdb03bb8acd75993c82cd7a4c8100eff81"},
{file = "regex-2022.4.24-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fb7107faf0168de087f62a2f2ed00f9e9da12e0b801582b516ddac236b871cda"},
{file = "regex-2022.4.24-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aabc28f7599f781ddaeac168d0b566d0db82182cc3dcf62129f0a4fc2927b811"},
{file = "regex-2022.4.24-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:92ad03f928675ca05b79d3b1d3dfc149e2226d57ed9d57808f82105d511d0212"},
{file = "regex-2022.4.24-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b7ba3c304a4a5d8112dbd30df8b3e4ef59b4b07807957d3c410d9713abaee9a8"},
{file = "regex-2022.4.24-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e2acf5c66fbb62b5fe4c40978ddebafa50818f00bf79d60569d9762f6356336e"},
{file = "regex-2022.4.24-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:7c4d9770e579eb11b582b2e2fd19fa204a15cb1589ae73cd4dcbb63b64f3e828"},
{file = "regex-2022.4.24-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:02543d6d5c32d361b7cc468079ba4cddaaf4a6544f655901ba1ff9d8e3f18755"},
{file = "regex-2022.4.24-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:73ed1b06abadbf6b61f6033a07c06f36ec0ddca117e41ef2ac37056705e46458"},
{file = "regex-2022.4.24-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:3241db067a7f69da57fba8bca543ac8a7ca415d91e77315690202749b9fdaba1"},
{file = "regex-2022.4.24-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:d128e278e5e554c5c022c7bed410ca851e00bacebbb4460de546a73bc53f8de4"},
{file = "regex-2022.4.24-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b1d53835922cd0f9b74b2742453a444865a70abae38d12eb41c59271da66f38d"},
{file = "regex-2022.4.24-cp38-cp38-win32.whl", hash = "sha256:f2a5d9f612091812dee18375a45d046526452142e7b78c4e21ab192db15453d5"},
{file = "regex-2022.4.24-cp38-cp38-win_amd64.whl", hash = "sha256:a850f5f369f1e3b6239da7fb43d1d029c1e178263df671819889c47caf7e4ff3"},
{file = "regex-2022.4.24-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bedb3d01ad35ea1745bdb1d57f3ee0f996f988c98f5bbae9d068c3bb3065d210"},
{file = "regex-2022.4.24-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8bf867ba71856414a482e4b683500f946c300c4896e472e51d3db8dfa8dc8f32"},
{file = "regex-2022.4.24-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b415b82e5be7389ec5ee7ee35431e4a549ea327caacf73b697c6b3538cb5c87f"},
{file = "regex-2022.4.24-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9dae5affbb66178dad6c6fd5b02221ca9917e016c75ee3945e9a9563eb1fbb6f"},
{file = "regex-2022.4.24-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e65580ae3137bce712f505ec7c2d700aef0014a3878c4767b74aff5895fc454f"},
{file = "regex-2022.4.24-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3e9e983fc8e0d4d5ded7caa5aed39ca2cf6026d7e39801ef6f0af0b1b6cd9276"},
{file = "regex-2022.4.24-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cfad3a770839aa456ff9a9aa0e253d98b628d005a3ccb37da1ff9be7c84fee16"},
{file = "regex-2022.4.24-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:ed625205f5f26984382b68e4cbcbc08e6603c9e84c14b38457170b0cc71c823b"},
{file = "regex-2022.4.24-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:c4fdf837666f7793a5c3cfa2f2f39f03eb6c7e92e831bc64486c2f547580c2b3"},
{file = "regex-2022.4.24-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:ed26c3d2d62c6588e0dad175b8d8cc0942a638f32d07b80f92043e5d73b7db67"},
{file = "regex-2022.4.24-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:f89d26e50a4c7453cb8c415acd09e72fbade2610606a9c500a1e48c43210a42d"},
{file = "regex-2022.4.24-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:97af238389cb029d63d5f2d931a7e8f5954ad96e812de5faaed373b68e74df86"},
{file = "regex-2022.4.24-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:be392d9cd5309509175a9d7660dc17bf57084501108dbff0c5a8bfc3646048c3"},
{file = "regex-2022.4.24-cp39-cp39-win32.whl", hash = "sha256:bcc6f7a3a95119c3568c572ca167ada75f8319890706283b9ba59b3489c9bcb3"},
{file = "regex-2022.4.24-cp39-cp39-win_amd64.whl", hash = "sha256:5b9c7b6895a01204296e9523b3e12b43e013835a9de035a783907c2c1bc447f0"},
{file = "regex-2022.4.24.tar.gz", hash = "sha256:92183e9180c392371079262879c6532ccf55f808e6900df5d9f03c9ca8807255"},
]
requests = [ requests = [
{file = "requests-2.27.1-py2.py3-none-any.whl", hash = "sha256:f22fa1e554c9ddfd16e6e41ac79759e17be9e492b3587efa038054674760e72d"}, {file = "requests-2.27.1-py2.py3-none-any.whl", hash = "sha256:f22fa1e554c9ddfd16e6e41ac79759e17be9e492b3587efa038054674760e72d"},
{file = "requests-2.27.1.tar.gz", hash = "sha256:68d7c56fd5a8999887728ef304a6d12edc7be74f1cfa47714fc8b414525c9a61"}, {file = "requests-2.27.1.tar.gz", hash = "sha256:68d7c56fd5a8999887728ef304a6d12edc7be74f1cfa47714fc8b414525c9a61"},
@ -1341,6 +1615,41 @@ sphinxcontrib-serializinghtml = [
sqlitedict = [ sqlitedict = [
{file = "sqlitedict-2.0.0.tar.gz", hash = "sha256:23a370416f4e1e962daa293382f3a8dbc4127e6a0abc06a5d4e58e6902f05d17"}, {file = "sqlitedict-2.0.0.tar.gz", hash = "sha256:23a370416f4e1e962daa293382f3a8dbc4127e6a0abc06a5d4e58e6902f05d17"},
] ]
tokenizers = [
{file = "tokenizers-0.12.1-cp310-cp310-macosx_10_11_x86_64.whl", hash = "sha256:d737df0f8f26e093a82bfb106b6cfb510a0e9302d35834568e5b20b73ddc5a9c"},
{file = "tokenizers-0.12.1-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:f1271224acafb27639c432e1ce4e7d38eab40305ba1c546e871d5c8a32f4f195"},
{file = "tokenizers-0.12.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cdeba37c2fb44e1aec8a72af4cb369655b59ba313181b1b4b8183f08e759c49c"},
{file = "tokenizers-0.12.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:53b5f4012ce3ffddd5b00827441b80dc7a0f6b41f4fc5248ae6d36e7d3920c6d"},
{file = "tokenizers-0.12.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5188e13fc09edfe05712ca3ae5a44e7f2b0137927b1ca210d0fad90d3e58315a"},
{file = "tokenizers-0.12.1-cp310-cp310-win32.whl", hash = "sha256:eff5ff411f18a201eec137b7b32fcb55e0c48b372d370bd24f965f5bad471fa4"},
{file = "tokenizers-0.12.1-cp310-cp310-win_amd64.whl", hash = "sha256:bdbca79726fe883c696088ea163715b2f902aec638a8e24bcf9790ff8fa45019"},
{file = "tokenizers-0.12.1-cp36-cp36m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:28825dade9e52ad464164020758f9d49eb7251c32b6ae146601c506a23c67c0e"},
{file = "tokenizers-0.12.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:91906d725cb84d8ee71ce05fbb155d39d494849622b4f9349e5176a8eb01c49b"},
{file = "tokenizers-0.12.1-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:230f51a0a82ca7b90077eaca2415f12ff9bd144607888b9c50c2ee543452322e"},
{file = "tokenizers-0.12.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8d4339c376b695de2ad8ccaebffa75e4dc1d7857be1103d80e7925b34af8cf78"},
{file = "tokenizers-0.12.1-cp37-cp37m-macosx_10_11_x86_64.whl", hash = "sha256:27d93b712aa2d4346aa506ecd4ec9e94edeebeaf2d484357b482cdeffc02b5f5"},
{file = "tokenizers-0.12.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:7f4cb68dc538b52240d1986d2034eb0a6373be2ab5f0787d1be3ad1444ce71b7"},
{file = "tokenizers-0.12.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ae6c04b629ac2cd2f695739988cb70b9bd8d5e7f849f5b14c4510e942bee5770"},
{file = "tokenizers-0.12.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6a38b2019d4807d42afeff603a119094ee00f63bea2921136524c8814e9003f8"},
{file = "tokenizers-0.12.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fde8dccb9033fa344ffce3ee1837939a50e7a210a768f1cf2059beeafa755481"},
{file = "tokenizers-0.12.1-cp37-cp37m-win32.whl", hash = "sha256:38625595b2fd37bfcce64ff9bfb6868c07e9a7b7f205c909d94a615ce9472287"},
{file = "tokenizers-0.12.1-cp37-cp37m-win_amd64.whl", hash = "sha256:01abe6fbfe55e4131ca0c4c3d1a9d7ef5df424a8d536e998d2a4fc0bc57935f4"},
{file = "tokenizers-0.12.1-cp38-cp38-macosx_10_11_x86_64.whl", hash = "sha256:7c5c54080a7d5c89c990e0d478e0882dbac88926d43323a3aa236492a3c9455f"},
{file = "tokenizers-0.12.1-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:419d113e3bcc4fe20a313afc47af81e62906306b08fe1601e1443d747d46af1f"},
{file = "tokenizers-0.12.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b9779944559cb7ace6a8516e402895f239b0d9d3c833c67dbaec496310e7e206"},
{file = "tokenizers-0.12.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7d43de14b4469b57490dbaf136a31c266cb676fa22320f01f230af9219ae9034"},
{file = "tokenizers-0.12.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:258873634406bd1d438c799993a5e44bbc0132ff055985c03c4fe30f702e9a33"},
{file = "tokenizers-0.12.1-cp38-cp38-win32.whl", hash = "sha256:3f2647cc256d6a53d18b9dcd71d377828e9f8991fbcbd6fcd8ca2ceb174552b0"},
{file = "tokenizers-0.12.1-cp38-cp38-win_amd64.whl", hash = "sha256:62a723bd4b18bc55121f5c34cd8efd6c651f2d3b81f81dd50e5351fb65b8a617"},
{file = "tokenizers-0.12.1-cp39-cp39-macosx_10_11_x86_64.whl", hash = "sha256:411ebc89228f30218ffa9d9c49d414864b0df5026a47c24820431821c4360460"},
{file = "tokenizers-0.12.1-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:619728df2551bdfe6f96ff177f9ded958e7ed9e2af94c8d5ac2834d1eb06d112"},
{file = "tokenizers-0.12.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8cea98f3f9577d1541b7bb0f7a3308a911751067e1d83e01485c9d3411bbf087"},
{file = "tokenizers-0.12.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:664f36f0a0d409c24f2201d495161fec4d8bc93e091fbb78814eb426f29905a3"},
{file = "tokenizers-0.12.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0bf2380ad59c50222959a9b6f231339200a826fc5cb2be09ff96d8a59f65fc5e"},
{file = "tokenizers-0.12.1-cp39-cp39-win32.whl", hash = "sha256:6a7a106d04154c2159db6cd7d042af2e2e0e53aee432f872fe6c8be45100436a"},
{file = "tokenizers-0.12.1-cp39-cp39-win_amd64.whl", hash = "sha256:2158baf80cbc09259bfd6e0e0fc4597b611e7a72ad5443dad63918a90f1dd304"},
{file = "tokenizers-0.12.1.tar.gz", hash = "sha256:070746f86efa6c873db341e55cf17bb5e7bdd5450330ca8eca542f5c3dab2c66"},
]
toml = [ toml = [
{file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"}, {file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"},
{file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"}, {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"},
@ -1349,10 +1658,35 @@ tomli = [
{file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"},
{file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"},
] ]
torch = [
{file = "torch-1.11.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:62052b50fffc29ca7afc0c04ef8206b6f1ca9d10629cb543077e12967e8d0398"},
{file = "torch-1.11.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:866bfba29ac98dec35d893d8e17eaec149d0ac7a53be7baae5c98069897db667"},
{file = "torch-1.11.0-cp310-cp310-win_amd64.whl", hash = "sha256:951640fb8db308a59d9b510e7d1ad910aff92913323bbe4bc75435347ddd346d"},
{file = "torch-1.11.0-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:5d77b5ece78fdafa5c7f42995ff9474399d22571cd6b2de21a5d666306a2ff8c"},
{file = "torch-1.11.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:b5a38682769b544c875ecc34bcb81fbad5c922139b61319aacffcfd8a32f528c"},
{file = "torch-1.11.0-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:f82d77695a60626f2b7382d85bc566de8a6b3e50d32080755abc040db802e419"},
{file = "torch-1.11.0-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:b96654d42566080a134e784705f33f8536b3b95b5dcde357ed7879b1692a5f78"},
{file = "torch-1.11.0-cp37-cp37m-win_amd64.whl", hash = "sha256:8ee7c2e8d7f7020d5bfbc1bb91b9591044c26bbd0cee5e4f694cfd7ed8649260"},
{file = "torch-1.11.0-cp37-none-macosx_10_9_x86_64.whl", hash = "sha256:6860b1d1bf0bb0b67a6bd47f85a0e4c825b518eea13b5d6101999dbbcbd5bc0c"},
{file = "torch-1.11.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:4322aa29f50da7f404db06cdf30896ea67b09f673af4a985afc7162bc897864d"},
{file = "torch-1.11.0-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:e4d2e0ddd652f30e94cff750220324ec45705d4ecc69658f773b3cb1c7a28dd0"},
{file = "torch-1.11.0-cp38-cp38-win_amd64.whl", hash = "sha256:34ce5ea4d8d85da32cdbadb50d4585106901e9f8a3527991daa70c13a09de1f7"},
{file = "torch-1.11.0-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:0ccc85cd06227a3edf809e2c795fd5762c3d4e8a38b5c9f744c6e7cf841361bb"},
{file = "torch-1.11.0-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:c1554e49d74f1b2c3e7202d77056ba2dd7465437585bac64062b580f714a44e9"},
{file = "torch-1.11.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:58c7814502b1c129a650d7092033bbb0bbd64faf1a7941631aaa1aeaddc37570"},
{file = "torch-1.11.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:831cf588f01dda9409e75576741d2823453990dee2983d670f2584b37a01adf7"},
{file = "torch-1.11.0-cp39-cp39-win_amd64.whl", hash = "sha256:44a1d02fd20f827f0f36dc26fdcfc45e793806a6ad52769a22260655a77a4369"},
{file = "torch-1.11.0-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:50fd9bf85c578c871c28f1cb0ace9dfc6024401c7f399b174fb0f370899f4454"},
{file = "torch-1.11.0-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:0e48af66ad755f0f9c5f2664028a414f57c49d6adc37e77e06fe0004da4edb61"},
]
tqdm = [ tqdm = [
{file = "tqdm-4.64.0-py2.py3-none-any.whl", hash = "sha256:74a2cdefe14d11442cedf3ba4e21a3b84ff9a2dbdc6cfae2c34addb2a14a5ea6"}, {file = "tqdm-4.64.0-py2.py3-none-any.whl", hash = "sha256:74a2cdefe14d11442cedf3ba4e21a3b84ff9a2dbdc6cfae2c34addb2a14a5ea6"},
{file = "tqdm-4.64.0.tar.gz", hash = "sha256:40be55d30e200777a307a7585aee69e4eabb46b4ec6a4b4a5f2d9f11e7d5408d"}, {file = "tqdm-4.64.0.tar.gz", hash = "sha256:40be55d30e200777a307a7585aee69e4eabb46b4ec6a4b4a5f2d9f11e7d5408d"},
] ]
transformers = [
{file = "transformers-4.19.2-py3-none-any.whl", hash = "sha256:1416315b7c5ff1f56d3915f416b67aa254a9907fbb73ef7f7bffc9210446b5fa"},
{file = "transformers-4.19.2.tar.gz", hash = "sha256:e19a4ff07458eda143c738e5259caf48449fcf078a63d6b1bd1aa806543440a3"},
]
typing-extensions = [ typing-extensions = [
{file = "typing_extensions-4.2.0-py3-none-any.whl", hash = "sha256:6657594ee297170d19f67d55c05852a874e7eb634f4f753dbd667855e07c1708"}, {file = "typing_extensions-4.2.0-py3-none-any.whl", hash = "sha256:6657594ee297170d19f67d55c05852a874e7eb634f4f753dbd667855e07c1708"},
{file = "typing_extensions-4.2.0.tar.gz", hash = "sha256:f1c24655a0da0d1b67f07e17a5e6b2a105894e6824b92096378bb3668ef02376"}, {file = "typing_extensions-4.2.0.tar.gz", hash = "sha256:f1c24655a0da0d1b67f07e17a5e6b2a105894e6824b92096378bb3668ef02376"},
@ -1365,6 +1699,10 @@ virtualenv = [
{file = "virtualenv-20.14.1-py2.py3-none-any.whl", hash = "sha256:e617f16e25b42eb4f6e74096b9c9e37713cf10bf30168fb4a739f3fa8f898a3a"}, {file = "virtualenv-20.14.1-py2.py3-none-any.whl", hash = "sha256:e617f16e25b42eb4f6e74096b9c9e37713cf10bf30168fb4a739f3fa8f898a3a"},
{file = "virtualenv-20.14.1.tar.gz", hash = "sha256:ef589a79795589aada0c1c5b319486797c03b67ac3984c48c669c0e4f50df3a5"}, {file = "virtualenv-20.14.1.tar.gz", hash = "sha256:ef589a79795589aada0c1c5b319486797c03b67ac3984c48c669c0e4f50df3a5"},
] ]
werkzeug = [
{file = "Werkzeug-2.1.2-py3-none-any.whl", hash = "sha256:72a4b735692dd3135217911cbeaa1be5fa3f62bffb8745c5215420a03dc55255"},
{file = "Werkzeug-2.1.2.tar.gz", hash = "sha256:1ce08e8093ed67d638d63879fd1ba3735817f7a80de3674d293f5984f25fb6e6"},
]
wrapt = [ wrapt = [
{file = "wrapt-1.14.1-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:1b376b3f4896e7930f1f772ac4b064ac12598d1c38d04907e696cc4d794b43d3"}, {file = "wrapt-1.14.1-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:1b376b3f4896e7930f1f772ac4b064ac12598d1c38d04907e696cc4d794b43d3"},
{file = "wrapt-1.14.1-cp27-cp27m-manylinux1_i686.whl", hash = "sha256:903500616422a40a98a5a3c4ff4ed9d0066f3b4c951fa286018ecdf0750194ef"}, {file = "wrapt-1.14.1-cp27-cp27m-manylinux1_i686.whl", hash = "sha256:903500616422a40a98a5a3c4ff4ed9d0066f3b4c951fa286018ecdf0750194ef"},
@ -1431,3 +1769,7 @@ wrapt = [
{file = "wrapt-1.14.1-cp39-cp39-win_amd64.whl", hash = "sha256:dee60e1de1898bde3b238f18340eec6148986da0455d8ba7848d50470a7a32fb"}, {file = "wrapt-1.14.1-cp39-cp39-win_amd64.whl", hash = "sha256:dee60e1de1898bde3b238f18340eec6148986da0455d8ba7848d50470a7a32fb"},
{file = "wrapt-1.14.1.tar.gz", hash = "sha256:380a85cf89e0e69b7cfbe2ea9f765f004ff419f34194018a6827ac0e3edfed4d"}, {file = "wrapt-1.14.1.tar.gz", hash = "sha256:380a85cf89e0e69b7cfbe2ea9f765f004ff419f34194018a6827ac0e3edfed4d"},
] ]
zipp = [
{file = "zipp-3.8.0-py3-none-any.whl", hash = "sha256:c4f6e5bbf48e74f7a38e7cc5b0480ff42b0ae5178957d564d18932525d5cf099"},
{file = "zipp-3.8.0.tar.gz", hash = "sha256:56bf8aadb83c24db6c4b577e13de374ccfb67da2078beba1d037c17980bf43ad"},
]

View File

@ -21,6 +21,10 @@ sqlitedict = "^2.0.0"
openai = "^0.18.1" openai = "^0.18.1"
redis = "^4.3.1" redis = "^4.3.1"
dill = "^0.3.5" dill = "^0.3.5"
Flask = "^2.1.2"
transformers = "^4.19.2"
torch = "^1.11.0"
requests = "^2.27.1"
[tool.poetry.dev-dependencies] [tool.poetry.dev-dependencies]
black = "^22.3.0" black = "^22.3.0"

View File

@ -1 +0,0 @@
"""Test client."""

View File

@ -1,59 +0,0 @@
"""Response test."""
import json
import pytest
from manifest.clients import Response
def test_init():
"""Test response initialization."""
with pytest.raises(ValueError) as exc_info:
response = Response(4)
assert str(exc_info.value) == "Response must be str or dict"
with pytest.raises(ValueError) as exc_info:
response = Response({"test": "hello"})
assert (
str(exc_info.value)
== "Response must be serialized to a dict with a list of choices"
)
with pytest.raises(ValueError) as exc_info:
response = Response({"choices": [{"blah": "hello"}]})
assert str(exc_info.value) == (
"Response must be serialized to a dict "
"with a list of choices with text field"
)
response = Response({"choices": [{"text": "hello"}]})
assert response.response == {"choices": [{"text": "hello"}]}
response = Response(json.dumps({"choices": [{"text": "hello"}]}))
assert response.response == {"choices": [{"text": "hello"}]}
def test_getitem():
"""Test response getitem."""
response = Response({"choices": [{"text": "hello"}]})
assert response["choices"] == [{"text": "hello"}]
def test_serialize():
"""Test response serialization."""
response = Response({"choices": [{"text": "hello"}]})
assert Response.deserialize(response.serialize()).response == {
"choices": [{"text": "hello"}]
}
def test_get_results():
"""Test response get results."""
response = Response({"choices": []})
assert response.get_results() is None
response = Response({"choices": [{"text": "hello"}]})
assert response.get_results() == "hello"
response = Response(
{"choices": [{"text": "hello"}, {"text": "my"}, {"text": "name"}]}
)
assert response.get_results() == ["hello", "my", "name"]

View File

@ -3,6 +3,7 @@ import os
import shutil import shutil
import pytest import pytest
import redis
@pytest.fixture @pytest.fixture
@ -32,5 +33,5 @@ def redis_cache():
port = os.environ.get("REDIS_PORT", 6379) port = os.environ.get("REDIS_PORT", 6379)
yield f"{host}:{port}" yield f"{host}:{port}"
# Clear out the database # Clear out the database
# db = redis.Redis(host=host, port=port) db = redis.Redis(host=host, port=port)
# db.flushdb() db.flushdb()

View File

@ -1,28 +1,28 @@
"""Cache test.""" """Cache test."""
import pytest import pytest
from redis import Redis
from sqlitedict import SqliteDict from sqlitedict import SqliteDict
from manifest.caches.redis import RedisCache from manifest.caches.redis import RedisCache
from manifest.caches.sqlite import SQLiteCache from manifest.caches.sqlite import SQLiteCache
from manifest.clients import Response
@pytest.mark.usefixtures("sqlite_cache") @pytest.mark.usefixtures("sqlite_cache")
@pytest.mark.usefixtures("redis_cache") @pytest.mark.usefixtures("redis_cache")
@pytest.mark.parametrize("cache_type", ["sqlite"]) @pytest.mark.parametrize("cache_type", ["sqlite", "redis"])
def test_init(sqlite_cache, redis_cache, cache_type): def test_init(sqlite_cache, redis_cache, cache_type):
"""Test cache initialization.""" """Test cache initialization."""
if cache_type == "sqlite": if cache_type == "sqlite":
cache = SQLiteCache(sqlite_cache) cache = SQLiteCache(sqlite_cache)
assert isinstance(cache.cache, SqliteDict) assert isinstance(cache.cache, SqliteDict)
assert isinstance(cache.prompt_cache, SqliteDict)
else: else:
cache = RedisCache(redis_cache) cache = RedisCache(redis_cache)
assert isinstance(cache.redis, Redis)
@pytest.mark.usefixtures("sqlite_cache") @pytest.mark.usefixtures("sqlite_cache")
@pytest.mark.usefixtures("redis_cache") @pytest.mark.usefixtures("redis_cache")
@pytest.mark.parametrize("cache_type", ["sqlite"]) @pytest.mark.parametrize("cache_type", ["sqlite", "redis"])
def test_key_get_and_set(sqlite_cache, redis_cache, cache_type): def test_key_get_and_set(sqlite_cache, redis_cache, cache_type):
"""Test cache key get and set.""" """Test cache key get and set."""
if cache_type == "sqlite": if cache_type == "sqlite":
@ -32,7 +32,6 @@ def test_key_get_and_set(sqlite_cache, redis_cache, cache_type):
cache.set_key("test", "valueA") cache.set_key("test", "valueA")
cache.set_key("testA", "valueB") cache.set_key("testA", "valueB")
assert cache.get_key("test") == "valueA" assert cache.get_key("test") == "valueA"
assert cache.get_key("testA") == "valueB" assert cache.get_key("testA") == "valueB"
@ -46,7 +45,7 @@ def test_key_get_and_set(sqlite_cache, redis_cache, cache_type):
@pytest.mark.usefixtures("sqlite_cache") @pytest.mark.usefixtures("sqlite_cache")
@pytest.mark.usefixtures("redis_cache") @pytest.mark.usefixtures("redis_cache")
@pytest.mark.parametrize("cache_type", ["sqlite"]) @pytest.mark.parametrize("cache_type", ["sqlite", "redis"])
def test_get(sqlite_cache, redis_cache, cache_type): def test_get(sqlite_cache, redis_cache, cache_type):
"""Test cache save prompt.""" """Test cache save prompt."""
if cache_type == "sqlite": if cache_type == "sqlite":
@ -54,16 +53,19 @@ def test_get(sqlite_cache, redis_cache, cache_type):
else: else:
cache = RedisCache(redis_cache) cache = RedisCache(redis_cache)
test_request = {"test": "hello", "testA": "world"} test_request = {"test": "hello", "testA": "world"}
compute = lambda: Response({"choices": [{"text": "hello"}]}) compute = lambda: {"choices": [{"text": "hello"}]}
response, cached = cache.get(test_request, overwrite_cache=False, compute=compute) response = cache.get(test_request, overwrite_cache=False, compute=compute)
assert response.get_results() == "hello" assert response.get_response() == "hello"
assert not cached assert not response.is_cached()
assert response.get_request() == test_request
response, cached = cache.get(test_request, overwrite_cache=False, compute=compute) response = cache.get(test_request, overwrite_cache=False, compute=compute)
assert response.get_results() == "hello" assert response.get_response() == "hello"
assert cached assert response.is_cached()
assert response.get_request() == test_request
response, cached = cache.get(test_request, overwrite_cache=True, compute=compute) response = cache.get(test_request, overwrite_cache=True, compute=compute)
assert response.get_results() == "hello" assert response.get_response() == "hello"
assert not cached assert not response.is_cached()
assert response.get_request() == test_request

21
tests/test_client.py Normal file
View File

@ -0,0 +1,21 @@
"""
Test client.
We just test the dummy client as we don't want to load a model or use OpenAI tokens.
"""
from manifest.clients.dummy import DummyClient
def test_init():
"""Test client initialization."""
client = DummyClient(connection_str=None, num_results=3)
assert client.num_results == 3
def test_get_request():
"""Test client get request."""
client = DummyClient(connection_str=None, num_results=3)
request_func, request_params = client.get_request("hello")
assert request_params == {"prompt": "hello", "num_results": 3}
assert request_func() == {"choices": [{"text": "hello"}] * 3}

View File

@ -1,7 +1,7 @@
"""Manifest test.""" """Manifest test."""
import pytest import pytest
from manifest import Manifest, Prompt from manifest import Manifest, Prompt, Response
from manifest.caches.cache import request_to_key from manifest.caches.cache import request_to_key
from manifest.caches.sqlite import SQLiteCache from manifest.caches.sqlite import SQLiteCache
from manifest.clients.dummy import DummyClient from manifest.clients.dummy import DummyClient
@ -18,22 +18,27 @@ def test_init(sqlite_cache):
assert manifest.client_name == "dummy" assert manifest.client_name == "dummy"
assert isinstance(manifest.client, DummyClient) assert isinstance(manifest.client, DummyClient)
assert isinstance(manifest.cache, SQLiteCache) assert isinstance(manifest.cache, SQLiteCache)
assert manifest.client.num_results == 1
assert manifest.stop_token == ""
manifest = Manifest( manifest = Manifest(
client_name="dummy", client_name="dummy",
cache_name="sqlite", cache_name="sqlite",
cache_connection=sqlite_cache, cache_connection=sqlite_cache,
num_results=3, num_results=3,
stop_token="\n",
) )
assert manifest.client_name == "dummy" assert manifest.client_name == "dummy"
assert isinstance(manifest.client, DummyClient) assert isinstance(manifest.client, DummyClient)
assert isinstance(manifest.cache, SQLiteCache) assert isinstance(manifest.cache, SQLiteCache)
assert manifest.client.num_results == 3 assert manifest.client.num_results == 3
assert manifest.stop_token == "\n"
@pytest.mark.usefixtures("sqlite_cache") @pytest.mark.usefixtures("sqlite_cache")
@pytest.mark.parametrize("num_results", [1, 2]) @pytest.mark.parametrize("num_results", [1, 2])
def test_run(sqlite_cache, num_results): @pytest.mark.parametrize("return_response", [True, False])
def test_run(sqlite_cache, num_results, return_response):
"""Test manifest run.""" """Test manifest run."""
manifest = Manifest( manifest = Manifest(
client_name="dummy", client_name="dummy",
@ -42,7 +47,12 @@ def test_run(sqlite_cache, num_results):
num_results=num_results, num_results=num_results,
) )
prompt = Prompt("This is a prompt") prompt = Prompt("This is a prompt")
res = manifest.run(prompt) result = manifest.run(prompt, return_response=return_response)
if return_response:
assert isinstance(result, Response)
res = result.get_response(manifest.stop_token)
else:
res = result
assert ( assert (
manifest.cache.get_key( manifest.cache.get_key(
request_to_key( request_to_key(
@ -61,7 +71,12 @@ def test_run(sqlite_cache, num_results):
assert res == ["hello", "hello"] assert res == ["hello", "hello"]
prompt = Prompt(lambda x: f"{x} is a prompt") prompt = Prompt(lambda x: f"{x} is a prompt")
res = manifest.run(prompt, "Hello") result = manifest.run(prompt, "Hello", return_response=return_response)
if return_response:
assert isinstance(result, Response)
res = result.get_response(manifest.stop_token)
else:
res = result
assert ( assert (
manifest.cache.get_key( manifest.cache.get_key(
request_to_key( request_to_key(
@ -79,10 +94,37 @@ def test_run(sqlite_cache, num_results):
else: else:
assert res == ["hello", "hello"] assert res == ["hello", "hello"]
prompt = Prompt(lambda x: f"{x} is a prompt")
result = manifest.run(
prompt, "Hello", stop_token="ll", return_response=return_response
)
if return_response:
assert isinstance(result, Response)
res = result.get_response(stop_token="ll")
else:
res = result
assert (
manifest.cache.get_key(
request_to_key(
{
"prompt": "Hello is a prompt",
"client_name": "dummy",
"num_results": num_results,
}
)
)
is not None
)
if num_results == 1:
assert res == "he"
else:
assert res == ["he", "he"]
@pytest.mark.usefixtures("sqlite_cache") @pytest.mark.usefixtures("sqlite_cache")
@pytest.mark.parametrize("num_results", [1, 2]) @pytest.mark.parametrize("num_results", [1, 2])
def test_batch_run(sqlite_cache, num_results): @pytest.mark.parametrize("return_response", [True, False])
def test_batch_run(sqlite_cache, num_results, return_response):
"""Test manifest run.""" """Test manifest run."""
manifest = Manifest( manifest = Manifest(
client_name="dummy", client_name="dummy",
@ -91,15 +133,38 @@ def test_batch_run(sqlite_cache, num_results):
num_results=num_results, num_results=num_results,
) )
prompt = Prompt("This is a prompt") prompt = Prompt("This is a prompt")
res = manifest.run_batch(prompt) result = manifest.run_batch(prompt, return_response=return_response)
if return_response:
res = [r.get_response(manifest.stop_token) for r in result]
else:
res = result
if num_results == 1: if num_results == 1:
assert res == ["hello"] assert res == ["hello"]
else: else:
assert res == [["hello", "hello"]] assert res == [["hello", "hello"]]
prompt = Prompt(lambda x: f"{x} is a prompt") prompt = Prompt(lambda x: f"{x} is a prompt")
res = manifest.run_batch(prompt, ["Hello", "Hello"]) result = manifest.run_batch(
prompt, ["Hello", "Hello"], return_response=return_response
)
if return_response:
res = [r.get_response(manifest.stop_token) for r in result]
else:
res = result
if num_results == 1: if num_results == 1:
assert res == ["hello", "hello"] assert res == ["hello", "hello"]
else: else:
assert res == [["hello", "hello"], ["hello", "hello"]] assert res == [["hello", "hello"], ["hello", "hello"]]
prompt = Prompt(lambda x: f"{x} is a prompt")
result = manifest.run_batch(
prompt, ["Hello", "Hello"], stop_token="ll", return_response=return_response
)
if return_response:
res = [r.get_response(stop_token="ll") for r in result]
else:
res = result
if num_results == 1:
assert res == ["he", "he"]
else:
assert res == [["he", "he"], ["he", "he"]]

74
tests/test_response.py Normal file
View File

@ -0,0 +1,74 @@
"""Response test."""
import pytest
from manifest import Response
def test_init():
"""Test response initialization."""
with pytest.raises(ValueError) as exc_info:
response = Response(4, False, {})
assert str(exc_info.value) == "Response must be str or dict"
with pytest.raises(ValueError) as exc_info:
response = Response({"test": "hello"}, False, {})
assert (
str(exc_info.value)
== "Response must be serialized to a dict with a list of choices"
)
with pytest.raises(ValueError) as exc_info:
response = Response({"choices": [{"blah": "hello"}]}, False, {})
assert str(exc_info.value) == (
"Response must be serialized to a dict "
"with a list of choices with text field"
)
response = Response({"choices": [{"text": "hello"}]}, False, {})
assert response._response == {"choices": [{"text": "hello"}]}
assert response._cached is False
assert response._request_params == {}
response = Response({"choices": [{"text": "hello"}]}, True, {"request": "yoyo"})
assert response._response == {"choices": [{"text": "hello"}]}
assert response._cached is True
assert response._request_params == {"request": "yoyo"}
def test_getters():
"""Test response cached."""
response = Response({"choices": [{"text": "hello"}]}, False, {})
assert response.get_raw_response() == {"choices": [{"text": "hello"}]}
assert response.is_cached() is False
assert response.get_request() == {}
response = Response({"choices": [{"text": "hello"}]}, True, {"request": "yoyo"})
assert response.get_raw_response() == {"choices": [{"text": "hello"}]}
assert response.is_cached() is True
assert response.get_request() == {"request": "yoyo"}
def test_serialize():
"""Test response serialization."""
response = Response({"choices": [{"text": "hello"}]}, True, {"request": "yoyo"})
deserialized_response = Response.deserialize(response.serialize())
assert deserialized_response._response == {"choices": [{"text": "hello"}]}
assert deserialized_response.is_cached() is True
assert deserialized_response._request_params == {"request": "yoyo"}
def test_get_results():
"""Test response get results."""
response = Response({"choices": []}, True, {"request": "yoyo"})
assert response.get_response() is None
assert response.get_response(stop_token="ll") is None
response = Response({"choices": [{"text": "hello"}]}, True, {"request": "yoyo"})
assert response.get_response() == "hello"
assert response.get_response(stop_token="ll") == "he"
response = Response(
{"choices": [{"text": "hello"}, {"text": "my"}, {"text": "name"}]},
True,
{"request": "yoyo"},
)
assert response.get_response() == ["hello", "my", "name"]
assert response.get_response(stop_token="m") == ["hello", "", "na"]