mirror of
https://github.com/HazyResearch/manifest
synced 2024-11-02 09:40:58 +00:00
Merge pull request #2 from HazyResearch/laurel/clients
[feature] redis DB, flask API, tests
This commit is contained in:
commit
9dd292f2b1
5
.flake8
5
.flake8
@ -2,9 +2,10 @@
|
|||||||
# - E731: do not assign a lambda expression, use a def
|
# - E731: do not assign a lambda expression, use a def
|
||||||
# - E402: module level import not at top of file
|
# - E402: module level import not at top of file
|
||||||
# - W503: line break before binary operator
|
# - W503: line break before binary operator
|
||||||
|
# - E203: whitespace before :
|
||||||
|
|
||||||
[flake8]
|
[flake8]
|
||||||
exclude = .git
|
exclude = .git
|
||||||
max-line-length = 88
|
max-line-length = 88
|
||||||
ignore = E731, E402, W503
|
ignore = E731, E402, W503, E203
|
||||||
per-file-ignores = __init__.py:F401
|
per-file-ignores = __init__.py:F401
|
||||||
|
13
Makefile
13
Makefile
@ -3,18 +3,17 @@ dev:
|
|||||||
poetry run pre-commit install
|
poetry run pre-commit install
|
||||||
|
|
||||||
test: dev check
|
test: dev check
|
||||||
poetry install
|
|
||||||
poetry run pytest tests
|
poetry run pytest tests
|
||||||
|
|
||||||
format:
|
format:
|
||||||
isort --atomic manifest/ tests/
|
poetry run isort --atomic manifest/ tests/
|
||||||
black manifest/ tests/
|
poetry run black manifest/ tests/
|
||||||
|
|
||||||
check:
|
check:
|
||||||
isort -c manifest/ tests/
|
poetry run isort -c manifest/ tests/
|
||||||
black manifest/ tests/ --check
|
poetry run black manifest/ tests/ --check
|
||||||
flake8 manifest/ tests/
|
poetry run flake8 manifest/ tests/
|
||||||
mypy manifest/
|
poetry run mypy manifest/
|
||||||
|
|
||||||
clean:
|
clean:
|
||||||
pip uninstall -y manifest
|
pip uninstall -y manifest
|
||||||
|
119
README.md
119
README.md
@ -19,10 +19,127 @@ or
|
|||||||
pip install poetry
|
pip install poetry
|
||||||
make dev
|
make dev
|
||||||
```
|
```
|
||||||
|
# Run
|
||||||
|
Manifest is meant to be a very light weight package to help with prompt iteration. Two key design decisions are
|
||||||
|
|
||||||
|
* Prompt are functional -- they can take an input example and dynamically change
|
||||||
|
* All models are behind API calls (e.g., OpenAI)
|
||||||
|
* Everything is cached for reuse to both save credits and to explore past results
|
||||||
|
|
||||||
|
## Prompts
|
||||||
|
A Manifest prompt is a function that accepts a single input to generate a string prompt to send to a model.
|
||||||
|
```
|
||||||
|
from manifest import Prompt
|
||||||
|
prompt = Prompt(lambda x: "Hello, my name is {x}")
|
||||||
|
print(promt("Laurel"))
|
||||||
|
>>> "Hello, my name is Laurel"
|
||||||
|
```
|
||||||
|
We also let you use static strings
|
||||||
|
```
|
||||||
|
prompt = Prompt("Hello, my name is static")
|
||||||
|
print(prompt())
|
||||||
|
>>> "Hello, my name is static"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Chaining prompts coming soon**
|
||||||
|
|
||||||
|
## Sessions
|
||||||
|
|
||||||
|
Each Manifest run is a session that connects to a model endpoint and backend database to record prompt queries. To start a Manifest session for OpenAI, make sure you run
|
||||||
|
```
|
||||||
|
export OPENAI_API_KEY=<OPENAIKEY>
|
||||||
|
```
|
||||||
|
so we can access OpenAI.
|
||||||
|
|
||||||
|
Then, in a notebook, run:
|
||||||
|
```
|
||||||
|
from manifest import Manifest
|
||||||
|
|
||||||
|
manifest = Manifest(
|
||||||
|
client_name = "openai",
|
||||||
|
cache_name = "sqlite",
|
||||||
|
cache_connection = "sqlite.cache"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
This will start a session with OpenAI and save all results to a local file called `sqlite.cache`.
|
||||||
|
|
||||||
|
We also support a Redis backend. If you have a Redis database running on port 6379, run
|
||||||
|
```
|
||||||
|
manifest = Manifest(
|
||||||
|
client_name = "openai",
|
||||||
|
cache_name = "redis",
|
||||||
|
cache_connection = "localhost:6379"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
As a hint, if you want to get Redis running, see the `docker run` command below under development.
|
||||||
|
|
||||||
|
We will explain [below](#huggingface-models) how to use Manifest for a locally hosted HuggingFace model.
|
||||||
|
|
||||||
|
Once you have a session open, you can write and develop prompts.
|
||||||
|
|
||||||
|
```
|
||||||
|
prompt = Prompt(lambda x: "Hello, my name is {x}")
|
||||||
|
result = manifest.run(prompt, "Laurel")
|
||||||
|
```
|
||||||
|
|
||||||
|
You can also run over multiple examples.
|
||||||
|
```
|
||||||
|
results = manifest.batch_run(prompt, ["Laurel", "Avanika"])
|
||||||
|
```
|
||||||
|
|
||||||
|
If something doesn't go right, you can also ask to get a raw manifest Response.
|
||||||
|
```
|
||||||
|
result_object = manifest.batch_run(prompt, ["Laurel", "Avanika"], return_response=True)
|
||||||
|
print(result_object.get_request())
|
||||||
|
print(result_object.is_cached())
|
||||||
|
print(result_object.get_response())
|
||||||
|
```
|
||||||
|
|
||||||
|
By default, we do not truncate results based on a stop token. You can change this by either passing a new stop token to a Manifest session or to a `run` or `batch_run`. If you set the stop token to `""`, we will not truncate the model output.
|
||||||
|
```
|
||||||
|
result = manifest.run(prompt, "Laurel", stop_token="and")
|
||||||
|
```
|
||||||
|
|
||||||
|
If you want to change default parameters to a model, we pass those as `kwargs` to the client.
|
||||||
|
```
|
||||||
|
result = manifest.run(prompt, "Laurel", max_tokens=50)
|
||||||
|
```
|
||||||
|
# Huggingface Models
|
||||||
|
To use a HuggingFace generative model, in `manifest/api` we have a Falsk application that hosts the models for you.
|
||||||
|
|
||||||
|
In a separate terminal or Tmux/Screen session, run
|
||||||
|
```
|
||||||
|
python3 manifest/api/app.py --model_type huggingface --model_name EleutherAI/gpt-j-6B --device 0
|
||||||
|
```
|
||||||
|
You will see the Flask session start and output a URL `http://127.0.0.1:5000`. Pass this in to Manifest. If you want to use a different port, set the `FLASK_PORT` environment variable.
|
||||||
|
|
||||||
|
```
|
||||||
|
manifest = Manifest(
|
||||||
|
client_name = "huggingface",
|
||||||
|
client_connection = "http://127.0.0.1:5000",
|
||||||
|
cache_name = "redis",
|
||||||
|
cache_connection = "localhost:6379"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Auto deployment coming soon**
|
||||||
|
|
||||||
# Development
|
# Development
|
||||||
Before submitting a PR, run
|
Before submitting a PR, run
|
||||||
```
|
```
|
||||||
export REDIS_PORT="6379" # or whatever PORT local redis is running for those tests
|
export REDIS_PORT="6380" # or whatever PORT local redis is running for those tests
|
||||||
|
cd <REDIS_PATH>
|
||||||
|
docker run -d -p 127.0.0.1:${REDIS_PORT}:6380 -v `pwd`:`pwd` -w `pwd` --name manifest_redis_test redis
|
||||||
make test
|
make test
|
||||||
```
|
```
|
||||||
|
|
||||||
|
To use our development Redis database, email [Laurel](lorr1@cs.stanford.edu). If you have access to our GCP account, in a separate terminal, run
|
||||||
|
```
|
||||||
|
gcloud compute ssh "manifest-connect" --zone "europe-west4-a" --project "hai-gcp-head-models" -- -N -L 6379:10.152.93.107:6379
|
||||||
|
```
|
||||||
|
|
||||||
|
Then if you issue
|
||||||
|
```
|
||||||
|
redis-cli ping
|
||||||
|
```
|
||||||
|
You should see a `PONG` response from our database.
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
"""Manifest init."""
|
"""Manifest init."""
|
||||||
from manifest.manifest import Manifest
|
from manifest.manifest import Manifest
|
||||||
from manifest.prompt import Prompt
|
from manifest.prompt import Prompt
|
||||||
|
from manifest.response import Response
|
||||||
|
1
manifest/api/__init__.py
Normal file
1
manifest/api/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
"""Api init."""
|
@ -1 +1,93 @@
|
|||||||
"""Flask app."""
|
"""Flask app."""
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import pkg_resources
|
||||||
|
from flask import Flask, request
|
||||||
|
|
||||||
|
from manifest.api.models.huggingface import HuggingFaceModel
|
||||||
|
from manifest.api.response import OpenAIResponse
|
||||||
|
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
app = Flask(__name__) # define app using Flask
|
||||||
|
# Will be global
|
||||||
|
model = None
|
||||||
|
PORT = int(os.environ.get("FLASK_PORT", 5000))
|
||||||
|
MODEL_CONSTRUCTORS = {
|
||||||
|
"huggingface": HuggingFaceModel,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> argparse.Namespace:
|
||||||
|
"""Generate args."""
|
||||||
|
parser = argparse.ArgumentParser(description="Model args")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_type",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Model type used for finding constructor.",
|
||||||
|
choices=["huggingface"],
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_name",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Name of model. Used in initialize of model class.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cache_dir", default=None, type=str, help="Cache directory for models."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--device", type=int, default=-1, help="Model device. -1 for CPU."
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
"""Run main."""
|
||||||
|
kwargs = parse_args()
|
||||||
|
model_type = kwargs.model_type
|
||||||
|
model_name = kwargs.model_name
|
||||||
|
|
||||||
|
# Global model
|
||||||
|
global model
|
||||||
|
model = MODEL_CONSTRUCTORS[model_type](
|
||||||
|
model_name, cache_dir=kwargs.cache_dir, device=kwargs.device
|
||||||
|
)
|
||||||
|
app.run(host="0.0.0.0", port=PORT)
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/completions", methods=["POST"])
|
||||||
|
def completions() -> Dict:
|
||||||
|
"""Get completions for generation."""
|
||||||
|
prompt = request.json["prompt"]
|
||||||
|
del request.json["prompt"]
|
||||||
|
generation_args = request.json
|
||||||
|
|
||||||
|
if not isinstance(prompt, str):
|
||||||
|
raise ValueError("Prompt must be a str")
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for generations in model.generate(prompt, **generation_args):
|
||||||
|
results.append(generations)
|
||||||
|
# transform the result into the openai format
|
||||||
|
return OpenAIResponse(results).__dict__()
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/")
|
||||||
|
def index() -> str:
|
||||||
|
"""Get index completion."""
|
||||||
|
fn = pkg_resources.resource_filename("metaseq", "service/index.html")
|
||||||
|
with open(fn) as f:
|
||||||
|
return f.read()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
1
manifest/api/models/__init__.py
Normal file
1
manifest/api/models/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
"""Models init."""
|
79
manifest/api/models/huggingface.py
Normal file
79
manifest/api/models/huggingface.py
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
"""Huggingface model."""
|
||||||
|
from typing import Any, List
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
AutoTokenizer,
|
||||||
|
GPT2LMHeadModel,
|
||||||
|
GPTJForCausalLM,
|
||||||
|
GPTNeoForCausalLM,
|
||||||
|
pipeline,
|
||||||
|
)
|
||||||
|
|
||||||
|
from manifest.api.models.model import Model
|
||||||
|
|
||||||
|
MODEL_REGISTRY = {
|
||||||
|
"EleutherAI/gpt-j-6B": GPTJForCausalLM,
|
||||||
|
"EleutherAI/gpt-neo-125M": GPTNeoForCausalLM,
|
||||||
|
"EleutherAI/gpt-neo-1.3B": GPTNeoForCausalLM,
|
||||||
|
"EleutherAI/gpt-neo-2.7B": GPTNeoForCausalLM,
|
||||||
|
"gpt2": GPT2LMHeadModel,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class HuggingFaceModel(Model):
|
||||||
|
"""Huggingface model."""
|
||||||
|
|
||||||
|
def __init__(self, model_name: str, cache_dir: str, device: int):
|
||||||
|
"""
|
||||||
|
Initialize model.
|
||||||
|
|
||||||
|
All arguments will be passed in the request from Manifest.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: model name string.
|
||||||
|
"""
|
||||||
|
model = MODEL_REGISTRY[model_name].from_pretrained(
|
||||||
|
model_name, cache_dir=cache_dir
|
||||||
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
self.pipeline = pipeline(
|
||||||
|
"text-generation", model=model, tokenizer=tokenizer, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate(self, prompt: str, **kwargs: Any) -> List[str]:
|
||||||
|
"""
|
||||||
|
Generate the prompt from model.
|
||||||
|
|
||||||
|
Outputs must be generated text, not including prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: promt to generate from.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list of generated text (list of length 1 for 1 generation).
|
||||||
|
"""
|
||||||
|
num_return = kwargs.get("n")
|
||||||
|
final_results = []
|
||||||
|
encoded_prompt = self.pipeline.tokenizer.encode(
|
||||||
|
prompt, add_special_tokens=False
|
||||||
|
)
|
||||||
|
result = self.pipeline(
|
||||||
|
prompt,
|
||||||
|
max_length=kwargs.get("max_tokens") + len(encoded_prompt),
|
||||||
|
temperature=kwargs.get("temperature"),
|
||||||
|
repetition_penalty=kwargs.get("repetition_penalty"),
|
||||||
|
top_k=kwargs.get("top_k"),
|
||||||
|
top_p=kwargs.get("top_p"),
|
||||||
|
num_return_sequences=num_return,
|
||||||
|
)
|
||||||
|
# Removes tokens removed from tokenization
|
||||||
|
decoded_prompt = self.pipeline.tokenizer.decode(
|
||||||
|
encoded_prompt, clean_up_tokenization_spaces=True
|
||||||
|
)
|
||||||
|
if num_return == 1:
|
||||||
|
final_results.append(result[0]["generated_text"][len(decoded_prompt) :])
|
||||||
|
else:
|
||||||
|
final_results.append(
|
||||||
|
[r["generated_text"][len(decoded_prompt) :] for r in result]
|
||||||
|
)
|
||||||
|
return final_results
|
@ -1 +0,0 @@
|
|||||||
"""Huggingface model."""
|
|
@ -1 +1,34 @@
|
|||||||
"""Model class."""
|
"""Model class."""
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, List
|
||||||
|
|
||||||
|
|
||||||
|
class Model(ABC):
|
||||||
|
"""Model class."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(self, model_name: str, **kwargs: Any):
|
||||||
|
"""
|
||||||
|
Initialize model.
|
||||||
|
|
||||||
|
kwargs are passed to model as default parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: model name string.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def generate(self, prompt: str, **kwargs: Any) -> List[str]:
|
||||||
|
"""
|
||||||
|
Generate the prompt from model.
|
||||||
|
|
||||||
|
Outputs must be generated text, not including prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: promt to generate from.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list of generated text (list of length 1 for 1 generation).
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
38
manifest/api/response.py
Normal file
38
manifest/api/response.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
"""OpenAI response."""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIResponse:
|
||||||
|
"""OpenAI response."""
|
||||||
|
|
||||||
|
def __init__(self, results: list) -> None:
|
||||||
|
"""Initialize response."""
|
||||||
|
self.results = results
|
||||||
|
self.response_id = str(uuid.uuid4())
|
||||||
|
self.created = int(time.time())
|
||||||
|
|
||||||
|
def __dict__(self) -> Dict[str, Any]: # type: ignore
|
||||||
|
"""Return dictionary representation of response."""
|
||||||
|
return {
|
||||||
|
"id": self.response_id,
|
||||||
|
"object": "text_completion",
|
||||||
|
"created": self.created,
|
||||||
|
"model": "flask_model",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"text": result,
|
||||||
|
# TODO: Add in more metadata for HF models
|
||||||
|
# "logprobs": {
|
||||||
|
# "tokens": result["tokens"],
|
||||||
|
# "token_logprobs": result["token_scores"],
|
||||||
|
# "text_offset": result["text_offset"],
|
||||||
|
# "top_logprobs": result["top_logprobs"],
|
||||||
|
# "finish_reason": "length",
|
||||||
|
# },
|
||||||
|
}
|
||||||
|
for result in self.results
|
||||||
|
],
|
||||||
|
}
|
@ -1,9 +1,9 @@
|
|||||||
"""Cache for queries and responses."""
|
"""Cache for queries and responses."""
|
||||||
import json
|
import json
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Callable, Dict, Tuple, Union
|
from typing import Any, Callable, Dict, Union
|
||||||
|
|
||||||
from manifest.clients.response import Response
|
from manifest.response import Response
|
||||||
|
|
||||||
|
|
||||||
def request_to_key(request: Dict) -> str:
|
def request_to_key(request: Dict) -> str:
|
||||||
@ -32,6 +32,32 @@ def key_to_request(key: str) -> Dict:
|
|||||||
return json.loads(key)
|
return json.loads(key)
|
||||||
|
|
||||||
|
|
||||||
|
def response_to_key(response: Dict) -> str:
|
||||||
|
"""
|
||||||
|
Normalize a response into a key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: response to normalize.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
normalized key.
|
||||||
|
"""
|
||||||
|
return json.dumps(response, sort_keys=True)
|
||||||
|
|
||||||
|
|
||||||
|
def key_to_response(key: str) -> Dict:
|
||||||
|
"""
|
||||||
|
Convert the normalized version to the response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: normalized key to convert.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
unnormalized response dict.
|
||||||
|
"""
|
||||||
|
return json.loads(key)
|
||||||
|
|
||||||
|
|
||||||
class Cache(ABC):
|
class Cache(ABC):
|
||||||
"""A cache for request/response pairs."""
|
"""A cache for request/response pairs."""
|
||||||
|
|
||||||
@ -97,17 +123,17 @@ class Cache(ABC):
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def get(
|
def get(
|
||||||
self, request: Dict, overwrite_cache: bool, compute: Callable[[], Response]
|
self, request: Dict, overwrite_cache: bool, compute: Callable[[], Dict]
|
||||||
) -> Tuple[Response, bool]:
|
) -> Response:
|
||||||
"""Get the result of request (by calling compute as needed)."""
|
"""Get the result of request (by calling compute as needed)."""
|
||||||
key = request_to_key(request)
|
key = request_to_key(request)
|
||||||
cached_response = self.get_key(key)
|
cached_response = self.get_key(key)
|
||||||
if cached_response and not overwrite_cache:
|
if cached_response and not overwrite_cache:
|
||||||
cached = True
|
cached = True
|
||||||
response = Response.deserialize(cached_response)
|
response = key_to_response(cached_response)
|
||||||
else:
|
else:
|
||||||
# Type Response
|
# Type Response
|
||||||
response = compute()
|
response = compute()
|
||||||
self.set_key(key, response.serialize())
|
self.set_key(key, response_to_key(response))
|
||||||
cached = False
|
cached = False
|
||||||
return response, cached
|
return Response(response, cached, request)
|
||||||
|
@ -17,13 +17,17 @@ class RedisCache(Cache):
|
|||||||
connection_str: connection string.
|
connection_str: connection string.
|
||||||
"""
|
"""
|
||||||
host, port = connection_str.split(":")
|
host, port = connection_str.split(":")
|
||||||
self.redis = redis.Redis(host=host, port=int(port))
|
self.redis = redis.Redis(host=host, port=int(port), db=0)
|
||||||
return
|
return
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
"""Close the client."""
|
"""Close the client."""
|
||||||
self.redis.close()
|
self.redis.close()
|
||||||
|
|
||||||
|
def _normalize_table_key(self, key: str, table: str) -> str:
|
||||||
|
"""Cast key for prompt key."""
|
||||||
|
return f"{table}:{key}"
|
||||||
|
|
||||||
def get_key(self, key: str, table: str = "default") -> Union[str, None]:
|
def get_key(self, key: str, table: str = "default") -> Union[str, None]:
|
||||||
"""
|
"""
|
||||||
Get the key for a request.
|
Get the key for a request.
|
||||||
@ -32,8 +36,13 @@ class RedisCache(Cache):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
key: key for cache.
|
key: key for cache.
|
||||||
|
table: table to get key in.
|
||||||
"""
|
"""
|
||||||
pass
|
norm_key = self._normalize_table_key(key, table)
|
||||||
|
if self.redis.exists(norm_key):
|
||||||
|
return self.redis.get(norm_key).decode("utf-8")
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
def set_key(self, key: str, value: str, table: str = "default") -> None:
|
def set_key(self, key: str, value: str, table: str = "default") -> None:
|
||||||
"""
|
"""
|
||||||
@ -44,8 +53,10 @@ class RedisCache(Cache):
|
|||||||
Args:
|
Args:
|
||||||
key: key for cache.
|
key: key for cache.
|
||||||
value: new value for key.
|
value: new value for key.
|
||||||
|
table: table to set key in.
|
||||||
"""
|
"""
|
||||||
self.redis[key] = value
|
self.redis.set(self._normalize_table_key(key, table), value)
|
||||||
|
self.commit()
|
||||||
|
|
||||||
def commit(self) -> None:
|
def commit(self) -> None:
|
||||||
"""Commit any results."""
|
"""Commit any results."""
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
"""SQLite cache."""
|
"""SQLite cache."""
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Union
|
from typing import Any, Union
|
||||||
|
|
||||||
from sqlitedict import SqliteDict
|
from sqlitedict import SqliteDict
|
||||||
@ -20,19 +19,18 @@ class SQLiteCache(Cache):
|
|||||||
Args:
|
Args:
|
||||||
connection_str: connection string.
|
connection_str: connection string.
|
||||||
"""
|
"""
|
||||||
self.cache_dir = connection_str
|
self.cache_file = connection_str
|
||||||
Path(self.cache_dir).mkdir(parents=True, exist_ok=True)
|
self.cache = SqliteDict(self.cache_file, autocommit=True)
|
||||||
# If more than two tables, switch to full on SQL connection
|
|
||||||
self.query_file = Path(self.cache_dir, "query.sqlite")
|
|
||||||
self.prompt_file = Path(self.cache_dir, "prompts.sqlite")
|
|
||||||
self.cache = SqliteDict(self.query_file, autocommit=False)
|
|
||||||
self.prompt_cache = SqliteDict(self.prompt_file, autocommit=False)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
"""Close the client."""
|
"""Close the client."""
|
||||||
self.cache.close()
|
self.cache.close()
|
||||||
|
|
||||||
|
def _normalize_table_key(self, key: str, table: str) -> str:
|
||||||
|
"""Cast key for prompt key."""
|
||||||
|
return f"{table}:{key}"
|
||||||
|
|
||||||
def get_key(self, key: str, table: str = "default") -> Union[str, None]:
|
def get_key(self, key: str, table: str = "default") -> Union[str, None]:
|
||||||
"""
|
"""
|
||||||
Get the key for a request.
|
Get the key for a request.
|
||||||
@ -43,14 +41,7 @@ class SQLiteCache(Cache):
|
|||||||
key: key for cache.
|
key: key for cache.
|
||||||
table: table to get key in.
|
table: table to get key in.
|
||||||
"""
|
"""
|
||||||
if table == "prompt":
|
return self.cache.get(self._normalize_table_key(key, table))
|
||||||
return self.prompt_cache.get(key)
|
|
||||||
else:
|
|
||||||
if table != "default":
|
|
||||||
raise ValueError(
|
|
||||||
"SQLiteDict only support table of `default` or `prompt`"
|
|
||||||
)
|
|
||||||
return self.cache.get(key)
|
|
||||||
|
|
||||||
def set_key(self, key: str, value: str, table: str = "default") -> None:
|
def set_key(self, key: str, value: str, table: str = "default") -> None:
|
||||||
"""
|
"""
|
||||||
@ -63,17 +54,9 @@ class SQLiteCache(Cache):
|
|||||||
value: new value for key.
|
value: new value for key.
|
||||||
table: table to set key in.
|
table: table to set key in.
|
||||||
"""
|
"""
|
||||||
if table == "prompt":
|
self.cache[self._normalize_table_key(key, table)] = value
|
||||||
self.prompt_cache[key] = value
|
|
||||||
else:
|
|
||||||
if table != "default":
|
|
||||||
raise ValueError(
|
|
||||||
"SQLiteDict only support table of `default` or `prompt`"
|
|
||||||
)
|
|
||||||
self.cache[key] = value
|
|
||||||
self.commit()
|
self.commit()
|
||||||
|
|
||||||
def commit(self) -> None:
|
def commit(self) -> None:
|
||||||
"""Commit any results."""
|
"""Commit any results."""
|
||||||
self.prompt_cache.commit()
|
|
||||||
self.cache.commit()
|
self.cache.commit()
|
||||||
|
@ -1,3 +1,2 @@
|
|||||||
"""Client init."""
|
"""Client init."""
|
||||||
from manifest.clients.client import Client
|
from manifest.clients.client import Client
|
||||||
from manifest.clients.response import Response
|
|
||||||
|
@ -2,8 +2,6 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Callable, Dict, Optional, Tuple
|
from typing import Any, Callable, Dict, Optional, Tuple
|
||||||
|
|
||||||
from manifest.clients.response import Response
|
|
||||||
|
|
||||||
|
|
||||||
class Client(ABC):
|
class Client(ABC):
|
||||||
"""Client class."""
|
"""Client class."""
|
||||||
@ -38,9 +36,7 @@ class Client(ABC):
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_request(
|
def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]:
|
||||||
self, query: str, **kwargs: Any
|
|
||||||
) -> Tuple[Callable[[], Response], Dict]:
|
|
||||||
"""
|
"""
|
||||||
Get request function.
|
Get request function.
|
||||||
|
|
||||||
|
@ -3,7 +3,6 @@ import logging
|
|||||||
from typing import Any, Callable, Dict, Optional, Tuple
|
from typing import Any, Callable, Dict, Optional, Tuple
|
||||||
|
|
||||||
from manifest.clients import Client
|
from manifest.clients import Client
|
||||||
from manifest.clients.response import Response
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -28,9 +27,7 @@ class DummyClient(Client):
|
|||||||
"""Close the client."""
|
"""Close the client."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_request(
|
def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]:
|
||||||
self, query: str, **kwargs: Any
|
|
||||||
) -> Tuple[Callable[[], Response], Dict]:
|
|
||||||
"""
|
"""
|
||||||
Get request string function.
|
Get request string function.
|
||||||
|
|
||||||
@ -46,7 +43,7 @@ class DummyClient(Client):
|
|||||||
"num_results": kwargs.get("num_results", self.num_results),
|
"num_results": kwargs.get("num_results", self.num_results),
|
||||||
}
|
}
|
||||||
|
|
||||||
def _run_completion() -> Response:
|
def _run_completion() -> Dict:
|
||||||
return Response({"choices": [{"text": "hello"}] * self.num_results})
|
return {"choices": [{"text": "hello"}] * self.num_results}
|
||||||
|
|
||||||
return _run_completion, request_params
|
return _run_completion, request_params
|
||||||
|
67
manifest/clients/huggingface.py
Normal file
67
manifest/clients/huggingface.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
"""OpenAI client."""
|
||||||
|
import logging
|
||||||
|
from typing import Any, Callable, Dict, Optional, Tuple
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from manifest.clients.client import Client
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class HuggingFaceClient(Client):
|
||||||
|
"""HuggingFace client."""
|
||||||
|
|
||||||
|
def connect(
|
||||||
|
self,
|
||||||
|
connection_str: Optional[str] = None,
|
||||||
|
temperature: Optional[float] = 1.0,
|
||||||
|
max_tokens: Optional[int] = 10,
|
||||||
|
top_p: Optional[float] = 1.0,
|
||||||
|
top_k: Optional[int] = 0,
|
||||||
|
repetition_penalty: Optional[float] = 1.0,
|
||||||
|
n: Optional[int] = 1,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Connect to the HuggingFace url."""
|
||||||
|
self.host = connection_str.rstrip("/")
|
||||||
|
self.temperature = temperature
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
self.top_p = top_p
|
||||||
|
self.top_k = top_k
|
||||||
|
self.repetition_penalty = repetition_penalty
|
||||||
|
self.n = n
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close the client."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]:
|
||||||
|
"""
|
||||||
|
Get request string function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: query string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
request function that takes no input.
|
||||||
|
request parameters as dict.
|
||||||
|
"""
|
||||||
|
request_params = {
|
||||||
|
"prompt": query,
|
||||||
|
"temperature": kwargs.get("temperature", self.temperature),
|
||||||
|
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
|
||||||
|
"top_p": kwargs.get("top_p", self.top_p),
|
||||||
|
"top_k": kwargs.get("top_k", self.top_k),
|
||||||
|
"repetition_penalty": kwargs.get(
|
||||||
|
"repetition_penalty", self.repetition_penalty
|
||||||
|
),
|
||||||
|
"n": kwargs.get("n", self.n),
|
||||||
|
}
|
||||||
|
|
||||||
|
def _run_completion() -> Dict:
|
||||||
|
post_str = self.host + "/completions"
|
||||||
|
res = requests.post(post_str, json=request_params)
|
||||||
|
return res.json()
|
||||||
|
|
||||||
|
return _run_completion, request_params
|
@ -5,7 +5,6 @@ from typing import Any, Callable, Dict, Optional, Tuple
|
|||||||
|
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
from manifest.clients import Response
|
|
||||||
from manifest.clients.client import Client
|
from manifest.clients.client import Client
|
||||||
|
|
||||||
logging.getLogger("openai").setLevel(logging.WARNING)
|
logging.getLogger("openai").setLevel(logging.WARNING)
|
||||||
@ -28,7 +27,7 @@ class OpenAIClient(Client):
|
|||||||
engine: Optional[str] = "text-ada-001",
|
engine: Optional[str] = "text-ada-001",
|
||||||
temperature: Optional[float] = 0.0,
|
temperature: Optional[float] = 0.0,
|
||||||
max_tokens: Optional[int] = 10,
|
max_tokens: Optional[int] = 10,
|
||||||
top_p: Optional[int] = 1,
|
top_p: Optional[float] = 1.0,
|
||||||
frequency_penalty: Optional[int] = 0,
|
frequency_penalty: Optional[int] = 0,
|
||||||
presence_penalty: Optional[int] = 0,
|
presence_penalty: Optional[int] = 0,
|
||||||
n: Optional[int] = 1,
|
n: Optional[int] = 1,
|
||||||
@ -59,9 +58,7 @@ class OpenAIClient(Client):
|
|||||||
"""Close the client."""
|
"""Close the client."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_request(
|
def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]:
|
||||||
self, query: str, **kwargs: Any
|
|
||||||
) -> Tuple[Callable[[], Response], Dict]:
|
|
||||||
"""
|
"""
|
||||||
Get request string function.
|
Get request string function.
|
||||||
|
|
||||||
@ -85,9 +82,9 @@ class OpenAIClient(Client):
|
|||||||
"n": kwargs.get("n", self.n),
|
"n": kwargs.get("n", self.n),
|
||||||
}
|
}
|
||||||
|
|
||||||
def _run_completion() -> Response:
|
def _run_completion() -> Dict:
|
||||||
try:
|
try:
|
||||||
return Response(openai.Completion.create(**request_params))
|
return openai.Completion.create(**request_params)
|
||||||
except openai.error.OpenAIError as e:
|
except openai.error.OpenAIError as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
raise e
|
raise e
|
||||||
|
@ -1,70 +0,0 @@
|
|||||||
"""Client response."""
|
|
||||||
import json
|
|
||||||
from typing import Dict, List, Union
|
|
||||||
|
|
||||||
|
|
||||||
class Response:
|
|
||||||
"""Response class."""
|
|
||||||
|
|
||||||
def __init__(self, response: Union[str, Dict]):
|
|
||||||
"""Initialize response."""
|
|
||||||
if isinstance(response, str):
|
|
||||||
self.response = json.loads(response)
|
|
||||||
elif isinstance(response, dict):
|
|
||||||
self.response = response
|
|
||||||
else:
|
|
||||||
raise ValueError("Response must be str or dict")
|
|
||||||
if ("choices" not in self.response) or (
|
|
||||||
not isinstance(self.response["choices"], list)
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"Response must be serialized to a dict with a list of choices"
|
|
||||||
)
|
|
||||||
if len(self.response["choices"]) > 0:
|
|
||||||
if "text" not in self.response["choices"][0]:
|
|
||||||
raise ValueError(
|
|
||||||
"Response must be serialized to a dict with a "
|
|
||||||
"list of choices with text field"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __getitem__(self, key: str) -> str:
|
|
||||||
"""
|
|
||||||
Return the response given the key.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key: key to get.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
value of key.
|
|
||||||
"""
|
|
||||||
return self.response[key]
|
|
||||||
|
|
||||||
def get_results(self) -> Union[str, List[str]]:
|
|
||||||
"""Get all text results from response."""
|
|
||||||
if len(self.response["choices"]) == 0:
|
|
||||||
return None
|
|
||||||
if len(self.response["choices"]) == 1:
|
|
||||||
return self.response["choices"][0]["text"]
|
|
||||||
return [choice["text"] for choice in self.response["choices"]]
|
|
||||||
|
|
||||||
def serialize(self) -> str:
|
|
||||||
"""
|
|
||||||
Serialize response to string.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
serialized response.
|
|
||||||
"""
|
|
||||||
return json.dumps(self.response, sort_keys=True)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def deserialize(cls, value: str) -> "Response":
|
|
||||||
"""
|
|
||||||
Deserialize string to response.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
value: serialized response.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
serialized response.
|
|
||||||
"""
|
|
||||||
return Response(value)
|
|
@ -2,18 +2,21 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Iterable, List, Optional, Union
|
from typing import Any, Iterable, List, Optional, Union
|
||||||
|
|
||||||
|
from manifest.response import Response
|
||||||
|
|
||||||
logging.getLogger("openai").setLevel(logging.WARNING)
|
logging.getLogger("openai").setLevel(logging.WARNING)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
from manifest.caches.redis import RedisCache
|
from manifest.caches.redis import RedisCache
|
||||||
from manifest.caches.sqlite import SQLiteCache
|
from manifest.caches.sqlite import SQLiteCache
|
||||||
from manifest.clients.dummy import DummyClient
|
from manifest.clients.dummy import DummyClient
|
||||||
|
from manifest.clients.huggingface import HuggingFaceClient
|
||||||
from manifest.clients.openai import OpenAIClient
|
from manifest.clients.openai import OpenAIClient
|
||||||
from manifest.prompt import Prompt
|
from manifest.prompt import Prompt
|
||||||
|
|
||||||
CLIENT_CONSTRUCTORS = {
|
CLIENT_CONSTRUCTORS = {
|
||||||
"openai": OpenAIClient,
|
"openai": OpenAIClient,
|
||||||
# "huggingface": manifest.clients.huggingface.HuggingFaceClient,
|
"huggingface": HuggingFaceClient,
|
||||||
"dummy": DummyClient,
|
"dummy": DummyClient,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -32,11 +35,20 @@ class Manifest:
|
|||||||
client_connection: Optional[str] = None,
|
client_connection: Optional[str] = None,
|
||||||
cache_name: str = "redis",
|
cache_name: str = "redis",
|
||||||
cache_connection: str = "localhost:6379",
|
cache_connection: str = "localhost:6379",
|
||||||
|
stop_token: str = "",
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize manifest.
|
Initialize manifest.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client_name: name of client.
|
||||||
|
client_connection: connection string for client.
|
||||||
|
cache_name: name of cache.
|
||||||
|
cache_connection: connection string for cache.
|
||||||
|
stop_token: stop token prompt generation.
|
||||||
|
Can be overridden in run
|
||||||
|
|
||||||
Remaining kwargs sent to client and cache.
|
Remaining kwargs sent to client and cache.
|
||||||
"""
|
"""
|
||||||
if client_name not in CLIENT_CONSTRUCTORS:
|
if client_name not in CLIENT_CONSTRUCTORS:
|
||||||
@ -50,8 +62,11 @@ class Manifest:
|
|||||||
f"Choices are {list(CACHE_CONSTRUCTORS.keys())}"
|
f"Choices are {list(CACHE_CONSTRUCTORS.keys())}"
|
||||||
)
|
)
|
||||||
self.client_name = client_name
|
self.client_name = client_name
|
||||||
self.client = CLIENT_CONSTRUCTORS[client_name](client_connection, **kwargs)
|
self.client = CLIENT_CONSTRUCTORS[client_name]( # type: ignore
|
||||||
|
client_connection, **kwargs
|
||||||
|
)
|
||||||
self.cache = CACHE_CONSTRUCTORS[cache_name](cache_connection, **kwargs)
|
self.cache = CACHE_CONSTRUCTORS[cache_name](cache_connection, **kwargs)
|
||||||
|
self.stop_token = stop_token
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
"""Close the client and cache."""
|
"""Close the client and cache."""
|
||||||
@ -63,8 +78,10 @@ class Manifest:
|
|||||||
prompt: Prompt,
|
prompt: Prompt,
|
||||||
input: Optional[Any] = None,
|
input: Optional[Any] = None,
|
||||||
overwrite_cache: bool = False,
|
overwrite_cache: bool = False,
|
||||||
|
stop_token: Optional[str] = None,
|
||||||
|
return_response: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Union[str, List[str]]:
|
) -> Union[str, List[str], Response]:
|
||||||
"""
|
"""
|
||||||
Run the prompt.
|
Run the prompt.
|
||||||
|
|
||||||
@ -72,10 +89,14 @@ class Manifest:
|
|||||||
prompt: prompt to run.
|
prompt: prompt to run.
|
||||||
input: input to prompt.
|
input: input to prompt.
|
||||||
overwrite_cache: whether to overwrite cache.
|
overwrite_cache: whether to overwrite cache.
|
||||||
|
stop_token: stop token for prompt generation.
|
||||||
|
Default is self.stop_token.
|
||||||
|
"" for no stop token.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
response from prompt.
|
response from prompt.
|
||||||
"""
|
"""
|
||||||
|
stop_token = stop_token if stop_token is not None else self.stop_token
|
||||||
prompt_str = prompt(input)
|
prompt_str = prompt(input)
|
||||||
possible_request, full_kwargs = self.client.get_request(prompt_str, **kwargs)
|
possible_request, full_kwargs = self.client.get_request(prompt_str, **kwargs)
|
||||||
# Create cacke key
|
# Create cacke key
|
||||||
@ -84,16 +105,22 @@ class Manifest:
|
|||||||
cache_key["client_name"] = self.client_name
|
cache_key["client_name"] = self.client_name
|
||||||
# Make query prompt dependent
|
# Make query prompt dependent
|
||||||
cache_key["prompt"] = prompt_str
|
cache_key["prompt"] = prompt_str
|
||||||
response, _ = self.cache.get(cache_key, overwrite_cache, possible_request)
|
response_obj = self.cache.get(cache_key, overwrite_cache, possible_request)
|
||||||
return response.get_results()
|
# Extract text results
|
||||||
|
if return_response:
|
||||||
|
return response_obj
|
||||||
|
else:
|
||||||
|
return response_obj.get_response(stop_token)
|
||||||
|
|
||||||
def run_batch(
|
def run_batch(
|
||||||
self,
|
self,
|
||||||
prompt: Prompt,
|
prompt: Prompt,
|
||||||
input: Optional[Iterable[Any]] = None,
|
input: Optional[Iterable[Any]] = None,
|
||||||
overwrite_cache: bool = False,
|
overwrite_cache: bool = False,
|
||||||
|
stop_token: Optional[str] = None,
|
||||||
|
return_response: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterable[Union[str, List[str]]]:
|
) -> Iterable[Union[str, List[str], Response]]:
|
||||||
"""
|
"""
|
||||||
Run the prompt on a batch of inputs.
|
Run the prompt on a batch of inputs.
|
||||||
|
|
||||||
@ -101,13 +128,21 @@ class Manifest:
|
|||||||
prompt: prompt to run.
|
prompt: prompt to run.
|
||||||
input: batch of inputs.
|
input: batch of inputs.
|
||||||
overwrite_cache: whether to overwrite cache.
|
overwrite_cache: whether to overwrite cache.
|
||||||
|
stop_token: stop token for prompt generation.
|
||||||
|
Default is self.stop_token.
|
||||||
|
"" for no stop token.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
batch of responses.
|
batch of responses.
|
||||||
"""
|
"""
|
||||||
if input is None:
|
if input is None:
|
||||||
input = [None]
|
input = [None]
|
||||||
return [self.run(prompt, inp, overwrite_cache, **kwargs) for inp in input]
|
return [
|
||||||
|
self.run(
|
||||||
|
prompt, inp, overwrite_cache, stop_token, return_response, **kwargs
|
||||||
|
)
|
||||||
|
for inp in input
|
||||||
|
]
|
||||||
|
|
||||||
def save_prompt(self, name: str, prompt: Prompt) -> None:
|
def save_prompt(self, name: str, prompt: Prompt) -> None:
|
||||||
"""
|
"""
|
||||||
|
106
manifest/response.py
Normal file
106
manifest/response.py
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
"""Client response."""
|
||||||
|
import json
|
||||||
|
from typing import Dict, List, Union
|
||||||
|
|
||||||
|
|
||||||
|
class Response:
|
||||||
|
"""Response class."""
|
||||||
|
|
||||||
|
def __init__(self, response: Dict, cached: bool, request_params: Dict):
|
||||||
|
"""Initialize response."""
|
||||||
|
if isinstance(response, dict):
|
||||||
|
self._response = response
|
||||||
|
else:
|
||||||
|
raise ValueError("Response must be str or dict")
|
||||||
|
if ("choices" not in self._response) or (
|
||||||
|
not isinstance(self._response["choices"], list)
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Response must be serialized to a dict with a list of choices"
|
||||||
|
)
|
||||||
|
if len(self._response["choices"]) > 0:
|
||||||
|
if "text" not in self._response["choices"][0]:
|
||||||
|
raise ValueError(
|
||||||
|
"Response must be serialized to a dict with a "
|
||||||
|
"list of choices with text field"
|
||||||
|
)
|
||||||
|
self._cached = cached
|
||||||
|
self._request_params = request_params
|
||||||
|
|
||||||
|
def is_cached(self) -> bool:
|
||||||
|
"""Check if response is cached."""
|
||||||
|
return self._cached
|
||||||
|
|
||||||
|
def get_request(self) -> Dict:
|
||||||
|
"""Get request parameters."""
|
||||||
|
return self._request_params
|
||||||
|
|
||||||
|
def get_raw_response(self) -> Dict:
|
||||||
|
"""Get response dict without parsing."""
|
||||||
|
return self._response
|
||||||
|
|
||||||
|
def get_response(self, stop_token: str = "") -> Union[str, List[str]]:
|
||||||
|
"""
|
||||||
|
Get all text results from response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stop_token: stop token for string generation
|
||||||
|
"""
|
||||||
|
process_result = (
|
||||||
|
lambda x: x.strip().split(stop_token)[0] if stop_token else x.strip()
|
||||||
|
)
|
||||||
|
if len(self._response["choices"]) == 0:
|
||||||
|
return None
|
||||||
|
if len(self._response["choices"]) == 1:
|
||||||
|
return process_result(self._response["choices"][0]["text"])
|
||||||
|
return [process_result(choice["text"]) for choice in self._response["choices"]]
|
||||||
|
|
||||||
|
def serialize(self) -> str:
|
||||||
|
"""
|
||||||
|
Serialize response to string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
serialized response.
|
||||||
|
"""
|
||||||
|
to_serialize = {
|
||||||
|
"response": self._response,
|
||||||
|
"cached": self._cached,
|
||||||
|
"request_params": self._request_params,
|
||||||
|
}
|
||||||
|
return json.dumps(to_serialize, sort_keys=True)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def deserialize(cls, value: str) -> "Response":
|
||||||
|
"""
|
||||||
|
Deserialize string to response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: serialized response.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
serialized response.
|
||||||
|
"""
|
||||||
|
deserialized = json.loads(value)
|
||||||
|
return cls(
|
||||||
|
deserialized["response"],
|
||||||
|
deserialized["cached"],
|
||||||
|
deserialized["request_params"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
"""
|
||||||
|
Get string representation of response.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
string representation of response.
|
||||||
|
"""
|
||||||
|
return self.serialize()
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
"""
|
||||||
|
Get string representation of response.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
string representation of response.
|
||||||
|
"""
|
||||||
|
return str(self)
|
356
poetry.lock
generated
356
poetry.lock
generated
@ -100,7 +100,7 @@ unicode_backport = ["unicodedata2"]
|
|||||||
name = "click"
|
name = "click"
|
||||||
version = "8.1.3"
|
version = "8.1.3"
|
||||||
description = "Composable command line interface toolkit"
|
description = "Composable command line interface toolkit"
|
||||||
category = "dev"
|
category = "main"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.7"
|
python-versions = ">=3.7"
|
||||||
|
|
||||||
@ -193,7 +193,7 @@ python-versions = ">=3.6"
|
|||||||
name = "filelock"
|
name = "filelock"
|
||||||
version = "3.7.0"
|
version = "3.7.0"
|
||||||
description = "A platform independent file lock."
|
description = "A platform independent file lock."
|
||||||
category = "dev"
|
category = "main"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.7"
|
python-versions = ">=3.7"
|
||||||
|
|
||||||
@ -237,6 +237,50 @@ python-versions = "*"
|
|||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
flake8 = "*"
|
flake8 = "*"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "flask"
|
||||||
|
version = "2.1.2"
|
||||||
|
description = "A simple framework for building complex web applications."
|
||||||
|
category = "main"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
click = ">=8.0"
|
||||||
|
importlib-metadata = {version = ">=3.6.0", markers = "python_version < \"3.10\""}
|
||||||
|
itsdangerous = ">=2.0"
|
||||||
|
Jinja2 = ">=3.0"
|
||||||
|
Werkzeug = ">=2.0"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
async = ["asgiref (>=3.2)"]
|
||||||
|
dotenv = ["python-dotenv"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "huggingface-hub"
|
||||||
|
version = "0.7.0"
|
||||||
|
description = "Client library to download and publish models on the huggingface.co hub"
|
||||||
|
category = "main"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7.0"
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
filelock = "*"
|
||||||
|
packaging = ">=20.9"
|
||||||
|
pyyaml = ">=5.1"
|
||||||
|
requests = "*"
|
||||||
|
tqdm = "*"
|
||||||
|
typing-extensions = ">=3.7.4.3"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
all = ["pytest", "datasets", "soundfile", "black (>=22.0,<23.0)", "isort (>=5.5.4)", "flake8 (>=3.8.3)"]
|
||||||
|
dev = ["pytest", "datasets", "soundfile", "black (>=22.0,<23.0)", "isort (>=5.5.4)", "flake8 (>=3.8.3)"]
|
||||||
|
fastai = ["toml", "fastai (>=2.4)", "fastcore (>=1.3.27)"]
|
||||||
|
quality = ["black (>=22.0,<23.0)", "isort (>=5.5.4)", "flake8 (>=3.8.3)"]
|
||||||
|
tensorflow = ["tensorflow", "pydot", "graphviz"]
|
||||||
|
testing = ["pytest", "datasets", "soundfile"]
|
||||||
|
torch = ["torch"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "identify"
|
name = "identify"
|
||||||
version = "2.5.1"
|
version = "2.5.1"
|
||||||
@ -264,6 +308,22 @@ category = "dev"
|
|||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
|
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "importlib-metadata"
|
||||||
|
version = "4.11.4"
|
||||||
|
description = "Read metadata from Python packages"
|
||||||
|
category = "main"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
zipp = ">=0.5"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
docs = ["sphinx", "jaraco.packaging (>=9)", "rst.linker (>=1.9)"]
|
||||||
|
perf = ["ipython"]
|
||||||
|
testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-cov", "pytest-enabler (>=1.0.1)", "packaging", "pyfakefs", "flufl.flake8", "pytest-perf (>=0.9.2)", "pytest-black (>=0.3.7)", "pytest-mypy (>=0.9.1)", "importlib-resources (>=1.3)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "iniconfig"
|
name = "iniconfig"
|
||||||
version = "1.1.1"
|
version = "1.1.1"
|
||||||
@ -286,11 +346,19 @@ requirements_deprecated_finder = ["pipreqs", "pip-api"]
|
|||||||
colors = ["colorama (>=0.4.3,<0.5.0)"]
|
colors = ["colorama (>=0.4.3,<0.5.0)"]
|
||||||
plugins = ["setuptools"]
|
plugins = ["setuptools"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "itsdangerous"
|
||||||
|
version = "2.1.2"
|
||||||
|
description = "Safely pass data to untrusted environments and back."
|
||||||
|
category = "main"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "jinja2"
|
name = "jinja2"
|
||||||
version = "3.1.2"
|
version = "3.1.2"
|
||||||
description = "A very fast and expressive template engine."
|
description = "A very fast and expressive template engine."
|
||||||
category = "dev"
|
category = "main"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.7"
|
python-versions = ">=3.7"
|
||||||
|
|
||||||
@ -304,7 +372,7 @@ i18n = ["Babel (>=2.7)"]
|
|||||||
name = "markupsafe"
|
name = "markupsafe"
|
||||||
version = "2.1.1"
|
version = "2.1.1"
|
||||||
description = "Safely add untrusted strings to HTML/XML markup."
|
description = "Safely add untrusted strings to HTML/XML markup."
|
||||||
category = "dev"
|
category = "main"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.7"
|
python-versions = ">=3.7"
|
||||||
|
|
||||||
@ -614,7 +682,7 @@ python-versions = "*"
|
|||||||
name = "pyyaml"
|
name = "pyyaml"
|
||||||
version = "6.0"
|
version = "6.0"
|
||||||
description = "YAML parser and emitter for Python"
|
description = "YAML parser and emitter for Python"
|
||||||
category = "dev"
|
category = "main"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.6"
|
python-versions = ">=3.6"
|
||||||
|
|
||||||
@ -648,6 +716,14 @@ packaging = ">=20.4"
|
|||||||
hiredis = ["hiredis (>=1.0.0)"]
|
hiredis = ["hiredis (>=1.0.0)"]
|
||||||
ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==20.0.1)", "requests (>=2.26.0)"]
|
ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==20.0.1)", "requests (>=2.26.0)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "regex"
|
||||||
|
version = "2022.4.24"
|
||||||
|
description = "Alternative regular expression module, to replace re."
|
||||||
|
category = "main"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.6"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "requests"
|
name = "requests"
|
||||||
version = "2.27.1"
|
version = "2.27.1"
|
||||||
@ -792,6 +868,18 @@ category = "main"
|
|||||||
optional = false
|
optional = false
|
||||||
python-versions = "*"
|
python-versions = "*"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tokenizers"
|
||||||
|
version = "0.12.1"
|
||||||
|
description = "Fast and Customizable Tokenizers"
|
||||||
|
category = "main"
|
||||||
|
optional = false
|
||||||
|
python-versions = "*"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
docs = ["sphinx", "sphinx-rtd-theme", "setuptools-rust"]
|
||||||
|
testing = ["pytest", "requests", "numpy", "datasets"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "toml"
|
name = "toml"
|
||||||
version = "0.10.2"
|
version = "0.10.2"
|
||||||
@ -808,6 +896,17 @@ category = "dev"
|
|||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.7"
|
python-versions = ">=3.7"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "torch"
|
||||||
|
version = "1.11.0"
|
||||||
|
description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
|
||||||
|
category = "main"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7.0"
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
typing-extensions = "*"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tqdm"
|
name = "tqdm"
|
||||||
version = "4.64.0"
|
version = "4.64.0"
|
||||||
@ -825,11 +924,71 @@ notebook = ["ipywidgets (>=6)"]
|
|||||||
slack = ["slack-sdk"]
|
slack = ["slack-sdk"]
|
||||||
telegram = ["requests"]
|
telegram = ["requests"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "transformers"
|
||||||
|
version = "4.19.2"
|
||||||
|
description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow"
|
||||||
|
category = "main"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7.0"
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
filelock = "*"
|
||||||
|
huggingface-hub = ">=0.1.0,<1.0"
|
||||||
|
numpy = ">=1.17"
|
||||||
|
packaging = ">=20.0"
|
||||||
|
pyyaml = ">=5.1"
|
||||||
|
regex = "!=2019.12.17"
|
||||||
|
requests = "*"
|
||||||
|
tokenizers = ">=0.11.1,<0.11.3 || >0.11.3,<0.13"
|
||||||
|
tqdm = ">=4.27"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
all = ["tensorflow (>=2.3)", "onnxconverter-common", "tf2onnx", "torch (>=1.0)", "jax (>=0.2.8,!=0.3.2,<=0.3.6)", "jaxlib (>=0.1.65,<=0.3.6)", "flax (>=0.3.5)", "optax (>=0.0.8)", "sentencepiece (>=0.1.91,!=0.1.92)", "protobuf", "tokenizers (>=0.11.1,!=0.11.3,<0.13)", "torchaudio", "librosa", "pyctcdecode (>=0.3.0)", "phonemizer", "pillow", "optuna", "ray", "sigopt", "timm", "codecarbon (==1.2.0)"]
|
||||||
|
audio = ["librosa", "pyctcdecode (>=0.3.0)", "phonemizer"]
|
||||||
|
codecarbon = ["codecarbon (==1.2.0)"]
|
||||||
|
deepspeed = ["deepspeed (>=0.6.4)"]
|
||||||
|
deepspeed-testing = ["deepspeed (>=0.6.4)", "pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil", "datasets", "pytest-timeout", "black (>=22.0,<23.0)", "sacrebleu (>=1.4.12,<2.0.0)", "rouge-score", "nltk", "GitPython (<3.1.19)", "hf-doc-builder (>=0.3.0)", "sacremoses", "rjieba", "faiss-cpu", "cookiecutter (==1.7.3)", "optuna"]
|
||||||
|
dev = ["tensorflow (>=2.3)", "onnxconverter-common", "tf2onnx", "torch (>=1.0)", "jax (>=0.2.8,!=0.3.2,<=0.3.6)", "jaxlib (>=0.1.65,<=0.3.6)", "flax (>=0.3.5)", "optax (>=0.0.8)", "sentencepiece (>=0.1.91,!=0.1.92)", "protobuf", "tokenizers (>=0.11.1,!=0.11.3,<0.13)", "torchaudio", "librosa", "pyctcdecode (>=0.3.0)", "phonemizer", "pillow", "optuna", "ray", "sigopt", "timm", "codecarbon (==1.2.0)", "pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil", "datasets", "pytest-timeout", "black (>=22.0,<23.0)", "sacrebleu (>=1.4.12,<2.0.0)", "rouge-score", "nltk", "GitPython (<3.1.19)", "hf-doc-builder (>=0.3.0)", "sacremoses", "rjieba", "faiss-cpu", "cookiecutter (==1.7.3)", "isort (>=5.5.4)", "flake8 (>=3.8.3)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "unidic-lite (>=1.0.7)", "unidic (>=1.0.2)", "hf-doc-builder", "scikit-learn"]
|
||||||
|
dev-tensorflow = ["pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil", "datasets", "pytest-timeout", "black (>=22.0,<23.0)", "sacrebleu (>=1.4.12,<2.0.0)", "rouge-score", "nltk", "GitPython (<3.1.19)", "hf-doc-builder (>=0.3.0)", "sacremoses", "rjieba", "faiss-cpu", "cookiecutter (==1.7.3)", "tensorflow (>=2.3)", "onnxconverter-common", "tf2onnx", "sentencepiece (>=0.1.91,!=0.1.92)", "protobuf", "tokenizers (>=0.11.1,!=0.11.3,<0.13)", "pillow", "isort (>=5.5.4)", "flake8 (>=3.8.3)", "hf-doc-builder", "scikit-learn", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "librosa", "pyctcdecode (>=0.3.0)", "phonemizer"]
|
||||||
|
dev-torch = ["pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil", "datasets", "pytest-timeout", "black (>=22.0,<23.0)", "sacrebleu (>=1.4.12,<2.0.0)", "rouge-score", "nltk", "GitPython (<3.1.19)", "hf-doc-builder (>=0.3.0)", "sacremoses", "rjieba", "faiss-cpu", "cookiecutter (==1.7.3)", "torch (>=1.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "protobuf", "tokenizers (>=0.11.1,!=0.11.3,<0.13)", "torchaudio", "librosa", "pyctcdecode (>=0.3.0)", "phonemizer", "pillow", "optuna", "ray", "sigopt", "timm", "codecarbon (==1.2.0)", "isort (>=5.5.4)", "flake8 (>=3.8.3)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "unidic-lite (>=1.0.7)", "unidic (>=1.0.2)", "hf-doc-builder", "scikit-learn", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"]
|
||||||
|
docs = ["tensorflow (>=2.3)", "onnxconverter-common", "tf2onnx", "torch (>=1.0)", "jax (>=0.2.8,!=0.3.2,<=0.3.6)", "jaxlib (>=0.1.65,<=0.3.6)", "flax (>=0.3.5)", "optax (>=0.0.8)", "sentencepiece (>=0.1.91,!=0.1.92)", "protobuf", "tokenizers (>=0.11.1,!=0.11.3,<0.13)", "torchaudio", "librosa", "pyctcdecode (>=0.3.0)", "phonemizer", "pillow", "optuna", "ray", "sigopt", "timm", "codecarbon (==1.2.0)", "hf-doc-builder"]
|
||||||
|
docs_specific = ["hf-doc-builder"]
|
||||||
|
fairscale = ["fairscale (>0.3)"]
|
||||||
|
flax = ["jax (>=0.2.8,!=0.3.2,<=0.3.6)", "jaxlib (>=0.1.65,<=0.3.6)", "flax (>=0.3.5)", "optax (>=0.0.8)"]
|
||||||
|
flax-speech = ["librosa", "pyctcdecode (>=0.3.0)", "phonemizer"]
|
||||||
|
ftfy = ["ftfy"]
|
||||||
|
integrations = ["optuna", "ray", "sigopt"]
|
||||||
|
ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "unidic-lite (>=1.0.7)", "unidic (>=1.0.2)"]
|
||||||
|
modelcreation = ["cookiecutter (==1.7.3)"]
|
||||||
|
onnx = ["onnxconverter-common", "tf2onnx", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"]
|
||||||
|
onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"]
|
||||||
|
optuna = ["optuna"]
|
||||||
|
quality = ["black (>=22.0,<23.0)", "isort (>=5.5.4)", "flake8 (>=3.8.3)", "GitPython (<3.1.19)", "hf-doc-builder (>=0.3.0)"]
|
||||||
|
ray = ["ray"]
|
||||||
|
retrieval = ["faiss-cpu", "datasets"]
|
||||||
|
sagemaker = ["sagemaker (>=2.31.0)"]
|
||||||
|
sentencepiece = ["sentencepiece (>=0.1.91,!=0.1.92)", "protobuf"]
|
||||||
|
serving = ["pydantic", "uvicorn", "fastapi", "starlette"]
|
||||||
|
sigopt = ["sigopt"]
|
||||||
|
sklearn = ["scikit-learn"]
|
||||||
|
speech = ["torchaudio", "librosa", "pyctcdecode (>=0.3.0)", "phonemizer"]
|
||||||
|
testing = ["pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil", "datasets", "pytest-timeout", "black (>=22.0,<23.0)", "sacrebleu (>=1.4.12,<2.0.0)", "rouge-score", "nltk", "GitPython (<3.1.19)", "hf-doc-builder (>=0.3.0)", "sacremoses", "rjieba", "faiss-cpu", "cookiecutter (==1.7.3)"]
|
||||||
|
tf = ["tensorflow (>=2.3)", "onnxconverter-common", "tf2onnx"]
|
||||||
|
tf-cpu = ["tensorflow-cpu (>=2.3)", "onnxconverter-common", "tf2onnx"]
|
||||||
|
tf-speech = ["librosa", "pyctcdecode (>=0.3.0)", "phonemizer"]
|
||||||
|
timm = ["timm"]
|
||||||
|
tokenizers = ["tokenizers (>=0.11.1,!=0.11.3,<0.13)"]
|
||||||
|
torch = ["torch (>=1.0)"]
|
||||||
|
torch-speech = ["torchaudio", "librosa", "pyctcdecode (>=0.3.0)", "phonemizer"]
|
||||||
|
torchhub = ["filelock", "huggingface-hub (>=0.1.0,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "torch (>=1.0)", "tokenizers (>=0.11.1,!=0.11.3,<0.13)", "tqdm (>=4.27)"]
|
||||||
|
vision = ["pillow"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "typing-extensions"
|
name = "typing-extensions"
|
||||||
version = "4.2.0"
|
version = "4.2.0"
|
||||||
description = "Backported and Experimental Type Hints for Python 3.7+"
|
description = "Backported and Experimental Type Hints for Python 3.7+"
|
||||||
category = "dev"
|
category = "main"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.7"
|
python-versions = ">=3.7"
|
||||||
|
|
||||||
@ -864,6 +1023,17 @@ six = ">=1.9.0,<2"
|
|||||||
docs = ["proselint (>=0.10.2)", "sphinx (>=3)", "sphinx-argparse (>=0.2.5)", "sphinx-rtd-theme (>=0.4.3)", "towncrier (>=21.3)"]
|
docs = ["proselint (>=0.10.2)", "sphinx (>=3)", "sphinx-argparse (>=0.2.5)", "sphinx-rtd-theme (>=0.4.3)", "towncrier (>=21.3)"]
|
||||||
testing = ["coverage (>=4)", "coverage-enable-subprocess (>=1)", "flaky (>=3)", "pytest (>=4)", "pytest-env (>=0.6.2)", "pytest-freezegun (>=0.4.1)", "pytest-mock (>=2)", "pytest-randomly (>=1)", "pytest-timeout (>=1)", "packaging (>=20.0)"]
|
testing = ["coverage (>=4)", "coverage-enable-subprocess (>=1)", "flaky (>=3)", "pytest (>=4)", "pytest-env (>=0.6.2)", "pytest-freezegun (>=0.4.1)", "pytest-mock (>=2)", "pytest-randomly (>=1)", "pytest-timeout (>=1)", "packaging (>=20.0)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "werkzeug"
|
||||||
|
version = "2.1.2"
|
||||||
|
description = "The comprehensive WSGI web application library."
|
||||||
|
category = "main"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
watchdog = ["watchdog"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "wrapt"
|
name = "wrapt"
|
||||||
version = "1.14.1"
|
version = "1.14.1"
|
||||||
@ -872,10 +1042,22 @@ category = "main"
|
|||||||
optional = false
|
optional = false
|
||||||
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7"
|
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "zipp"
|
||||||
|
version = "3.8.0"
|
||||||
|
description = "Backport of pathlib-compatible object wrapper for zip files"
|
||||||
|
category = "main"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
docs = ["sphinx", "jaraco.packaging (>=9)", "rst.linker (>=1.9)"]
|
||||||
|
testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-cov", "pytest-enabler (>=1.0.1)", "jaraco.itertools", "func-timeout", "pytest-black (>=0.3.7)", "pytest-mypy (>=0.9.1)"]
|
||||||
|
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "1.1"
|
lock-version = "1.1"
|
||||||
python-versions = "^3.8"
|
python-versions = "^3.8"
|
||||||
content-hash = "cfa01a852b186ab6ddd1b610dba12c7d4ac6f10c7bdb10ca7bac6a0d25568d00"
|
content-hash = "b0f92eed3a0f80cc9976f68098b772a19d19b4742426299bd63cb1c636d8a84a"
|
||||||
|
|
||||||
[metadata.files]
|
[metadata.files]
|
||||||
alabaster = [
|
alabaster = [
|
||||||
@ -1026,6 +1208,14 @@ flake8-polyfill = [
|
|||||||
{file = "flake8-polyfill-1.0.2.tar.gz", hash = "sha256:e44b087597f6da52ec6393a709e7108b2905317d0c0b744cdca6208e670d8eda"},
|
{file = "flake8-polyfill-1.0.2.tar.gz", hash = "sha256:e44b087597f6da52ec6393a709e7108b2905317d0c0b744cdca6208e670d8eda"},
|
||||||
{file = "flake8_polyfill-1.0.2-py2.py3-none-any.whl", hash = "sha256:12be6a34ee3ab795b19ca73505e7b55826d5f6ad7230d31b18e106400169b9e9"},
|
{file = "flake8_polyfill-1.0.2-py2.py3-none-any.whl", hash = "sha256:12be6a34ee3ab795b19ca73505e7b55826d5f6ad7230d31b18e106400169b9e9"},
|
||||||
]
|
]
|
||||||
|
flask = [
|
||||||
|
{file = "Flask-2.1.2-py3-none-any.whl", hash = "sha256:fad5b446feb0d6db6aec0c3184d16a8c1f6c3e464b511649c8918a9be100b4fe"},
|
||||||
|
{file = "Flask-2.1.2.tar.gz", hash = "sha256:315ded2ddf8a6281567edb27393010fe3406188bafbfe65a3339d5787d89e477"},
|
||||||
|
]
|
||||||
|
huggingface-hub = [
|
||||||
|
{file = "huggingface_hub-0.7.0-py3-none-any.whl", hash = "sha256:fd448fd0b738d803411c79bdf9f12f0ba171fecd24a59edf88c1391b473bc2c0"},
|
||||||
|
{file = "huggingface_hub-0.7.0.tar.gz", hash = "sha256:8154dc2fad84b32a4bca18372a647d9381ed8550a80b11050758357b8fcea639"},
|
||||||
|
]
|
||||||
identify = [
|
identify = [
|
||||||
{file = "identify-2.5.1-py2.py3-none-any.whl", hash = "sha256:0dca2ea3e4381c435ef9c33ba100a78a9b40c0bab11189c7cf121f75815efeaa"},
|
{file = "identify-2.5.1-py2.py3-none-any.whl", hash = "sha256:0dca2ea3e4381c435ef9c33ba100a78a9b40c0bab11189c7cf121f75815efeaa"},
|
||||||
{file = "identify-2.5.1.tar.gz", hash = "sha256:3d11b16f3fe19f52039fb7e39c9c884b21cb1b586988114fbe42671f03de3e82"},
|
{file = "identify-2.5.1.tar.gz", hash = "sha256:3d11b16f3fe19f52039fb7e39c9c884b21cb1b586988114fbe42671f03de3e82"},
|
||||||
@ -1038,6 +1228,10 @@ imagesize = [
|
|||||||
{file = "imagesize-1.3.0-py2.py3-none-any.whl", hash = "sha256:1db2f82529e53c3e929e8926a1fa9235aa82d0bd0c580359c67ec31b2fddaa8c"},
|
{file = "imagesize-1.3.0-py2.py3-none-any.whl", hash = "sha256:1db2f82529e53c3e929e8926a1fa9235aa82d0bd0c580359c67ec31b2fddaa8c"},
|
||||||
{file = "imagesize-1.3.0.tar.gz", hash = "sha256:cd1750d452385ca327479d45b64d9c7729ecf0b3969a58148298c77092261f9d"},
|
{file = "imagesize-1.3.0.tar.gz", hash = "sha256:cd1750d452385ca327479d45b64d9c7729ecf0b3969a58148298c77092261f9d"},
|
||||||
]
|
]
|
||||||
|
importlib-metadata = [
|
||||||
|
{file = "importlib_metadata-4.11.4-py3-none-any.whl", hash = "sha256:c58c8eb8a762858f49e18436ff552e83914778e50e9d2f1660535ffb364552ec"},
|
||||||
|
{file = "importlib_metadata-4.11.4.tar.gz", hash = "sha256:5d26852efe48c0a32b0509ffbc583fda1a2266545a78d104a6f4aff3db17d700"},
|
||||||
|
]
|
||||||
iniconfig = [
|
iniconfig = [
|
||||||
{file = "iniconfig-1.1.1-py2.py3-none-any.whl", hash = "sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3"},
|
{file = "iniconfig-1.1.1-py2.py3-none-any.whl", hash = "sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3"},
|
||||||
{file = "iniconfig-1.1.1.tar.gz", hash = "sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32"},
|
{file = "iniconfig-1.1.1.tar.gz", hash = "sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32"},
|
||||||
@ -1046,6 +1240,10 @@ isort = [
|
|||||||
{file = "isort-5.10.1-py3-none-any.whl", hash = "sha256:6f62d78e2f89b4500b080fe3a81690850cd254227f27f75c3a0c491a1f351ba7"},
|
{file = "isort-5.10.1-py3-none-any.whl", hash = "sha256:6f62d78e2f89b4500b080fe3a81690850cd254227f27f75c3a0c491a1f351ba7"},
|
||||||
{file = "isort-5.10.1.tar.gz", hash = "sha256:e8443a5e7a020e9d7f97f1d7d9cd17c88bcb3bc7e218bf9cf5095fe550be2951"},
|
{file = "isort-5.10.1.tar.gz", hash = "sha256:e8443a5e7a020e9d7f97f1d7d9cd17c88bcb3bc7e218bf9cf5095fe550be2951"},
|
||||||
]
|
]
|
||||||
|
itsdangerous = [
|
||||||
|
{file = "itsdangerous-2.1.2-py3-none-any.whl", hash = "sha256:2c2349112351b88699d8d4b6b075022c0808887cb7ad10069318a8b0bc88db44"},
|
||||||
|
{file = "itsdangerous-2.1.2.tar.gz", hash = "sha256:5dbbc68b317e5e42f327f9021763545dc3fc3bfe22e6deb96aaf1fc38874156a"},
|
||||||
|
]
|
||||||
jinja2 = [
|
jinja2 = [
|
||||||
{file = "Jinja2-3.1.2-py3-none-any.whl", hash = "sha256:6088930bfe239f0e6710546ab9c19c9ef35e29792895fed6e6e31a023a182a61"},
|
{file = "Jinja2-3.1.2-py3-none-any.whl", hash = "sha256:6088930bfe239f0e6710546ab9c19c9ef35e29792895fed6e6e31a023a182a61"},
|
||||||
{file = "Jinja2-3.1.2.tar.gz", hash = "sha256:31351a702a408a9e7595a8fc6150fc3f43bb6bf7e319770cbc0db9df9437e852"},
|
{file = "Jinja2-3.1.2.tar.gz", hash = "sha256:31351a702a408a9e7595a8fc6150fc3f43bb6bf7e319770cbc0db9df9437e852"},
|
||||||
@ -1298,6 +1496,82 @@ redis = [
|
|||||||
{file = "redis-4.3.1-py3-none-any.whl", hash = "sha256:84316970995a7adb907a56754d2b92d88fc2d252963dc5ac34c88f0f1a22c25d"},
|
{file = "redis-4.3.1-py3-none-any.whl", hash = "sha256:84316970995a7adb907a56754d2b92d88fc2d252963dc5ac34c88f0f1a22c25d"},
|
||||||
{file = "redis-4.3.1.tar.gz", hash = "sha256:94b617b4cd296e94991146f66fc5559756fbefe9493604f0312e4d3298ac63e9"},
|
{file = "redis-4.3.1.tar.gz", hash = "sha256:94b617b4cd296e94991146f66fc5559756fbefe9493604f0312e4d3298ac63e9"},
|
||||||
]
|
]
|
||||||
|
regex = [
|
||||||
|
{file = "regex-2022.4.24-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:f86aef546add4ff1202e1f31e9bb54f9268f17d996b2428877283146bf9bc013"},
|
||||||
|
{file = "regex-2022.4.24-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e944268445b5694f5d41292c9228f0ca46d5a32a67f195d5f8547c1f1d91f4bc"},
|
||||||
|
{file = "regex-2022.4.24-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0f8da3145f4b72f7ce6181c804eaa44cdcea313c8998cdade3d9e20a8717a9cb"},
|
||||||
|
{file = "regex-2022.4.24-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0fd464e547dbabf4652ca5fe9d88d75ec30182981e737c07b3410235a44b9939"},
|
||||||
|
{file = "regex-2022.4.24-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:071bcb625e890f28b7c4573124a6512ea65107152b1d3ca101ce33a52dad4593"},
|
||||||
|
{file = "regex-2022.4.24-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1c2de7f32fa87d04d40f54bce3843af430697aba51c3a114aa62837a0772f219"},
|
||||||
|
{file = "regex-2022.4.24-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1a07e8366115069f26822c47732122ab61598830a69f5629a37ea8881487c107"},
|
||||||
|
{file = "regex-2022.4.24-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:036d1c1fbe69eba3ee253c107e71749cdbb4776db93d674bc0d5e28f30300734"},
|
||||||
|
{file = "regex-2022.4.24-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:af1e687ffab18a75409e5e5d6215b6ccd41a5a1a0ea6ce9665e01253f737a0d3"},
|
||||||
|
{file = "regex-2022.4.24-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:165cc75cfa5aa0f12adb2ac6286330e7229a06dc0e6c004ec35da682b5b89579"},
|
||||||
|
{file = "regex-2022.4.24-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:3e35c50b27f36176c792738cb9b858523053bc495044d2c2b44db24376b266f1"},
|
||||||
|
{file = "regex-2022.4.24-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:43ee0df35925ae4b0cc6ee3f60b73369e559dd2ac40945044da9394dd9d3a51d"},
|
||||||
|
{file = "regex-2022.4.24-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:58521abdab76583bd41ef47e5e2ddd93b32501aee4ee8cee71dee10a45ba46b1"},
|
||||||
|
{file = "regex-2022.4.24-cp310-cp310-win32.whl", hash = "sha256:275afc7352982ee947fc88f67a034b52c78395977b5fc7c9be15f7dc95b76f06"},
|
||||||
|
{file = "regex-2022.4.24-cp310-cp310-win_amd64.whl", hash = "sha256:253f858a0255cd91a0424a4b15c2eedb12f20274f85731b0d861c8137e843065"},
|
||||||
|
{file = "regex-2022.4.24-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:85b7ee4d0c7a46296d884f6b489af8b960c4291d76aea4b22fd4fbe05e6ec08e"},
|
||||||
|
{file = "regex-2022.4.24-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e0da7ef160d4f3eb3d4d3e39a02c3c42f7dbcfce62c81f784cc99fc7059765f"},
|
||||||
|
{file = "regex-2022.4.24-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4f2e2cef324ca9355049ee1e712f68e2e92716eba24275e6767b9bfa15f1f478"},
|
||||||
|
{file = "regex-2022.4.24-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6165e737acb3bea3271372e8aa5ebe7226c8a8e8da1b94af2d6547c5a09d689d"},
|
||||||
|
{file = "regex-2022.4.24-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3f6bd8178cce5bb56336722d5569d19c50bba5915a69a2050c497fb921e7cb0f"},
|
||||||
|
{file = "regex-2022.4.24-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:45b761406777a681db0c24686178532134c937d24448d9e085279b69e9eb7da4"},
|
||||||
|
{file = "regex-2022.4.24-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:3dfbadb7b74d95f72f9f9dbf9778f7de92722ab520a109ceaf7927461fa85b10"},
|
||||||
|
{file = "regex-2022.4.24-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:9913bcf730eb6e9b441fb176832eea9acbebab6035542c7c89d90c803f5cd3be"},
|
||||||
|
{file = "regex-2022.4.24-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:68aed3fb0c61296bd6d234f558f78c51671f79ccb069cbcd428c2eea6fee7a5b"},
|
||||||
|
{file = "regex-2022.4.24-cp36-cp36m-musllinux_1_1_ppc64le.whl", hash = "sha256:8e7d33f93cdd01868327d834d0f5bb029241cd293b47d51b96814dec27fc9b4b"},
|
||||||
|
{file = "regex-2022.4.24-cp36-cp36m-musllinux_1_1_s390x.whl", hash = "sha256:82b7fc67e49fdce671bdbec1127189fc979badf062ce6e79dc95ef5e07a8bf92"},
|
||||||
|
{file = "regex-2022.4.24-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:c36906a7855ec33a9083608e6cd595e4729dab18aeb9aad0dd0b039240266239"},
|
||||||
|
{file = "regex-2022.4.24-cp36-cp36m-win32.whl", hash = "sha256:b2df3ede85d778c949d9bd2a50237072cee3df0a423c91f5514f78f8035bde87"},
|
||||||
|
{file = "regex-2022.4.24-cp36-cp36m-win_amd64.whl", hash = "sha256:dffd9114ade73137ab2b79a8faf864683dbd2dbbb6b23a305fbbd4cbaeeb2187"},
|
||||||
|
{file = "regex-2022.4.24-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6a0ef57cccd8089b4249eebad95065390e56c04d4a92c51316eab4131bca96a9"},
|
||||||
|
{file = "regex-2022.4.24-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:12af15b6edb00e425f713160cfd361126e624ec0de86e74f7cad4b97b7f169b3"},
|
||||||
|
{file = "regex-2022.4.24-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7f271d0831d8ebc56e17b37f9fa1824b0379221d1238ae77c18a6e8c47f1fdce"},
|
||||||
|
{file = "regex-2022.4.24-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:37903d5ca11fa47577e8952d2e2c6de28553b11c70defee827afb941ab2c6729"},
|
||||||
|
{file = "regex-2022.4.24-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8b747cef8e5dcdaf394192d43a0c02f5825aeb0ecd3d43e63ae500332ab830b0"},
|
||||||
|
{file = "regex-2022.4.24-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:582ea06079a03750b5f71e20a87cd99e646d796638b5894ff85987ebf5e04924"},
|
||||||
|
{file = "regex-2022.4.24-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:aa6daa189db9104787ff1fd7a7623ce017077aa59eaac609d0d25ba95ed251a0"},
|
||||||
|
{file = "regex-2022.4.24-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:7dbc96419ef0fb6ac56626014e6d3a345aeb8b17a3df8830235a88626ffc8d84"},
|
||||||
|
{file = "regex-2022.4.24-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:0fb6cb16518ac7eff29d1e0b0cce90275dfae0f17154165491058c31d58bdd1d"},
|
||||||
|
{file = "regex-2022.4.24-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:bea61de0c688198e3d9479344228c7accaa22a78b58ec408e41750ebafee6c08"},
|
||||||
|
{file = "regex-2022.4.24-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:46cbc5b23f85e94161b093dba1b49035697cf44c7db3c930adabfc0e6d861b95"},
|
||||||
|
{file = "regex-2022.4.24-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:50b77622016f03989cd06ecf6b602c7a6b4ed2e3ce04133876b041d109c934ee"},
|
||||||
|
{file = "regex-2022.4.24-cp37-cp37m-win32.whl", hash = "sha256:2bde99f2cdfd6db1ec7e02d68cadd384ffe7413831373ea7cc68c5415a0cb577"},
|
||||||
|
{file = "regex-2022.4.24-cp37-cp37m-win_amd64.whl", hash = "sha256:66fb765b2173d90389384708e3e1d3e4be1148bd8d4d50476b1469da5a2f0229"},
|
||||||
|
{file = "regex-2022.4.24-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:709396c0c95b95045fac89b94f997410ff39b81a09863fe21002f390d48cc7d3"},
|
||||||
|
{file = "regex-2022.4.24-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7a608022f4593fc67518c6c599ae5abdb03bb8acd75993c82cd7a4c8100eff81"},
|
||||||
|
{file = "regex-2022.4.24-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fb7107faf0168de087f62a2f2ed00f9e9da12e0b801582b516ddac236b871cda"},
|
||||||
|
{file = "regex-2022.4.24-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aabc28f7599f781ddaeac168d0b566d0db82182cc3dcf62129f0a4fc2927b811"},
|
||||||
|
{file = "regex-2022.4.24-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:92ad03f928675ca05b79d3b1d3dfc149e2226d57ed9d57808f82105d511d0212"},
|
||||||
|
{file = "regex-2022.4.24-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b7ba3c304a4a5d8112dbd30df8b3e4ef59b4b07807957d3c410d9713abaee9a8"},
|
||||||
|
{file = "regex-2022.4.24-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e2acf5c66fbb62b5fe4c40978ddebafa50818f00bf79d60569d9762f6356336e"},
|
||||||
|
{file = "regex-2022.4.24-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:7c4d9770e579eb11b582b2e2fd19fa204a15cb1589ae73cd4dcbb63b64f3e828"},
|
||||||
|
{file = "regex-2022.4.24-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:02543d6d5c32d361b7cc468079ba4cddaaf4a6544f655901ba1ff9d8e3f18755"},
|
||||||
|
{file = "regex-2022.4.24-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:73ed1b06abadbf6b61f6033a07c06f36ec0ddca117e41ef2ac37056705e46458"},
|
||||||
|
{file = "regex-2022.4.24-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:3241db067a7f69da57fba8bca543ac8a7ca415d91e77315690202749b9fdaba1"},
|
||||||
|
{file = "regex-2022.4.24-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:d128e278e5e554c5c022c7bed410ca851e00bacebbb4460de546a73bc53f8de4"},
|
||||||
|
{file = "regex-2022.4.24-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b1d53835922cd0f9b74b2742453a444865a70abae38d12eb41c59271da66f38d"},
|
||||||
|
{file = "regex-2022.4.24-cp38-cp38-win32.whl", hash = "sha256:f2a5d9f612091812dee18375a45d046526452142e7b78c4e21ab192db15453d5"},
|
||||||
|
{file = "regex-2022.4.24-cp38-cp38-win_amd64.whl", hash = "sha256:a850f5f369f1e3b6239da7fb43d1d029c1e178263df671819889c47caf7e4ff3"},
|
||||||
|
{file = "regex-2022.4.24-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bedb3d01ad35ea1745bdb1d57f3ee0f996f988c98f5bbae9d068c3bb3065d210"},
|
||||||
|
{file = "regex-2022.4.24-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8bf867ba71856414a482e4b683500f946c300c4896e472e51d3db8dfa8dc8f32"},
|
||||||
|
{file = "regex-2022.4.24-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b415b82e5be7389ec5ee7ee35431e4a549ea327caacf73b697c6b3538cb5c87f"},
|
||||||
|
{file = "regex-2022.4.24-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9dae5affbb66178dad6c6fd5b02221ca9917e016c75ee3945e9a9563eb1fbb6f"},
|
||||||
|
{file = "regex-2022.4.24-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e65580ae3137bce712f505ec7c2d700aef0014a3878c4767b74aff5895fc454f"},
|
||||||
|
{file = "regex-2022.4.24-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3e9e983fc8e0d4d5ded7caa5aed39ca2cf6026d7e39801ef6f0af0b1b6cd9276"},
|
||||||
|
{file = "regex-2022.4.24-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cfad3a770839aa456ff9a9aa0e253d98b628d005a3ccb37da1ff9be7c84fee16"},
|
||||||
|
{file = "regex-2022.4.24-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:ed625205f5f26984382b68e4cbcbc08e6603c9e84c14b38457170b0cc71c823b"},
|
||||||
|
{file = "regex-2022.4.24-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:c4fdf837666f7793a5c3cfa2f2f39f03eb6c7e92e831bc64486c2f547580c2b3"},
|
||||||
|
{file = "regex-2022.4.24-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:ed26c3d2d62c6588e0dad175b8d8cc0942a638f32d07b80f92043e5d73b7db67"},
|
||||||
|
{file = "regex-2022.4.24-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:f89d26e50a4c7453cb8c415acd09e72fbade2610606a9c500a1e48c43210a42d"},
|
||||||
|
{file = "regex-2022.4.24-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:97af238389cb029d63d5f2d931a7e8f5954ad96e812de5faaed373b68e74df86"},
|
||||||
|
{file = "regex-2022.4.24-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:be392d9cd5309509175a9d7660dc17bf57084501108dbff0c5a8bfc3646048c3"},
|
||||||
|
{file = "regex-2022.4.24-cp39-cp39-win32.whl", hash = "sha256:bcc6f7a3a95119c3568c572ca167ada75f8319890706283b9ba59b3489c9bcb3"},
|
||||||
|
{file = "regex-2022.4.24-cp39-cp39-win_amd64.whl", hash = "sha256:5b9c7b6895a01204296e9523b3e12b43e013835a9de035a783907c2c1bc447f0"},
|
||||||
|
{file = "regex-2022.4.24.tar.gz", hash = "sha256:92183e9180c392371079262879c6532ccf55f808e6900df5d9f03c9ca8807255"},
|
||||||
|
]
|
||||||
requests = [
|
requests = [
|
||||||
{file = "requests-2.27.1-py2.py3-none-any.whl", hash = "sha256:f22fa1e554c9ddfd16e6e41ac79759e17be9e492b3587efa038054674760e72d"},
|
{file = "requests-2.27.1-py2.py3-none-any.whl", hash = "sha256:f22fa1e554c9ddfd16e6e41ac79759e17be9e492b3587efa038054674760e72d"},
|
||||||
{file = "requests-2.27.1.tar.gz", hash = "sha256:68d7c56fd5a8999887728ef304a6d12edc7be74f1cfa47714fc8b414525c9a61"},
|
{file = "requests-2.27.1.tar.gz", hash = "sha256:68d7c56fd5a8999887728ef304a6d12edc7be74f1cfa47714fc8b414525c9a61"},
|
||||||
@ -1341,6 +1615,41 @@ sphinxcontrib-serializinghtml = [
|
|||||||
sqlitedict = [
|
sqlitedict = [
|
||||||
{file = "sqlitedict-2.0.0.tar.gz", hash = "sha256:23a370416f4e1e962daa293382f3a8dbc4127e6a0abc06a5d4e58e6902f05d17"},
|
{file = "sqlitedict-2.0.0.tar.gz", hash = "sha256:23a370416f4e1e962daa293382f3a8dbc4127e6a0abc06a5d4e58e6902f05d17"},
|
||||||
]
|
]
|
||||||
|
tokenizers = [
|
||||||
|
{file = "tokenizers-0.12.1-cp310-cp310-macosx_10_11_x86_64.whl", hash = "sha256:d737df0f8f26e093a82bfb106b6cfb510a0e9302d35834568e5b20b73ddc5a9c"},
|
||||||
|
{file = "tokenizers-0.12.1-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:f1271224acafb27639c432e1ce4e7d38eab40305ba1c546e871d5c8a32f4f195"},
|
||||||
|
{file = "tokenizers-0.12.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cdeba37c2fb44e1aec8a72af4cb369655b59ba313181b1b4b8183f08e759c49c"},
|
||||||
|
{file = "tokenizers-0.12.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:53b5f4012ce3ffddd5b00827441b80dc7a0f6b41f4fc5248ae6d36e7d3920c6d"},
|
||||||
|
{file = "tokenizers-0.12.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5188e13fc09edfe05712ca3ae5a44e7f2b0137927b1ca210d0fad90d3e58315a"},
|
||||||
|
{file = "tokenizers-0.12.1-cp310-cp310-win32.whl", hash = "sha256:eff5ff411f18a201eec137b7b32fcb55e0c48b372d370bd24f965f5bad471fa4"},
|
||||||
|
{file = "tokenizers-0.12.1-cp310-cp310-win_amd64.whl", hash = "sha256:bdbca79726fe883c696088ea163715b2f902aec638a8e24bcf9790ff8fa45019"},
|
||||||
|
{file = "tokenizers-0.12.1-cp36-cp36m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:28825dade9e52ad464164020758f9d49eb7251c32b6ae146601c506a23c67c0e"},
|
||||||
|
{file = "tokenizers-0.12.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:91906d725cb84d8ee71ce05fbb155d39d494849622b4f9349e5176a8eb01c49b"},
|
||||||
|
{file = "tokenizers-0.12.1-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:230f51a0a82ca7b90077eaca2415f12ff9bd144607888b9c50c2ee543452322e"},
|
||||||
|
{file = "tokenizers-0.12.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8d4339c376b695de2ad8ccaebffa75e4dc1d7857be1103d80e7925b34af8cf78"},
|
||||||
|
{file = "tokenizers-0.12.1-cp37-cp37m-macosx_10_11_x86_64.whl", hash = "sha256:27d93b712aa2d4346aa506ecd4ec9e94edeebeaf2d484357b482cdeffc02b5f5"},
|
||||||
|
{file = "tokenizers-0.12.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:7f4cb68dc538b52240d1986d2034eb0a6373be2ab5f0787d1be3ad1444ce71b7"},
|
||||||
|
{file = "tokenizers-0.12.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ae6c04b629ac2cd2f695739988cb70b9bd8d5e7f849f5b14c4510e942bee5770"},
|
||||||
|
{file = "tokenizers-0.12.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6a38b2019d4807d42afeff603a119094ee00f63bea2921136524c8814e9003f8"},
|
||||||
|
{file = "tokenizers-0.12.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fde8dccb9033fa344ffce3ee1837939a50e7a210a768f1cf2059beeafa755481"},
|
||||||
|
{file = "tokenizers-0.12.1-cp37-cp37m-win32.whl", hash = "sha256:38625595b2fd37bfcce64ff9bfb6868c07e9a7b7f205c909d94a615ce9472287"},
|
||||||
|
{file = "tokenizers-0.12.1-cp37-cp37m-win_amd64.whl", hash = "sha256:01abe6fbfe55e4131ca0c4c3d1a9d7ef5df424a8d536e998d2a4fc0bc57935f4"},
|
||||||
|
{file = "tokenizers-0.12.1-cp38-cp38-macosx_10_11_x86_64.whl", hash = "sha256:7c5c54080a7d5c89c990e0d478e0882dbac88926d43323a3aa236492a3c9455f"},
|
||||||
|
{file = "tokenizers-0.12.1-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:419d113e3bcc4fe20a313afc47af81e62906306b08fe1601e1443d747d46af1f"},
|
||||||
|
{file = "tokenizers-0.12.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b9779944559cb7ace6a8516e402895f239b0d9d3c833c67dbaec496310e7e206"},
|
||||||
|
{file = "tokenizers-0.12.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7d43de14b4469b57490dbaf136a31c266cb676fa22320f01f230af9219ae9034"},
|
||||||
|
{file = "tokenizers-0.12.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:258873634406bd1d438c799993a5e44bbc0132ff055985c03c4fe30f702e9a33"},
|
||||||
|
{file = "tokenizers-0.12.1-cp38-cp38-win32.whl", hash = "sha256:3f2647cc256d6a53d18b9dcd71d377828e9f8991fbcbd6fcd8ca2ceb174552b0"},
|
||||||
|
{file = "tokenizers-0.12.1-cp38-cp38-win_amd64.whl", hash = "sha256:62a723bd4b18bc55121f5c34cd8efd6c651f2d3b81f81dd50e5351fb65b8a617"},
|
||||||
|
{file = "tokenizers-0.12.1-cp39-cp39-macosx_10_11_x86_64.whl", hash = "sha256:411ebc89228f30218ffa9d9c49d414864b0df5026a47c24820431821c4360460"},
|
||||||
|
{file = "tokenizers-0.12.1-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:619728df2551bdfe6f96ff177f9ded958e7ed9e2af94c8d5ac2834d1eb06d112"},
|
||||||
|
{file = "tokenizers-0.12.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8cea98f3f9577d1541b7bb0f7a3308a911751067e1d83e01485c9d3411bbf087"},
|
||||||
|
{file = "tokenizers-0.12.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:664f36f0a0d409c24f2201d495161fec4d8bc93e091fbb78814eb426f29905a3"},
|
||||||
|
{file = "tokenizers-0.12.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0bf2380ad59c50222959a9b6f231339200a826fc5cb2be09ff96d8a59f65fc5e"},
|
||||||
|
{file = "tokenizers-0.12.1-cp39-cp39-win32.whl", hash = "sha256:6a7a106d04154c2159db6cd7d042af2e2e0e53aee432f872fe6c8be45100436a"},
|
||||||
|
{file = "tokenizers-0.12.1-cp39-cp39-win_amd64.whl", hash = "sha256:2158baf80cbc09259bfd6e0e0fc4597b611e7a72ad5443dad63918a90f1dd304"},
|
||||||
|
{file = "tokenizers-0.12.1.tar.gz", hash = "sha256:070746f86efa6c873db341e55cf17bb5e7bdd5450330ca8eca542f5c3dab2c66"},
|
||||||
|
]
|
||||||
toml = [
|
toml = [
|
||||||
{file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"},
|
{file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"},
|
||||||
{file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"},
|
{file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"},
|
||||||
@ -1349,10 +1658,35 @@ tomli = [
|
|||||||
{file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"},
|
{file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"},
|
||||||
{file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"},
|
{file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"},
|
||||||
]
|
]
|
||||||
|
torch = [
|
||||||
|
{file = "torch-1.11.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:62052b50fffc29ca7afc0c04ef8206b6f1ca9d10629cb543077e12967e8d0398"},
|
||||||
|
{file = "torch-1.11.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:866bfba29ac98dec35d893d8e17eaec149d0ac7a53be7baae5c98069897db667"},
|
||||||
|
{file = "torch-1.11.0-cp310-cp310-win_amd64.whl", hash = "sha256:951640fb8db308a59d9b510e7d1ad910aff92913323bbe4bc75435347ddd346d"},
|
||||||
|
{file = "torch-1.11.0-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:5d77b5ece78fdafa5c7f42995ff9474399d22571cd6b2de21a5d666306a2ff8c"},
|
||||||
|
{file = "torch-1.11.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:b5a38682769b544c875ecc34bcb81fbad5c922139b61319aacffcfd8a32f528c"},
|
||||||
|
{file = "torch-1.11.0-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:f82d77695a60626f2b7382d85bc566de8a6b3e50d32080755abc040db802e419"},
|
||||||
|
{file = "torch-1.11.0-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:b96654d42566080a134e784705f33f8536b3b95b5dcde357ed7879b1692a5f78"},
|
||||||
|
{file = "torch-1.11.0-cp37-cp37m-win_amd64.whl", hash = "sha256:8ee7c2e8d7f7020d5bfbc1bb91b9591044c26bbd0cee5e4f694cfd7ed8649260"},
|
||||||
|
{file = "torch-1.11.0-cp37-none-macosx_10_9_x86_64.whl", hash = "sha256:6860b1d1bf0bb0b67a6bd47f85a0e4c825b518eea13b5d6101999dbbcbd5bc0c"},
|
||||||
|
{file = "torch-1.11.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:4322aa29f50da7f404db06cdf30896ea67b09f673af4a985afc7162bc897864d"},
|
||||||
|
{file = "torch-1.11.0-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:e4d2e0ddd652f30e94cff750220324ec45705d4ecc69658f773b3cb1c7a28dd0"},
|
||||||
|
{file = "torch-1.11.0-cp38-cp38-win_amd64.whl", hash = "sha256:34ce5ea4d8d85da32cdbadb50d4585106901e9f8a3527991daa70c13a09de1f7"},
|
||||||
|
{file = "torch-1.11.0-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:0ccc85cd06227a3edf809e2c795fd5762c3d4e8a38b5c9f744c6e7cf841361bb"},
|
||||||
|
{file = "torch-1.11.0-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:c1554e49d74f1b2c3e7202d77056ba2dd7465437585bac64062b580f714a44e9"},
|
||||||
|
{file = "torch-1.11.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:58c7814502b1c129a650d7092033bbb0bbd64faf1a7941631aaa1aeaddc37570"},
|
||||||
|
{file = "torch-1.11.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:831cf588f01dda9409e75576741d2823453990dee2983d670f2584b37a01adf7"},
|
||||||
|
{file = "torch-1.11.0-cp39-cp39-win_amd64.whl", hash = "sha256:44a1d02fd20f827f0f36dc26fdcfc45e793806a6ad52769a22260655a77a4369"},
|
||||||
|
{file = "torch-1.11.0-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:50fd9bf85c578c871c28f1cb0ace9dfc6024401c7f399b174fb0f370899f4454"},
|
||||||
|
{file = "torch-1.11.0-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:0e48af66ad755f0f9c5f2664028a414f57c49d6adc37e77e06fe0004da4edb61"},
|
||||||
|
]
|
||||||
tqdm = [
|
tqdm = [
|
||||||
{file = "tqdm-4.64.0-py2.py3-none-any.whl", hash = "sha256:74a2cdefe14d11442cedf3ba4e21a3b84ff9a2dbdc6cfae2c34addb2a14a5ea6"},
|
{file = "tqdm-4.64.0-py2.py3-none-any.whl", hash = "sha256:74a2cdefe14d11442cedf3ba4e21a3b84ff9a2dbdc6cfae2c34addb2a14a5ea6"},
|
||||||
{file = "tqdm-4.64.0.tar.gz", hash = "sha256:40be55d30e200777a307a7585aee69e4eabb46b4ec6a4b4a5f2d9f11e7d5408d"},
|
{file = "tqdm-4.64.0.tar.gz", hash = "sha256:40be55d30e200777a307a7585aee69e4eabb46b4ec6a4b4a5f2d9f11e7d5408d"},
|
||||||
]
|
]
|
||||||
|
transformers = [
|
||||||
|
{file = "transformers-4.19.2-py3-none-any.whl", hash = "sha256:1416315b7c5ff1f56d3915f416b67aa254a9907fbb73ef7f7bffc9210446b5fa"},
|
||||||
|
{file = "transformers-4.19.2.tar.gz", hash = "sha256:e19a4ff07458eda143c738e5259caf48449fcf078a63d6b1bd1aa806543440a3"},
|
||||||
|
]
|
||||||
typing-extensions = [
|
typing-extensions = [
|
||||||
{file = "typing_extensions-4.2.0-py3-none-any.whl", hash = "sha256:6657594ee297170d19f67d55c05852a874e7eb634f4f753dbd667855e07c1708"},
|
{file = "typing_extensions-4.2.0-py3-none-any.whl", hash = "sha256:6657594ee297170d19f67d55c05852a874e7eb634f4f753dbd667855e07c1708"},
|
||||||
{file = "typing_extensions-4.2.0.tar.gz", hash = "sha256:f1c24655a0da0d1b67f07e17a5e6b2a105894e6824b92096378bb3668ef02376"},
|
{file = "typing_extensions-4.2.0.tar.gz", hash = "sha256:f1c24655a0da0d1b67f07e17a5e6b2a105894e6824b92096378bb3668ef02376"},
|
||||||
@ -1365,6 +1699,10 @@ virtualenv = [
|
|||||||
{file = "virtualenv-20.14.1-py2.py3-none-any.whl", hash = "sha256:e617f16e25b42eb4f6e74096b9c9e37713cf10bf30168fb4a739f3fa8f898a3a"},
|
{file = "virtualenv-20.14.1-py2.py3-none-any.whl", hash = "sha256:e617f16e25b42eb4f6e74096b9c9e37713cf10bf30168fb4a739f3fa8f898a3a"},
|
||||||
{file = "virtualenv-20.14.1.tar.gz", hash = "sha256:ef589a79795589aada0c1c5b319486797c03b67ac3984c48c669c0e4f50df3a5"},
|
{file = "virtualenv-20.14.1.tar.gz", hash = "sha256:ef589a79795589aada0c1c5b319486797c03b67ac3984c48c669c0e4f50df3a5"},
|
||||||
]
|
]
|
||||||
|
werkzeug = [
|
||||||
|
{file = "Werkzeug-2.1.2-py3-none-any.whl", hash = "sha256:72a4b735692dd3135217911cbeaa1be5fa3f62bffb8745c5215420a03dc55255"},
|
||||||
|
{file = "Werkzeug-2.1.2.tar.gz", hash = "sha256:1ce08e8093ed67d638d63879fd1ba3735817f7a80de3674d293f5984f25fb6e6"},
|
||||||
|
]
|
||||||
wrapt = [
|
wrapt = [
|
||||||
{file = "wrapt-1.14.1-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:1b376b3f4896e7930f1f772ac4b064ac12598d1c38d04907e696cc4d794b43d3"},
|
{file = "wrapt-1.14.1-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:1b376b3f4896e7930f1f772ac4b064ac12598d1c38d04907e696cc4d794b43d3"},
|
||||||
{file = "wrapt-1.14.1-cp27-cp27m-manylinux1_i686.whl", hash = "sha256:903500616422a40a98a5a3c4ff4ed9d0066f3b4c951fa286018ecdf0750194ef"},
|
{file = "wrapt-1.14.1-cp27-cp27m-manylinux1_i686.whl", hash = "sha256:903500616422a40a98a5a3c4ff4ed9d0066f3b4c951fa286018ecdf0750194ef"},
|
||||||
@ -1431,3 +1769,7 @@ wrapt = [
|
|||||||
{file = "wrapt-1.14.1-cp39-cp39-win_amd64.whl", hash = "sha256:dee60e1de1898bde3b238f18340eec6148986da0455d8ba7848d50470a7a32fb"},
|
{file = "wrapt-1.14.1-cp39-cp39-win_amd64.whl", hash = "sha256:dee60e1de1898bde3b238f18340eec6148986da0455d8ba7848d50470a7a32fb"},
|
||||||
{file = "wrapt-1.14.1.tar.gz", hash = "sha256:380a85cf89e0e69b7cfbe2ea9f765f004ff419f34194018a6827ac0e3edfed4d"},
|
{file = "wrapt-1.14.1.tar.gz", hash = "sha256:380a85cf89e0e69b7cfbe2ea9f765f004ff419f34194018a6827ac0e3edfed4d"},
|
||||||
]
|
]
|
||||||
|
zipp = [
|
||||||
|
{file = "zipp-3.8.0-py3-none-any.whl", hash = "sha256:c4f6e5bbf48e74f7a38e7cc5b0480ff42b0ae5178957d564d18932525d5cf099"},
|
||||||
|
{file = "zipp-3.8.0.tar.gz", hash = "sha256:56bf8aadb83c24db6c4b577e13de374ccfb67da2078beba1d037c17980bf43ad"},
|
||||||
|
]
|
||||||
|
@ -21,6 +21,10 @@ sqlitedict = "^2.0.0"
|
|||||||
openai = "^0.18.1"
|
openai = "^0.18.1"
|
||||||
redis = "^4.3.1"
|
redis = "^4.3.1"
|
||||||
dill = "^0.3.5"
|
dill = "^0.3.5"
|
||||||
|
Flask = "^2.1.2"
|
||||||
|
transformers = "^4.19.2"
|
||||||
|
torch = "^1.11.0"
|
||||||
|
requests = "^2.27.1"
|
||||||
|
|
||||||
[tool.poetry.dev-dependencies]
|
[tool.poetry.dev-dependencies]
|
||||||
black = "^22.3.0"
|
black = "^22.3.0"
|
||||||
|
@ -1 +0,0 @@
|
|||||||
"""Test client."""
|
|
@ -1,59 +0,0 @@
|
|||||||
"""Response test."""
|
|
||||||
import json
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from manifest.clients import Response
|
|
||||||
|
|
||||||
|
|
||||||
def test_init():
|
|
||||||
"""Test response initialization."""
|
|
||||||
with pytest.raises(ValueError) as exc_info:
|
|
||||||
response = Response(4)
|
|
||||||
assert str(exc_info.value) == "Response must be str or dict"
|
|
||||||
with pytest.raises(ValueError) as exc_info:
|
|
||||||
response = Response({"test": "hello"})
|
|
||||||
assert (
|
|
||||||
str(exc_info.value)
|
|
||||||
== "Response must be serialized to a dict with a list of choices"
|
|
||||||
)
|
|
||||||
with pytest.raises(ValueError) as exc_info:
|
|
||||||
response = Response({"choices": [{"blah": "hello"}]})
|
|
||||||
assert str(exc_info.value) == (
|
|
||||||
"Response must be serialized to a dict "
|
|
||||||
"with a list of choices with text field"
|
|
||||||
)
|
|
||||||
|
|
||||||
response = Response({"choices": [{"text": "hello"}]})
|
|
||||||
assert response.response == {"choices": [{"text": "hello"}]}
|
|
||||||
|
|
||||||
response = Response(json.dumps({"choices": [{"text": "hello"}]}))
|
|
||||||
assert response.response == {"choices": [{"text": "hello"}]}
|
|
||||||
|
|
||||||
|
|
||||||
def test_getitem():
|
|
||||||
"""Test response getitem."""
|
|
||||||
response = Response({"choices": [{"text": "hello"}]})
|
|
||||||
assert response["choices"] == [{"text": "hello"}]
|
|
||||||
|
|
||||||
|
|
||||||
def test_serialize():
|
|
||||||
"""Test response serialization."""
|
|
||||||
response = Response({"choices": [{"text": "hello"}]})
|
|
||||||
assert Response.deserialize(response.serialize()).response == {
|
|
||||||
"choices": [{"text": "hello"}]
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_results():
|
|
||||||
"""Test response get results."""
|
|
||||||
response = Response({"choices": []})
|
|
||||||
assert response.get_results() is None
|
|
||||||
|
|
||||||
response = Response({"choices": [{"text": "hello"}]})
|
|
||||||
assert response.get_results() == "hello"
|
|
||||||
|
|
||||||
response = Response(
|
|
||||||
{"choices": [{"text": "hello"}, {"text": "my"}, {"text": "name"}]}
|
|
||||||
)
|
|
||||||
assert response.get_results() == ["hello", "my", "name"]
|
|
@ -3,6 +3,7 @@ import os
|
|||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import redis
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -32,5 +33,5 @@ def redis_cache():
|
|||||||
port = os.environ.get("REDIS_PORT", 6379)
|
port = os.environ.get("REDIS_PORT", 6379)
|
||||||
yield f"{host}:{port}"
|
yield f"{host}:{port}"
|
||||||
# Clear out the database
|
# Clear out the database
|
||||||
# db = redis.Redis(host=host, port=port)
|
db = redis.Redis(host=host, port=port)
|
||||||
# db.flushdb()
|
db.flushdb()
|
||||||
|
@ -1,28 +1,28 @@
|
|||||||
"""Cache test."""
|
"""Cache test."""
|
||||||
import pytest
|
import pytest
|
||||||
|
from redis import Redis
|
||||||
from sqlitedict import SqliteDict
|
from sqlitedict import SqliteDict
|
||||||
|
|
||||||
from manifest.caches.redis import RedisCache
|
from manifest.caches.redis import RedisCache
|
||||||
from manifest.caches.sqlite import SQLiteCache
|
from manifest.caches.sqlite import SQLiteCache
|
||||||
from manifest.clients import Response
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("sqlite_cache")
|
@pytest.mark.usefixtures("sqlite_cache")
|
||||||
@pytest.mark.usefixtures("redis_cache")
|
@pytest.mark.usefixtures("redis_cache")
|
||||||
@pytest.mark.parametrize("cache_type", ["sqlite"])
|
@pytest.mark.parametrize("cache_type", ["sqlite", "redis"])
|
||||||
def test_init(sqlite_cache, redis_cache, cache_type):
|
def test_init(sqlite_cache, redis_cache, cache_type):
|
||||||
"""Test cache initialization."""
|
"""Test cache initialization."""
|
||||||
if cache_type == "sqlite":
|
if cache_type == "sqlite":
|
||||||
cache = SQLiteCache(sqlite_cache)
|
cache = SQLiteCache(sqlite_cache)
|
||||||
assert isinstance(cache.cache, SqliteDict)
|
assert isinstance(cache.cache, SqliteDict)
|
||||||
assert isinstance(cache.prompt_cache, SqliteDict)
|
|
||||||
else:
|
else:
|
||||||
cache = RedisCache(redis_cache)
|
cache = RedisCache(redis_cache)
|
||||||
|
assert isinstance(cache.redis, Redis)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("sqlite_cache")
|
@pytest.mark.usefixtures("sqlite_cache")
|
||||||
@pytest.mark.usefixtures("redis_cache")
|
@pytest.mark.usefixtures("redis_cache")
|
||||||
@pytest.mark.parametrize("cache_type", ["sqlite"])
|
@pytest.mark.parametrize("cache_type", ["sqlite", "redis"])
|
||||||
def test_key_get_and_set(sqlite_cache, redis_cache, cache_type):
|
def test_key_get_and_set(sqlite_cache, redis_cache, cache_type):
|
||||||
"""Test cache key get and set."""
|
"""Test cache key get and set."""
|
||||||
if cache_type == "sqlite":
|
if cache_type == "sqlite":
|
||||||
@ -32,7 +32,6 @@ def test_key_get_and_set(sqlite_cache, redis_cache, cache_type):
|
|||||||
|
|
||||||
cache.set_key("test", "valueA")
|
cache.set_key("test", "valueA")
|
||||||
cache.set_key("testA", "valueB")
|
cache.set_key("testA", "valueB")
|
||||||
|
|
||||||
assert cache.get_key("test") == "valueA"
|
assert cache.get_key("test") == "valueA"
|
||||||
assert cache.get_key("testA") == "valueB"
|
assert cache.get_key("testA") == "valueB"
|
||||||
|
|
||||||
@ -46,7 +45,7 @@ def test_key_get_and_set(sqlite_cache, redis_cache, cache_type):
|
|||||||
|
|
||||||
@pytest.mark.usefixtures("sqlite_cache")
|
@pytest.mark.usefixtures("sqlite_cache")
|
||||||
@pytest.mark.usefixtures("redis_cache")
|
@pytest.mark.usefixtures("redis_cache")
|
||||||
@pytest.mark.parametrize("cache_type", ["sqlite"])
|
@pytest.mark.parametrize("cache_type", ["sqlite", "redis"])
|
||||||
def test_get(sqlite_cache, redis_cache, cache_type):
|
def test_get(sqlite_cache, redis_cache, cache_type):
|
||||||
"""Test cache save prompt."""
|
"""Test cache save prompt."""
|
||||||
if cache_type == "sqlite":
|
if cache_type == "sqlite":
|
||||||
@ -54,16 +53,19 @@ def test_get(sqlite_cache, redis_cache, cache_type):
|
|||||||
else:
|
else:
|
||||||
cache = RedisCache(redis_cache)
|
cache = RedisCache(redis_cache)
|
||||||
test_request = {"test": "hello", "testA": "world"}
|
test_request = {"test": "hello", "testA": "world"}
|
||||||
compute = lambda: Response({"choices": [{"text": "hello"}]})
|
compute = lambda: {"choices": [{"text": "hello"}]}
|
||||||
|
|
||||||
response, cached = cache.get(test_request, overwrite_cache=False, compute=compute)
|
response = cache.get(test_request, overwrite_cache=False, compute=compute)
|
||||||
assert response.get_results() == "hello"
|
assert response.get_response() == "hello"
|
||||||
assert not cached
|
assert not response.is_cached()
|
||||||
|
assert response.get_request() == test_request
|
||||||
|
|
||||||
response, cached = cache.get(test_request, overwrite_cache=False, compute=compute)
|
response = cache.get(test_request, overwrite_cache=False, compute=compute)
|
||||||
assert response.get_results() == "hello"
|
assert response.get_response() == "hello"
|
||||||
assert cached
|
assert response.is_cached()
|
||||||
|
assert response.get_request() == test_request
|
||||||
|
|
||||||
response, cached = cache.get(test_request, overwrite_cache=True, compute=compute)
|
response = cache.get(test_request, overwrite_cache=True, compute=compute)
|
||||||
assert response.get_results() == "hello"
|
assert response.get_response() == "hello"
|
||||||
assert not cached
|
assert not response.is_cached()
|
||||||
|
assert response.get_request() == test_request
|
||||||
|
21
tests/test_client.py
Normal file
21
tests/test_client.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
"""
|
||||||
|
Test client.
|
||||||
|
|
||||||
|
We just test the dummy client as we don't want to load a model or use OpenAI tokens.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from manifest.clients.dummy import DummyClient
|
||||||
|
|
||||||
|
|
||||||
|
def test_init():
|
||||||
|
"""Test client initialization."""
|
||||||
|
client = DummyClient(connection_str=None, num_results=3)
|
||||||
|
assert client.num_results == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_request():
|
||||||
|
"""Test client get request."""
|
||||||
|
client = DummyClient(connection_str=None, num_results=3)
|
||||||
|
request_func, request_params = client.get_request("hello")
|
||||||
|
assert request_params == {"prompt": "hello", "num_results": 3}
|
||||||
|
assert request_func() == {"choices": [{"text": "hello"}] * 3}
|
@ -1,7 +1,7 @@
|
|||||||
"""Manifest test."""
|
"""Manifest test."""
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from manifest import Manifest, Prompt
|
from manifest import Manifest, Prompt, Response
|
||||||
from manifest.caches.cache import request_to_key
|
from manifest.caches.cache import request_to_key
|
||||||
from manifest.caches.sqlite import SQLiteCache
|
from manifest.caches.sqlite import SQLiteCache
|
||||||
from manifest.clients.dummy import DummyClient
|
from manifest.clients.dummy import DummyClient
|
||||||
@ -18,22 +18,27 @@ def test_init(sqlite_cache):
|
|||||||
assert manifest.client_name == "dummy"
|
assert manifest.client_name == "dummy"
|
||||||
assert isinstance(manifest.client, DummyClient)
|
assert isinstance(manifest.client, DummyClient)
|
||||||
assert isinstance(manifest.cache, SQLiteCache)
|
assert isinstance(manifest.cache, SQLiteCache)
|
||||||
|
assert manifest.client.num_results == 1
|
||||||
|
assert manifest.stop_token == ""
|
||||||
|
|
||||||
manifest = Manifest(
|
manifest = Manifest(
|
||||||
client_name="dummy",
|
client_name="dummy",
|
||||||
cache_name="sqlite",
|
cache_name="sqlite",
|
||||||
cache_connection=sqlite_cache,
|
cache_connection=sqlite_cache,
|
||||||
num_results=3,
|
num_results=3,
|
||||||
|
stop_token="\n",
|
||||||
)
|
)
|
||||||
assert manifest.client_name == "dummy"
|
assert manifest.client_name == "dummy"
|
||||||
assert isinstance(manifest.client, DummyClient)
|
assert isinstance(manifest.client, DummyClient)
|
||||||
assert isinstance(manifest.cache, SQLiteCache)
|
assert isinstance(manifest.cache, SQLiteCache)
|
||||||
assert manifest.client.num_results == 3
|
assert manifest.client.num_results == 3
|
||||||
|
assert manifest.stop_token == "\n"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("sqlite_cache")
|
@pytest.mark.usefixtures("sqlite_cache")
|
||||||
@pytest.mark.parametrize("num_results", [1, 2])
|
@pytest.mark.parametrize("num_results", [1, 2])
|
||||||
def test_run(sqlite_cache, num_results):
|
@pytest.mark.parametrize("return_response", [True, False])
|
||||||
|
def test_run(sqlite_cache, num_results, return_response):
|
||||||
"""Test manifest run."""
|
"""Test manifest run."""
|
||||||
manifest = Manifest(
|
manifest = Manifest(
|
||||||
client_name="dummy",
|
client_name="dummy",
|
||||||
@ -42,7 +47,12 @@ def test_run(sqlite_cache, num_results):
|
|||||||
num_results=num_results,
|
num_results=num_results,
|
||||||
)
|
)
|
||||||
prompt = Prompt("This is a prompt")
|
prompt = Prompt("This is a prompt")
|
||||||
res = manifest.run(prompt)
|
result = manifest.run(prompt, return_response=return_response)
|
||||||
|
if return_response:
|
||||||
|
assert isinstance(result, Response)
|
||||||
|
res = result.get_response(manifest.stop_token)
|
||||||
|
else:
|
||||||
|
res = result
|
||||||
assert (
|
assert (
|
||||||
manifest.cache.get_key(
|
manifest.cache.get_key(
|
||||||
request_to_key(
|
request_to_key(
|
||||||
@ -61,7 +71,12 @@ def test_run(sqlite_cache, num_results):
|
|||||||
assert res == ["hello", "hello"]
|
assert res == ["hello", "hello"]
|
||||||
|
|
||||||
prompt = Prompt(lambda x: f"{x} is a prompt")
|
prompt = Prompt(lambda x: f"{x} is a prompt")
|
||||||
res = manifest.run(prompt, "Hello")
|
result = manifest.run(prompt, "Hello", return_response=return_response)
|
||||||
|
if return_response:
|
||||||
|
assert isinstance(result, Response)
|
||||||
|
res = result.get_response(manifest.stop_token)
|
||||||
|
else:
|
||||||
|
res = result
|
||||||
assert (
|
assert (
|
||||||
manifest.cache.get_key(
|
manifest.cache.get_key(
|
||||||
request_to_key(
|
request_to_key(
|
||||||
@ -79,10 +94,37 @@ def test_run(sqlite_cache, num_results):
|
|||||||
else:
|
else:
|
||||||
assert res == ["hello", "hello"]
|
assert res == ["hello", "hello"]
|
||||||
|
|
||||||
|
prompt = Prompt(lambda x: f"{x} is a prompt")
|
||||||
|
result = manifest.run(
|
||||||
|
prompt, "Hello", stop_token="ll", return_response=return_response
|
||||||
|
)
|
||||||
|
if return_response:
|
||||||
|
assert isinstance(result, Response)
|
||||||
|
res = result.get_response(stop_token="ll")
|
||||||
|
else:
|
||||||
|
res = result
|
||||||
|
assert (
|
||||||
|
manifest.cache.get_key(
|
||||||
|
request_to_key(
|
||||||
|
{
|
||||||
|
"prompt": "Hello is a prompt",
|
||||||
|
"client_name": "dummy",
|
||||||
|
"num_results": num_results,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
is not None
|
||||||
|
)
|
||||||
|
if num_results == 1:
|
||||||
|
assert res == "he"
|
||||||
|
else:
|
||||||
|
assert res == ["he", "he"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("sqlite_cache")
|
@pytest.mark.usefixtures("sqlite_cache")
|
||||||
@pytest.mark.parametrize("num_results", [1, 2])
|
@pytest.mark.parametrize("num_results", [1, 2])
|
||||||
def test_batch_run(sqlite_cache, num_results):
|
@pytest.mark.parametrize("return_response", [True, False])
|
||||||
|
def test_batch_run(sqlite_cache, num_results, return_response):
|
||||||
"""Test manifest run."""
|
"""Test manifest run."""
|
||||||
manifest = Manifest(
|
manifest = Manifest(
|
||||||
client_name="dummy",
|
client_name="dummy",
|
||||||
@ -91,15 +133,38 @@ def test_batch_run(sqlite_cache, num_results):
|
|||||||
num_results=num_results,
|
num_results=num_results,
|
||||||
)
|
)
|
||||||
prompt = Prompt("This is a prompt")
|
prompt = Prompt("This is a prompt")
|
||||||
res = manifest.run_batch(prompt)
|
result = manifest.run_batch(prompt, return_response=return_response)
|
||||||
|
if return_response:
|
||||||
|
res = [r.get_response(manifest.stop_token) for r in result]
|
||||||
|
else:
|
||||||
|
res = result
|
||||||
if num_results == 1:
|
if num_results == 1:
|
||||||
assert res == ["hello"]
|
assert res == ["hello"]
|
||||||
else:
|
else:
|
||||||
assert res == [["hello", "hello"]]
|
assert res == [["hello", "hello"]]
|
||||||
|
|
||||||
prompt = Prompt(lambda x: f"{x} is a prompt")
|
prompt = Prompt(lambda x: f"{x} is a prompt")
|
||||||
res = manifest.run_batch(prompt, ["Hello", "Hello"])
|
result = manifest.run_batch(
|
||||||
|
prompt, ["Hello", "Hello"], return_response=return_response
|
||||||
|
)
|
||||||
|
if return_response:
|
||||||
|
res = [r.get_response(manifest.stop_token) for r in result]
|
||||||
|
else:
|
||||||
|
res = result
|
||||||
if num_results == 1:
|
if num_results == 1:
|
||||||
assert res == ["hello", "hello"]
|
assert res == ["hello", "hello"]
|
||||||
else:
|
else:
|
||||||
assert res == [["hello", "hello"], ["hello", "hello"]]
|
assert res == [["hello", "hello"], ["hello", "hello"]]
|
||||||
|
|
||||||
|
prompt = Prompt(lambda x: f"{x} is a prompt")
|
||||||
|
result = manifest.run_batch(
|
||||||
|
prompt, ["Hello", "Hello"], stop_token="ll", return_response=return_response
|
||||||
|
)
|
||||||
|
if return_response:
|
||||||
|
res = [r.get_response(stop_token="ll") for r in result]
|
||||||
|
else:
|
||||||
|
res = result
|
||||||
|
if num_results == 1:
|
||||||
|
assert res == ["he", "he"]
|
||||||
|
else:
|
||||||
|
assert res == [["he", "he"], ["he", "he"]]
|
||||||
|
74
tests/test_response.py
Normal file
74
tests/test_response.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
"""Response test."""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from manifest import Response
|
||||||
|
|
||||||
|
|
||||||
|
def test_init():
|
||||||
|
"""Test response initialization."""
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
response = Response(4, False, {})
|
||||||
|
assert str(exc_info.value) == "Response must be str or dict"
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
response = Response({"test": "hello"}, False, {})
|
||||||
|
assert (
|
||||||
|
str(exc_info.value)
|
||||||
|
== "Response must be serialized to a dict with a list of choices"
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
response = Response({"choices": [{"blah": "hello"}]}, False, {})
|
||||||
|
assert str(exc_info.value) == (
|
||||||
|
"Response must be serialized to a dict "
|
||||||
|
"with a list of choices with text field"
|
||||||
|
)
|
||||||
|
|
||||||
|
response = Response({"choices": [{"text": "hello"}]}, False, {})
|
||||||
|
assert response._response == {"choices": [{"text": "hello"}]}
|
||||||
|
assert response._cached is False
|
||||||
|
assert response._request_params == {}
|
||||||
|
|
||||||
|
response = Response({"choices": [{"text": "hello"}]}, True, {"request": "yoyo"})
|
||||||
|
assert response._response == {"choices": [{"text": "hello"}]}
|
||||||
|
assert response._cached is True
|
||||||
|
assert response._request_params == {"request": "yoyo"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_getters():
|
||||||
|
"""Test response cached."""
|
||||||
|
response = Response({"choices": [{"text": "hello"}]}, False, {})
|
||||||
|
assert response.get_raw_response() == {"choices": [{"text": "hello"}]}
|
||||||
|
assert response.is_cached() is False
|
||||||
|
assert response.get_request() == {}
|
||||||
|
|
||||||
|
response = Response({"choices": [{"text": "hello"}]}, True, {"request": "yoyo"})
|
||||||
|
assert response.get_raw_response() == {"choices": [{"text": "hello"}]}
|
||||||
|
assert response.is_cached() is True
|
||||||
|
assert response.get_request() == {"request": "yoyo"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_serialize():
|
||||||
|
"""Test response serialization."""
|
||||||
|
response = Response({"choices": [{"text": "hello"}]}, True, {"request": "yoyo"})
|
||||||
|
deserialized_response = Response.deserialize(response.serialize())
|
||||||
|
assert deserialized_response._response == {"choices": [{"text": "hello"}]}
|
||||||
|
assert deserialized_response.is_cached() is True
|
||||||
|
assert deserialized_response._request_params == {"request": "yoyo"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_results():
|
||||||
|
"""Test response get results."""
|
||||||
|
response = Response({"choices": []}, True, {"request": "yoyo"})
|
||||||
|
assert response.get_response() is None
|
||||||
|
assert response.get_response(stop_token="ll") is None
|
||||||
|
|
||||||
|
response = Response({"choices": [{"text": "hello"}]}, True, {"request": "yoyo"})
|
||||||
|
assert response.get_response() == "hello"
|
||||||
|
assert response.get_response(stop_token="ll") == "he"
|
||||||
|
|
||||||
|
response = Response(
|
||||||
|
{"choices": [{"text": "hello"}, {"text": "my"}, {"text": "name"}]},
|
||||||
|
True,
|
||||||
|
{"request": "yoyo"},
|
||||||
|
)
|
||||||
|
assert response.get_response() == ["hello", "my", "name"]
|
||||||
|
assert response.get_response(stop_token="m") == ["hello", "", "na"]
|
Loading…
Reference in New Issue
Block a user