diff --git a/Makefile b/Makefile index 358bc88..63d96c2 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,7 @@ dev: poetry install poetry run pre-commit install + poetry run mypy --install-types test: dev check poetry run pytest tests diff --git a/README.md b/README.md index f4e2f2d..3962c50 100644 --- a/README.md +++ b/README.md @@ -3,19 +3,19 @@ Prompt programming with FMs. # Install Download the code: -``` +```bash git clone git@github.com:HazyResearch/manifest.git cd manifest ``` Install: -``` +```bash pip install poetry poetry install poetry run pre-commit install ``` or -``` +```bash pip install poetry make dev ``` @@ -28,14 +28,14 @@ Manifest is meant to be a very light weight package to help with prompt iteratio ## Prompts A Manifest prompt is a function that accepts a single input to generate a string prompt to send to a model. -``` +```python from manifest import Prompt prompt = Prompt(lambda x: "Hello, my name is {x}") -print(promt("Laurel")) +print(prompt("Laurel")) >>> "Hello, my name is Laurel" ``` We also let you use static strings -``` +```python prompt = Prompt("Hello, my name is static") print(prompt()) >>> "Hello, my name is static" @@ -46,13 +46,13 @@ print(prompt()) ## Sessions Each Manifest run is a session that connects to a model endpoint and backend database to record prompt queries. To start a Manifest session for OpenAI, make sure you run -``` +```bash export OPENAI_API_KEY= ``` so we can access OpenAI. Then, in a notebook, run: -``` +```python from manifest import Manifest manifest = Manifest( @@ -64,7 +64,7 @@ manifest = Manifest( This will start a session with OpenAI and save all results to a local file called `sqlite.cache`. We also support a Redis backend. If you have a Redis database running on port 6379, run -``` +```python manifest = Manifest( client_name = "openai", cache_name = "redis", @@ -77,18 +77,18 @@ We will explain [below](#huggingface-models) how to use Manifest for a locally h Once you have a session open, you can write and develop prompts. -``` +```python prompt = Prompt(lambda x: "Hello, my name is {x}") result = manifest.run(prompt, "Laurel") ``` You can also run over multiple examples. -``` +```python results = manifest.batch_run(prompt, ["Laurel", "Avanika"]) ``` If something doesn't go right, you can also ask to get a raw manifest Response. -``` +```python result_object = manifest.batch_run(prompt, ["Laurel", "Avanika"], return_response=True) print(result_object.get_request()) print(result_object.is_cached()) @@ -96,24 +96,24 @@ print(result_object.get_response()) ``` By default, we do not truncate results based on a stop token. You can change this by either passing a new stop token to a Manifest session or to a `run` or `batch_run`. If you set the stop token to `""`, we will not truncate the model output. -``` +```python result = manifest.run(prompt, "Laurel", stop_token="and") ``` If you want to change default parameters to a model, we pass those as `kwargs` to the client. -``` +```python result = manifest.run(prompt, "Laurel", max_tokens=50) ``` # Huggingface Models To use a HuggingFace generative model, in `manifest/api` we have a Falsk application that hosts the models for you. In a separate terminal or Tmux/Screen session, run -``` +```python python3 manifest/api/app.py --model_type huggingface --model_name EleutherAI/gpt-j-6B --device 0 ``` You will see the Flask session start and output a URL `http://127.0.0.1:5000`. Pass this in to Manifest. If you want to use a different port, set the `FLASK_PORT` environment variable. -``` +```python manifest = Manifest( client_name = "huggingface", client_connection = "http://127.0.0.1:5000", @@ -122,11 +122,13 @@ manifest = Manifest( ) ``` +If you have a custom model you trained, pass the model path to `--model_name`. + **Auto deployment coming soon** # Development Before submitting a PR, run -``` +```bash export REDIS_PORT="6380" # or whatever PORT local redis is running for those tests cd docker run -d -p 127.0.0.1:${REDIS_PORT}:6380 -v `pwd`:`pwd` -w `pwd` --name manifest_redis_test redis @@ -134,12 +136,12 @@ make test ``` To use our development Redis database, email [Laurel](lorr1@cs.stanford.edu). If you have access to our GCP account, in a separate terminal, run -``` +```bash gcloud compute ssh "manifest-connect" --zone "europe-west4-a" --project "hai-gcp-head-models" -- -N -L 6379:10.152.93.107:6379 ``` Then if you issue -``` +```bash redis-cli ping ``` You should see a `PONG` response from our database. diff --git a/manifest/api/app.py b/manifest/api/app.py index a89ace8..1d5c87c 100644 --- a/manifest/api/app.py +++ b/manifest/api/app.py @@ -81,6 +81,12 @@ def completions() -> Dict: return OpenAIResponse(results).__dict__() +@app.route("/params", methods=["POST"]) +def params() -> Dict: + """Get model params.""" + return model.get_init_params() + + @app.route("/") def index() -> str: """Get index completion.""" diff --git a/manifest/api/models/huggingface.py b/manifest/api/models/huggingface.py index 1939386..2169a31 100644 --- a/manifest/api/models/huggingface.py +++ b/manifest/api/models/huggingface.py @@ -1,7 +1,10 @@ """Huggingface model.""" -from typing import Any, List +import json +from pathlib import Path +from typing import Any, Dict, List from transformers import ( + AutoModelForSeq2SeqLM, AutoTokenizer, GPT2LMHeadModel, GPTJForCausalLM, @@ -17,6 +20,18 @@ MODEL_REGISTRY = { "EleutherAI/gpt-neo-1.3B": GPTNeoForCausalLM, "EleutherAI/gpt-neo-2.7B": GPTNeoForCausalLM, "gpt2": GPT2LMHeadModel, + "bigscience/T0pp": AutoModelForSeq2SeqLM, + "bigscience/T0_3B": AutoModelForSeq2SeqLM, +} + +MODEL_PIPELINE = { + "EleutherAI/gpt-j-6B": "text-generation", + "EleutherAI/gpt-neo-125M": "text-generation", + "EleutherAI/gpt-neo-1.3B": "text-generation", + "EleutherAI/gpt-neo-2.7B": "text-generation", + "gpt2": "text-generation", + "bigscience/T0pp": "text2text-generation", + "bigscience/T0_3B": "text2text-generation", } @@ -32,13 +47,27 @@ class HuggingFaceModel(Model): Args: model_name: model name string. """ + # Check if providing path + self.model_path = model_name + if Path(self.model_path).exists() and Path(self.model_path).is_dir(): + # Try to find config + if (Path(self.model_path) / "config.json").exists(): + config = json.load(open(Path(self.model_path) / "config.json")) + model_name = config["_name_or_path"] + self.model_name = model_name + print("Model Name:", self.model_name, "Model Path:", self.model_path) model = MODEL_REGISTRY[model_name].from_pretrained( - model_name, cache_dir=cache_dir + self.model_path, cache_dir=cache_dir ) tokenizer = AutoTokenizer.from_pretrained(model_name) self.pipeline = pipeline( - "text-generation", model=model, tokenizer=tokenizer, device=device + MODEL_PIPELINE[model_name], model=model, tokenizer=tokenizer, device=device ) + self.returns_input = MODEL_PIPELINE[model_name] == "text-generation" + + def get_init_params(self) -> Dict: + """Return init params to determine what model is being used.""" + return {"model_name": self.model_name, "model_path": self.model_path} def generate(self, prompt: str, **kwargs: Any) -> List[str]: """ @@ -66,14 +95,12 @@ class HuggingFaceModel(Model): top_p=kwargs.get("top_p"), num_return_sequences=num_return, ) - # Removes tokens removed from tokenization - decoded_prompt = self.pipeline.tokenizer.decode( - encoded_prompt, clean_up_tokenization_spaces=True - ) - if num_return == 1: - final_results.append(result[0]["generated_text"][len(decoded_prompt) :]) + if self.returns_input: + start_idx = len(prompt) else: - final_results.append( - [r["generated_text"][len(decoded_prompt) :] for r in result] - ) + start_idx = 0 + if num_return == 1: + final_results.append(result[0]["generated_text"][start_idx:]) + else: + final_results.append([r["generated_text"][start_idx:] for r in result]) return final_results diff --git a/manifest/api/models/model.py b/manifest/api/models/model.py index 428c765..7723a85 100644 --- a/manifest/api/models/model.py +++ b/manifest/api/models/model.py @@ -1,6 +1,6 @@ """Model class.""" from abc import ABC, abstractmethod -from typing import Any, List +from typing import Any, Dict, List class Model(ABC): @@ -18,6 +18,11 @@ class Model(ABC): """ raise NotImplementedError() + @abstractmethod + def get_init_params(self) -> Dict: + """Return init params to determine what model is being used.""" + raise NotImplementedError() + @abstractmethod def generate(self, prompt: str, **kwargs: Any) -> List[str]: """ diff --git a/manifest/caches/cache.py b/manifest/caches/cache.py index 5e1bd00..e1a6d94 100644 --- a/manifest/caches/cache.py +++ b/manifest/caches/cache.py @@ -61,19 +61,20 @@ def key_to_response(key: str) -> Dict: class Cache(ABC): """A cache for request/response pairs.""" - def __init__(self, connection_str: str, **kwargs: Any): + def __init__(self, connection_str: str, cache_args: Dict[str, Any] = {}): """ Initialize client. - kwargs are passed to client as default parameters. + cache_args are passed to client as default parameters. For clients like OpenAI that do not require a connection, the connection_str can be None. Args: connection_str: connection string for client. + cache_args: cache arguments. """ - self.connect(connection_str, **kwargs) + self.connect(connection_str, cache_args) @abstractmethod def close(self) -> None: @@ -81,7 +82,7 @@ class Cache(ABC): raise NotImplementedError() @abstractmethod - def connect(self, connection_str: str, **kwargs: Any) -> None: + def connect(self, connection_str: str, cache_args: Dict[str, Any]) -> None: """ Connect to client. diff --git a/manifest/caches/redis.py b/manifest/caches/redis.py index 7e5c1eb..0b9cf3e 100644 --- a/manifest/caches/redis.py +++ b/manifest/caches/redis.py @@ -1,5 +1,5 @@ """Redis cache.""" -from typing import Any, Union +from typing import Any, Dict, Union import redis @@ -9,12 +9,13 @@ from manifest.caches import Cache class RedisCache(Cache): """A Redis cache for request/response pairs.""" - def connect(self, connection_str: str, **kwargs: Any) -> None: + def connect(self, connection_str: str, cache_args: Dict[str, Any]) -> None: """ Connect to client. Args: connection_str: connection string. + cache_args: cache arguments. """ host, port = connection_str.split(":") self.redis = redis.Redis(host=host, port=int(port), db=0) diff --git a/manifest/caches/sqlite.py b/manifest/caches/sqlite.py index 7f76f7c..64ffc72 100644 --- a/manifest/caches/sqlite.py +++ b/manifest/caches/sqlite.py @@ -1,6 +1,6 @@ """SQLite cache.""" import logging -from typing import Any, Union +from typing import Any, Dict, Union from sqlitedict import SqliteDict @@ -12,12 +12,13 @@ logging.getLogger("sqlitedict").setLevel(logging.WARNING) class SQLiteCache(Cache): """A SQLite cache for request/response pairs.""" - def connect(self, connection_str: str, **kwargs: Any) -> None: + def connect(self, connection_str: str, cache_args: Dict[str, Any]) -> None: """ Connect to client. Args: connection_str: connection string. + cache_args: cache arguments. """ self.cache_file = connection_str self.cache = SqliteDict(self.cache_file, autocommit=True) diff --git a/manifest/clients/client.py b/manifest/clients/client.py index afd0d18..d77a04c 100644 --- a/manifest/clients/client.py +++ b/manifest/clients/client.py @@ -6,7 +6,9 @@ from typing import Any, Callable, Dict, Optional, Tuple class Client(ABC): """Client class.""" - def __init__(self, connection_str: Optional[str] = None, **kwargs: Any): + def __init__( + self, connection_str: Optional[str] = None, client_args: Dict[str, Any] = {} + ): """ Initialize client. @@ -17,8 +19,9 @@ class Client(ABC): Args: connection_str: connection string for client. + client_args: client arguments. """ - self.connect(connection_str, **kwargs) + self.connect(connection_str, client_args) @abstractmethod def close(self) -> None: @@ -26,7 +29,7 @@ class Client(ABC): raise NotImplementedError() @abstractmethod - def connect(self, connection_str: str, **kwargs: Any) -> None: + def connect(self, connection_str: str, client_args: Dict[str, Any]) -> None: """ Connect to client. diff --git a/manifest/clients/dummy.py b/manifest/clients/dummy.py index 12af093..73617a6 100644 --- a/manifest/clients/dummy.py +++ b/manifest/clients/dummy.py @@ -13,15 +13,18 @@ class DummyClient(Client): def connect( self, connection_str: Optional[str] = None, - num_results: Optional[int] = 1, - **kwargs: Any, + client_args: Dict[str, Any] = {}, ) -> None: """ Connect to dummpy server. This is a dummy client that returns identity responses. Used for testing. + + Args: + connection_str: connection string. + client_args: client arguments. """ - self.num_results = num_results + self.num_results = client_args.pop("num_results", 1) def close(self) -> None: """Close the client.""" diff --git a/manifest/clients/huggingface.py b/manifest/clients/huggingface.py index b877b81..6cea8ad 100644 --- a/manifest/clients/huggingface.py +++ b/manifest/clients/huggingface.py @@ -15,27 +15,41 @@ class HuggingFaceClient(Client): def connect( self, connection_str: Optional[str] = None, - temperature: Optional[float] = 1.0, - max_tokens: Optional[int] = 10, - top_p: Optional[float] = 1.0, - top_k: Optional[int] = 0, - repetition_penalty: Optional[float] = 1.0, - n: Optional[int] = 1, - **kwargs: Any, + client_args: Dict[str, Any] = {}, ) -> None: - """Connect to the HuggingFace url.""" + """ + Connect to the HuggingFace url. + + Arsg: + connection_str: connection string. + client_args: client arguments. + """ self.host = connection_str.rstrip("/") - self.temperature = temperature - self.max_tokens = max_tokens - self.top_p = top_p - self.top_k = top_k - self.repetition_penalty = repetition_penalty - self.n = n + self.temperature = client_args.pop("temperature", 0.00001) + self.max_tokens = client_args.pop("max_tokens", 10) + self.top_p = client_args.pop("top_p", 1.0) + self.top_k = client_args.pop("top_k", 50) + self.repetition_penalty = client_args.pop("repetition_penalty", 1.0) + self.n = client_args.pop("n", 1) + self.model_params = self.get_model_params() def close(self) -> None: """Close the client.""" pass + def get_model_params(self) -> Dict: + """ + Get model params. + + By getting model params from the server, we can add to request + and make sure cache keys are unique to model. + + Returns: + model params. + """ + res = requests.post(self.host + "/params") + return res.json() + def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]: """ Get request string function. @@ -58,6 +72,7 @@ class HuggingFaceClient(Client): ), "n": kwargs.get("n", self.n), } + request_params.update(self.model_params) def _run_completion() -> Dict: post_str = self.host + "/completions" diff --git a/manifest/clients/openai.py b/manifest/clients/openai.py index ed6593a..a1ae098 100644 --- a/manifest/clients/openai.py +++ b/manifest/clients/openai.py @@ -24,35 +24,32 @@ class OpenAIClient(Client): def connect( self, connection_str: Optional[str] = None, - engine: Optional[str] = "text-ada-001", - temperature: Optional[float] = 0.0, - max_tokens: Optional[int] = 10, - top_p: Optional[float] = 1.0, - frequency_penalty: Optional[int] = 0, - presence_penalty: Optional[int] = 0, - n: Optional[int] = 1, - **kwargs: Any, + client_args: Dict[str, Any] = {}, ) -> None: """ Connect to the OpenAI server. connection_str is passed as default OPENAI_API_KEY if variable not set. + + Args: + connection_str: connection string. + client_args: client arguments. """ openai.api_key = os.environ.get("OPENAI_API_KEY", connection_str) if openai.api_key is None: raise ValueError( - "OpenAI API key not set. Set OPENAI_API_KEY environment ", - "svariable or pass through `connection_str`.", + "OpenAI API key not set. Set OPENAI_API_KEY environment " + "variable or pass through `connection_str`." ) - self.engine = engine + self.engine = client_args.pop("engine", "text-davinci-002") if self.engine not in OPENAI_ENGINES: raise ValueError(f"Invalid engine {self.engine}. Must be {OPENAI_ENGINES}.") - self.temperature = temperature - self.max_tokens = max_tokens - self.top_p = top_p - self.frequency_penalty = frequency_penalty - self.presence_penalty = presence_penalty - self.n = n + self.temperature = client_args.pop("temperature", 0.0) + self.max_tokens = client_args.pop("max_tokens", 10) + self.top_p = client_args.pop("top_p", 1.0) + self.frequency_penalty = client_args.pop("frequency_penalty", 0.0) + self.presence_penalty = client_args.pop("presence_penalty", 0.0) + self.n = client_args.pop("n", 1) def close(self) -> None: """Close the client.""" diff --git a/manifest/clients/opt.py b/manifest/clients/opt.py new file mode 100644 index 0000000..a314502 --- /dev/null +++ b/manifest/clients/opt.py @@ -0,0 +1,61 @@ +"""OpenAI client.""" +import logging +from typing import Any, Callable, Dict, Optional, Tuple + +import requests + +from manifest.clients.client import Client + +logger = logging.getLogger(__name__) + + +class OPTClient(Client): + """OPT client.""" + + def connect( + self, + connection_str: Optional[str] = None, + client_args: Dict[str, Any] = {}, + ) -> None: + """ + Connect to the OPT url. + + Arsg: + connection_str: connection string. + client_args: client arguments. + """ + self.host = connection_str.rstrip("/") + self.temperature = client_args.pop("temperature", 1.0) + self.max_tokens = client_args.pop("max_tokens", 10) + self.top_p = client_args.pop("top_p", 0) + self.n = client_args.pop("n", 1) + + def close(self) -> None: + """Close the client.""" + pass + + def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]: + """ + Get request string function. + + Args: + query: query string. + + Returns: + request function that takes no input. + request parameters as dict. + """ + request_params = { + "prompt": query, + "temperature": kwargs.get("temperature", self.temperature), + "max_tokens": kwargs.get("max_tokens", self.max_tokens), + "top_p": kwargs.get("top_p", self.top_p), + "n": kwargs.get("n", self.n), + } + + def _run_completion() -> Dict: + post_str = self.host + "/completions" + res = requests.post(post_str, json=request_params) + return res.json() + + return _run_completion, request_params diff --git a/manifest/manifest.py b/manifest/manifest.py index d448a7a..b392c5e 100644 --- a/manifest/manifest.py +++ b/manifest/manifest.py @@ -2,21 +2,24 @@ import logging from typing import Any, Iterable, List, Optional, Union -from manifest.response import Response - -logging.getLogger("openai").setLevel(logging.WARNING) -logger = logging.getLogger(__name__) +from tqdm.auto import tqdm from manifest.caches.redis import RedisCache from manifest.caches.sqlite import SQLiteCache from manifest.clients.dummy import DummyClient from manifest.clients.huggingface import HuggingFaceClient from manifest.clients.openai import OpenAIClient +from manifest.clients.opt import OPTClient from manifest.prompt import Prompt +from manifest.response import Response + +logging.getLogger("openai").setLevel(logging.WARNING) +logger = logging.getLogger(__name__) CLIENT_CONSTRUCTORS = { "openai": OpenAIClient, "huggingface": HuggingFaceClient, + "opt": OPTClient, "dummy": DummyClient, } @@ -62,10 +65,14 @@ class Manifest: f"Choices are {list(CACHE_CONSTRUCTORS.keys())}" ) self.client_name = client_name + # Must pass kwargs as dict to client "pop" methods removed used arguments self.client = CLIENT_CONSTRUCTORS[client_name]( # type: ignore - client_connection, **kwargs + client_connection, client_args=kwargs ) - self.cache = CACHE_CONSTRUCTORS[cache_name](cache_connection, **kwargs) + self.cache = CACHE_CONSTRUCTORS[cache_name](cache_connection, cache_args=kwargs) + if len(kwargs) > 0: + raise ValueError(f"{list(kwargs.items())} arguments are not recognized.") + self.stop_token = stop_token def close(self) -> None: @@ -119,6 +126,7 @@ class Manifest: overwrite_cache: bool = False, stop_token: Optional[str] = None, return_response: bool = False, + verbose: bool = False, **kwargs: Any, ) -> Iterable[Union[str, List[str], Response]]: """ @@ -141,7 +149,7 @@ class Manifest: self.run( prompt, inp, overwrite_cache, stop_token, return_response, **kwargs ) - for inp in input + for inp in tqdm(input, desc="Running batch", disable=not verbose) ] def save_prompt(self, name: str, prompt: Prompt) -> None: diff --git a/poetry.lock b/poetry.lock index 5213252..55b415c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -984,6 +984,33 @@ torch-speech = ["torchaudio", "librosa", "pyctcdecode (>=0.3.0)", "phonemizer"] torchhub = ["filelock", "huggingface-hub (>=0.1.0,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "torch (>=1.0)", "tokenizers (>=0.11.1,!=0.11.3,<0.13)", "tqdm (>=4.27)"] vision = ["pillow"] +[[package]] +name = "types-redis" +version = "4.2.6" +description = "Typing stubs for redis" +category = "main" +optional = false +python-versions = "*" + +[[package]] +name = "types-requests" +version = "2.27.29" +description = "Typing stubs for requests" +category = "main" +optional = false +python-versions = "*" + +[package.dependencies] +types-urllib3 = "<1.27" + +[[package]] +name = "types-urllib3" +version = "1.26.15" +description = "Typing stubs for urllib3" +category = "main" +optional = false +python-versions = "*" + [[package]] name = "typing-extensions" version = "4.2.0" @@ -1057,7 +1084,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest- [metadata] lock-version = "1.1" python-versions = "^3.8" -content-hash = "b0f92eed3a0f80cc9976f68098b772a19d19b4742426299bd63cb1c636d8a84a" +content-hash = "9f1c7010530f8668850294c74a7dd60fdaad2086901a5bd0f4438b870666c1a7" [metadata.files] alabaster = [ @@ -1687,6 +1714,18 @@ transformers = [ {file = "transformers-4.19.2-py3-none-any.whl", hash = "sha256:1416315b7c5ff1f56d3915f416b67aa254a9907fbb73ef7f7bffc9210446b5fa"}, {file = "transformers-4.19.2.tar.gz", hash = "sha256:e19a4ff07458eda143c738e5259caf48449fcf078a63d6b1bd1aa806543440a3"}, ] +types-redis = [ + {file = "types-redis-4.2.6.tar.gz", hash = "sha256:d6adc77185cf40b300816767a64c0ee9ee0b21dc174e8e5c23b7e83d43189cb8"}, + {file = "types_redis-4.2.6-py3-none-any.whl", hash = "sha256:1136af954ade0be33b487f440c8cbcbee29f089a83e685484ec91f363c6c69fe"}, +] +types-requests = [ + {file = "types-requests-2.27.29.tar.gz", hash = "sha256:fb453b3a76a48eca66381cea8004feaaea12835e838196f5c7ac87c75c5c19ef"}, + {file = "types_requests-2.27.29-py3-none-any.whl", hash = "sha256:014f4f82db7b96c41feea9adaea30e68cd64c230eeab34b70c29bebb26ec74ac"}, +] +types-urllib3 = [ + {file = "types-urllib3-1.26.15.tar.gz", hash = "sha256:c89283541ef92e344b7f59f83ea9b5a295b16366ceee3f25ecfc5593c79f794e"}, + {file = "types_urllib3-1.26.15-py3-none-any.whl", hash = "sha256:6011befa13f901fc934f59bb1fd6973be6f3acf4ebfce427593a27e7f492918f"}, +] typing-extensions = [ {file = "typing_extensions-4.2.0-py3-none-any.whl", hash = "sha256:6657594ee297170d19f67d55c05852a874e7eb634f4f753dbd667855e07c1708"}, {file = "typing_extensions-4.2.0.tar.gz", hash = "sha256:f1c24655a0da0d1b67f07e17a5e6b2a105894e6824b92096378bb3668ef02376"}, diff --git a/pyproject.toml b/pyproject.toml index eea9245..ce3bf13 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,9 @@ Flask = "^2.1.2" transformers = "^4.19.2" torch = "^1.11.0" requests = "^2.27.1" +tqdm = "^4.64.0" +types-redis = "^4.2.6" +types-requests = "^2.27.29" [tool.poetry.dev-dependencies] black = "^22.3.0" @@ -55,6 +58,7 @@ module = [ "tqdm", "sqlitedict", "dill", + "tqdm.auto", ] [tool.isort] diff --git a/tests/test_client.py b/tests/test_client.py index 90d84e5..1e9acbf 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -9,13 +9,15 @@ from manifest.clients.dummy import DummyClient def test_init(): """Test client initialization.""" - client = DummyClient(connection_str=None, num_results=3) + args = {"num_results": 3} + client = DummyClient(connection_str=None, client_args=args) assert client.num_results == 3 def test_get_request(): """Test client get request.""" - client = DummyClient(connection_str=None, num_results=3) + args = {"num_results": 3} + client = DummyClient(connection_str=None, client_args=args) request_func, request_params = client.get_request("hello") assert request_params == {"prompt": "hello", "num_results": 3} assert request_func() == {"choices": [{"text": "hello"}] * 3} diff --git a/tests/test_manifest.py b/tests/test_manifest.py index be0bc44..7b4207e 100644 --- a/tests/test_manifest.py +++ b/tests/test_manifest.py @@ -10,6 +10,15 @@ from manifest.clients.dummy import DummyClient @pytest.mark.usefixtures("sqlite_cache") def test_init(sqlite_cache): """Test manifest initialization.""" + with pytest.raises(ValueError) as exc_info: + Manifest( + client_name="dummy", + cache_name="sqlite", + cache_connection=sqlite_cache, + sep_tok="", + ) + assert str(exc_info.value) == "[('sep_tok', '')] arguments are not recognized." + manifest = Manifest( client_name="dummy", cache_name="sqlite",