First main commit

laurel/helm
Laurel Orr 2 years ago
parent 6ada1f2d2c
commit ed301be2a7

@ -0,0 +1,10 @@
# This is our code-style check. We currently allow the following exceptions:
# - E731: do not assign a lambda expression, use a def
# - E402: module level import not at top of file
# - W503: line break before binary operator
[flake8]
exclude = .git
max-line-length = 88
ignore = E731, E402, W503
per-file-ignores = __init__.py:F401

@ -0,0 +1,28 @@
---
name: Bug report
about: Create a report to help us improve
---
## Description of the bug
A clear and concise description of what the bug is.
## To Reproduce
Steps to reproduce the behavior:
1. Go to '...'
2. Click on '....'
3. Execute the code '...'
If necessary, attach example data which can be used to replicate the issue.
## Expected behavior
A clear and concise description of what you expected to happen.
## Error Logs/Screenshots
If applicable, add error logs or screenshots to help explain your problem.
## Environment (please complete the following information)
- OS: [e.g. Ubuntu 18.04]
- bootleg Version: [e.g. 0.6.0]
## Additional context
Add any other context about the problem here.

@ -0,0 +1,19 @@
---
name: Feature request
about: Suggest an idea for this project
---
## Description of the feature request
**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
## Description of the solution you'd like
A clear and concise description of what you want to happen.
## Description of the alternatives you've considered
A clear and concise description of any alternative solutions or features you've considered.
## Additional context
Add any other context or screenshots about the feature request here.

@ -0,0 +1,65 @@
name: CI
on:
push:
branches:
- main
pull_request:
branches:
- main
# Allows you to run this workflow manually from the Actions tab
workflow_dispatch:
jobs:
test:
runs-on: ${{ matrix.os }}
timeout-minutes: 30
strategy:
matrix:
os: [ubuntu-latest]
python-version: ["3.8", "3.9", "3.10"]
services:
# Label used to access the service container
redis:
# Docker Hub image
image: redislabs/redis
# Set health checks to wait until redis has started
options: >-
--health-cmd "redis-cli ping"
--health-interval 10s
--health-timeout 5s
--health-retries 5
ports:
# Maps port 6379 on service container to the host
- 6379:6379
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Install Poetry
uses: snok/install-poetry@v1
with:
virtualenvs-in-project: true
virtualenvs-create: true
installer-parallel: true
- name: Load cached venv if cache exists
id: cached-poetry-deps
uses: actions/cache@v3
with:
path: .venv
key: venv-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }}
- name: Install dependencies
if: steps.cached-poetry-deps.outputs.cache-hit != 'true'
run: |
poetry install --no-interaction --no-root
- name: Install Manifest
run: |
make dev
- name: Run preliminary checks
run: |
make check
- name: Test with pytest
run: |
poetry run pytest tests

120
.gitignore vendored

@ -0,0 +1,120 @@
runs/*
*._*
# Pickle Saved
*.pt
*.pk
**/*.pt
**/*.pk
# PyCharm
*.idea
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
*.tsv
*.7z
.DS_Store

@ -0,0 +1,23 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.2.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-toml
- id: check-merge-conflict
- id: check-added-large-files
- repo: https://github.com/timothycrosley/isort
rev: 5.9.3
hooks:
- id: isort
- repo: https://github.com/psf/black
rev: 22.3.0
hooks:
- id: black
language_version: python3
- repo: https://gitlab.com/pycqa/flake8
rev: 3.9.2
hooks:
- id: flake8

@ -0,0 +1,28 @@
dev:
poetry install
poetry run pre-commit install
test: dev check
poetry install
poetry run pytest tests
format:
isort --atomic manifest/ tests/
black manifest/ tests/
check:
isort -c manifest/ tests/
black manifest/ tests/ --check
flake8 manifest/ tests/
mypy manifest/
clean:
pip uninstall -y manifest
rm -rf src/manifest.egg-info
rm -rf build/ dist/
prune:
@bash -c "git fetch -p";
@bash -c "for branch in $(git branch -vv | grep ': gone]' | awk '{print $1}'); do git branch -d $branch; done";
.PHONY: dev test clean check prune

@ -1,2 +1,21 @@
# manifest
Prompt programming with FMs.
# Install
Download the code:
```
git clone git@github.com:HazyResearch/manifest.git
cd manifest
```
Install:
```
pip install poetry
poetry install
poetry run pre-commit install
```
or
```
pip install poetry
make dev
```

@ -0,0 +1,3 @@
"""Manifest init."""
from manifest.manifest import Manifest
from manifest.prompt import Prompt

@ -0,0 +1 @@
"""Flask app."""

@ -0,0 +1 @@
"""Huggingface model."""

@ -0,0 +1 @@
"""Model class."""

@ -0,0 +1,2 @@
"""Cache init."""
from manifest.caches.cache import Cache

@ -0,0 +1,113 @@
"""Cache for queries and responses."""
import json
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Tuple, Union
from manifest.clients.response import Response
def request_to_key(request: Dict) -> str:
"""
Normalize a request into a key.
Args:
request: request to normalize.
Returns:
normalized key.
"""
return json.dumps(request, sort_keys=True)
def key_to_request(key: str) -> Dict:
"""
Convert the normalized version to the request.
Args:
key: normalized key to convert.
Returns:
unnormalized request dict.
"""
return json.loads(key)
class Cache(ABC):
"""A cache for request/response pairs."""
def __init__(self, connection_str: str, **kwargs: Any):
"""
Initialize client.
kwargs are passed to client as default parameters.
For clients like OpenAI that do not require a connection,
the connection_str can be None.
Args:
connection_str: connection string for client.
"""
self.connect(connection_str, **kwargs)
@abstractmethod
def close(self) -> None:
"""Close the client."""
raise NotImplementedError()
@abstractmethod
def connect(self, connection_str: str, **kwargs: Any) -> None:
"""
Connect to client.
Args:
connection_str: connection string.
"""
raise NotImplementedError()
@abstractmethod
def get_key(self, key: str, table: str = "default") -> Union[str, None]:
"""
Get the key for a request.
With return None if key is not in cache.
Args:
key: key for cache.
table: table to get key in.
"""
raise NotImplementedError()
@abstractmethod
def set_key(self, key: str, value: str, table: str = "default") -> None:
"""
Set the value for the key.
Will override old value.
Args:
key: key for cache.
value: new value for key.
table: table to set key in.
"""
raise NotImplementedError()
@abstractmethod
def commit(self) -> None:
"""Commit any results."""
raise NotImplementedError()
def get(
self, request: Dict, overwrite_cache: bool, compute: Callable[[], Response]
) -> Tuple[Response, bool]:
"""Get the result of request (by calling compute as needed)."""
key = request_to_key(request)
cached_response = self.get_key(key)
if cached_response and not overwrite_cache:
cached = True
response = Response.deserialize(cached_response)
else:
# Type Response
response = compute()
self.set_key(key, response.serialize())
cached = False
return response, cached

@ -0,0 +1,52 @@
"""Redis cache."""
from typing import Any, Union
import redis
from manifest.caches import Cache
class RedisCache(Cache):
"""A Redis cache for request/response pairs."""
def connect(self, connection_str: str, **kwargs: Any) -> None:
"""
Connect to client.
Args:
connection_str: connection string.
"""
host, port = connection_str.split(":")
self.redis = redis.Redis(host=host, port=int(port))
return
def close(self) -> None:
"""Close the client."""
self.redis.close()
def get_key(self, key: str, table: str = "default") -> Union[str, None]:
"""
Get the key for a request.
With return None if key is not in cache.
Args:
key: key for cache.
"""
pass
def set_key(self, key: str, value: str, table: str = "default") -> None:
"""
Set the value for the key.
Will override old value.
Args:
key: key for cache.
value: new value for key.
"""
self.redis[key] = value
def commit(self) -> None:
"""Commit any results."""
pass

@ -0,0 +1,79 @@
"""SQLite cache."""
import logging
from pathlib import Path
from typing import Any, Union
from sqlitedict import SqliteDict
from manifest.caches import Cache
logging.getLogger("sqlitedict").setLevel(logging.WARNING)
class SQLiteCache(Cache):
"""A SQLite cache for request/response pairs."""
def connect(self, connection_str: str, **kwargs: Any) -> None:
"""
Connect to client.
Args:
connection_str: connection string.
"""
self.cache_dir = connection_str
Path(self.cache_dir).mkdir(parents=True, exist_ok=True)
# If more than two tables, switch to full on SQL connection
self.query_file = Path(self.cache_dir, "query.sqlite")
self.prompt_file = Path(self.cache_dir, "prompts.sqlite")
self.cache = SqliteDict(self.query_file, autocommit=False)
self.prompt_cache = SqliteDict(self.prompt_file, autocommit=False)
return
def close(self) -> None:
"""Close the client."""
self.cache.close()
def get_key(self, key: str, table: str = "default") -> Union[str, None]:
"""
Get the key for a request.
With return None if key is not in cache.
Args:
key: key for cache.
table: table to get key in.
"""
if table == "prompt":
return self.prompt_cache.get(key)
else:
if table != "default":
raise ValueError(
"SQLiteDict only support table of `default` or `prompt`"
)
return self.cache.get(key)
def set_key(self, key: str, value: str, table: str = "default") -> None:
"""
Set the value for the key.
Will override old value.
Args:
key: key for cache.
value: new value for key.
table: table to set key in.
"""
if table == "prompt":
self.prompt_cache[key] = value
else:
if table != "default":
raise ValueError(
"SQLiteDict only support table of `default` or `prompt`"
)
self.cache[key] = value
self.commit()
def commit(self) -> None:
"""Commit any results."""
self.prompt_cache.commit()
self.cache.commit()

@ -0,0 +1,3 @@
"""Client init."""
from manifest.clients.client import Client
from manifest.clients.response import Response

@ -0,0 +1,58 @@
"""Client class."""
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Optional, Tuple
from manifest.clients import Response
class Client(ABC):
"""Client class."""
def __init__(self, connection_str: Optional[str] = None, **kwargs: Any):
"""
Initialize client.
kwargs are passed to client as default parameters.
For clients like OpenAI that do not require a connection,
the connection_str can be None.
Args:
connection_str: connection string for client.
"""
self.connect(connection_str, **kwargs)
@abstractmethod
def close(self) -> None:
"""Close the client."""
raise NotImplementedError()
@abstractmethod
def connect(self, connection_str: str, **kwargs: Any) -> None:
"""
Connect to client.
Args:
connection_str: connection string.
"""
raise NotImplementedError()
@abstractmethod
def get_request(
self, query: str, **kwargs: Any
) -> Tuple[Callable[[], Response], Dict]:
"""
Get request function.
kwargs override default parameters.
Calling the returned function will run the request.
Args:
query: query string.
Returns:
request function that takes no input.
request parameters as dict.
"""
raise NotImplementedError()

@ -0,0 +1,52 @@
"""Dummy client."""
import logging
from typing import Any, Callable, Dict, Optional, Tuple
from manifest.clients import Client
from manifest.clients.response import Response
logger = logging.getLogger(__name__)
class DummyClient(Client):
"""Dummy client."""
def connect(
self,
connection_str: Optional[str] = None,
num_results: Optional[int] = 1,
**kwargs: Any,
) -> None:
"""
Connect to dummpy server.
This is a dummy client that returns identity responses. Used for testing.
"""
self.num_results = num_results
def close(self) -> None:
"""Close the client."""
pass
def get_request(
self, query: str, **kwargs: Any
) -> Tuple[Callable[[], Response], Dict]:
"""
Get request string function.
Args:
query: query string.
Returns:
request function that takes no input.
request parameters as dict.
"""
request_params = {
"prompt": query,
"num_results": kwargs.get("num_results", self.num_results),
}
def _run_completion() -> Response:
return Response({"choices": [{"text": "hello"}] * self.num_results})
return _run_completion, request_params

@ -0,0 +1,95 @@
"""OpenAI client."""
import logging
import os
from typing import Any, Callable, Dict, Optional, Tuple
import openai
from manifest.clients import Response
from manifest.clients.client import Client
logging.getLogger("openai").setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
OPENAI_ENGINES = {
"text-davinci-002",
"text-curie-001",
"text-babbage-001",
"text-ada-001",
}
class OpenAIClient(Client):
"""OpenAI client."""
def connect(
self,
connection_str: Optional[str] = None,
engine: Optional[str] = "text-ada-001",
temperature: Optional[float] = 0.0,
max_tokens: Optional[int] = 10,
top_p: Optional[int] = 1,
frequency_penalty: Optional[int] = 0,
presence_penalty: Optional[int] = 0,
n: Optional[int] = 1,
**kwargs: Any,
) -> None:
"""
Connect to the OpenAI server.
connection_str is passed as default OPENAI_API_KEY if variable not set.
"""
openai.api_key = os.environ.get("OPENAI_API_KEY", connection_str)
if openai.api_key is None:
raise ValueError(
"OpenAI API key not set. Set OPENAI_API_KEY environment ",
"svariable or pass through `connection_str`.",
)
self.engine = engine
if self.engine not in OPENAI_ENGINES:
raise ValueError(f"Invalid engine {self.engine}. Must be {OPENAI_ENGINES}.")
self.temperature = temperature
self.max_tokens = max_tokens
self.top_p = top_p
self.frequency_penalty = frequency_penalty
self.presence_penalty = presence_penalty
self.n = n
def close(self) -> None:
"""Close the client."""
pass
def get_request(
self, query: str, **kwargs: Any
) -> Tuple[Callable[[], Response], Dict]:
"""
Get request string function.
Args:
query: query string.
Returns:
request function that takes no input.
request parameters as dict.
"""
request_params = {
"engine": kwargs.get("engine", self.engine),
"prompt": query,
"temperature": kwargs.get("temperature", self.temperature),
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
"top_p": kwargs.get("top_p", self.top_p),
"frequency_penalty": kwargs.get(
"frequency_penalty", self.frequency_penalty
),
"presence_penalty": kwargs.get("presence_penalty", self.presence_penalty),
"n": kwargs.get("n", self.n),
}
def _run_completion() -> Response:
try:
return Response(openai.Completion.create(**request_params))
except openai.error.OpenAIError as e:
logger.error(e)
raise e
return _run_completion, request_params

@ -0,0 +1,70 @@
"""Client response."""
import json
from typing import Dict, List, Union
class Response:
"""Response class."""
def __init__(self, response: Union[str, Dict]):
"""Initialize response."""
if isinstance(response, str):
self.response = json.loads(response)
elif isinstance(response, dict):
self.response = response
else:
raise ValueError("Response must be str or dict")
if ("choices" not in self.response) or (
not isinstance(self.response["choices"], list)
):
raise ValueError(
"Response must be serialized to a dict with a list of choices"
)
if len(self.response["choices"]) > 0:
if "text" not in self.response["choices"][0]:
raise ValueError(
"Response must be serialized to a dict with a ",
"list of choices with text field",
)
def __getitem__(self, key: str) -> str:
"""
Return the response given the key.
Args:
key: key to get.
Returns:
value of key.
"""
return self.response[key]
def get_results(self) -> Union[str, List[str]]:
"""Get all text results from response."""
if len(self.response["choices"]) == 0:
return None
if len(self.response["choices"]) == 1:
return self.response["choices"][0]["text"]
return [choice["text"] for choice in self.response["choices"]]
def serialize(self) -> str:
"""
Serialize response to string.
Returns:
serialized response.
"""
return json.dumps(self.response, sort_keys=True)
@classmethod
def deserialize(cls, value: str) -> "Response":
"""
Deserialize string to response.
Args:
value: serialized response.
Returns:
serialized response.
"""
return Response(value)

@ -0,0 +1,138 @@
"""Manifest class."""
import logging
from typing import Any, Iterable, List, Optional, Union
logging.getLogger("openai").setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
from manifest import Prompt
from manifest.caches.redis import RedisCache
from manifest.caches.sqlite import SQLiteCache
from manifest.clients.dummy import DummyClient
from manifest.clients.openai import OpenAIClient
CLIENT_CONSTRUCTORS = {
"openai": OpenAIClient,
# "huggingface": manifest.clients.huggingface.HuggingFaceClient,
"dummy": DummyClient,
}
CACHE_CONSTRUCTORS = {
"redis": RedisCache,
"sqlite": SQLiteCache,
}
class Manifest:
"""Manifest session object."""
def __init__(
self,
client_name: str = "openai",
client_connection: Optional[str] = None,
cache_name: str = "redis",
cache_connection: str = "localhost:6379",
**kwargs: Any,
):
"""
Initialize manifest.
Remaining kwargs sent to client and cache.
"""
if client_name not in CLIENT_CONSTRUCTORS:
raise ValueError(
f"Unknown client name: {client_name}. "
f"Choices are {list(CLIENT_CONSTRUCTORS.keys())}"
)
if cache_name not in CACHE_CONSTRUCTORS:
raise ValueError(
f"Unknown cache name: {cache_name}. "
f"Choices are {list(CACHE_CONSTRUCTORS.keys())}"
)
self.client_name = client_name
self.client = CLIENT_CONSTRUCTORS[client_name](client_connection, **kwargs)
self.cache = CACHE_CONSTRUCTORS[cache_name](cache_connection, **kwargs)
def close(self) -> None:
"""Close the client and cache."""
self.client.close()
self.cache.close()
def run(
self,
prompt: Prompt,
input: Optional[Any] = None,
overwrite_cache: bool = False,
**kwargs: Any,
) -> Union[str, List[str]]:
"""
Run the prompt.
Args:
prompt: prompt to run.
input: input to prompt.
overwrite_cache: whether to overwrite cache.
Returns:
response from prompt.
"""
prompt_str = prompt(input)
possible_request, full_kwargs = self.client.get_request(prompt_str, **kwargs)
# Create cacke key
cache_key = full_kwargs.copy()
# Make query model dependent
cache_key["client_name"] = self.client_name
# Make query prompt dependent
cache_key["prompt"] = prompt_str
response, _ = self.cache.get(cache_key, overwrite_cache, possible_request)
return response.get_results()
def run_batch(
self,
prompt: Prompt,
input: Optional[Iterable[Any]] = None,
overwrite_cache: bool = False,
**kwargs: Any,
) -> Iterable[Union[str, List[str]]]:
"""
Run the prompt on a batch of inputs.
Args:
prompt: prompt to run.
input: batch of inputs.
overwrite_cache: whether to overwrite cache.
Returns:
batch of responses.
"""
if input is None:
input = [None]
return [self.run(prompt, inp, overwrite_cache, **kwargs) for inp in input]
def save_prompt(self, name: str, prompt: Prompt) -> None:
"""
Save the prompt to the cache for long term storage.
Args:
name: name of prompt.
prompt: prompt to save.
"""
self.cache.set_key(name, prompt.serialize(), table="prompt")
def load_prompt(self, name: str) -> Prompt:
"""
Load the prompt from the cache.
Args:
name: name of prompt.
Returns:
Prompt saved with name.
"""
return Prompt.deserialize(self.cache.get_key(name, table="prompt"))
def open_explorer(self) -> None:
"""Open the explorer for jupyter widget."""
# Open explorer
# TODO: implement
pass

@ -0,0 +1,72 @@
"""Prompt class."""
import inspect
import logging
from typing import Any, Callable, List, Optional, Union
logging.getLogger("openai").setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
class Prompt:
"""Prompt class."""
def __init__(self, prompt_obj: Union[str, Callable, "Prompt", List["Prompt"]]):
"""
Initialize prompt.
If prompt_obj is a string, it will be cast as function.
If prompt_obj is list of promts, it will be composed.
"""
# TODO: figure out how to compose prompts to keep the
# interface simple? Can we make a function
# such that a single call will run the composition?
if isinstance(prompt_obj, str):
self.prompt_func = lambda: prompt_obj
elif callable(prompt_obj):
self.prompt_func = prompt_obj
else:
# TODO: implement
raise NotImplementedError()
self.num_args = len(inspect.signature(self.prompt_func).parameters)
if self.num_args > 1:
raise ValueError("Prompts must have zero or one input.")
def __call__(self, input: Optional[Any] = None) -> str:
"""
Return the prompt given the inputs.
Args:
input: input to prompt.
Returns:
prompt string.
"""
if self.num_args >= 1:
return self.prompt_func(input) # type: ignore
else:
return self.prompt_func()
def serialize(self) -> str:
"""
Return the prompt as str.
Returns:
json object.
"""
# TODO: implement
pass
@classmethod
def deserialize(cls, obj: str) -> "Prompt":
"""
Return the prompt from a json object.
Args:
obj: json object.
Return:
prompt.
"""
# TODO: implement
pass

1418
poetry.lock generated

File diff suppressed because it is too large Load Diff

@ -0,0 +1,69 @@
[tool.poetry]
authors = ["Laurel Orr <lorr1@cs.stanford.edu>, Avanika Narayan <avanikan@stanford.edu>"]
classifiers = [
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10"
]
description = "Manifest for Prompt Programming"
name = "manifest"
repository = "https://github.com/HazyResearch/manifest"
version = "0.0.1"
[tool.poetry.urls]
"Bug Tracker" = "https://github.com/HazyResearch/manifest/issues"
[tool.poetry.dependencies]
python = "^3.8"
sqlitedict = "^2.0.0"
openai = "^0.18.1"
redis = "^4.3.1"
[tool.poetry.dev-dependencies]
black = "^22.3.0"
flake8 = "^4.0.0"
flake8-docstrings = "^1.6.0"
isort = "^5.9.3"
mypy = "^0.950"
pep8-naming = "^0.12.1"
pre-commit = "^2.14.0"
pytest = "^7.0.0"
pytest-cov = "^3.0.0"
python-dotenv = "^0.20.0"
recommonmark = "^0.7.1"
[build-system]
build-backend = "poetry.core.masonry.api"
requires = ["poetry-core>=1.0.0"]
# Additional Tool Configurations
[tool.mypy]
disallow_untyped_defs = true
strict_optional = false
[[tool.mypy.overrides]]
ignore_missing_imports = true
module = [
"numpy",
"tqdm",
"sqlitedict",
]
[tool.isort]
combine_as_imports = true
force_grid_wrap = 0
include_trailing_comma = true
known_first_party = ["manifest"]
line_length = 88
multi_line_output = 3
[tool.pytest.ini_options]
log_format = "[%(levelname)s] %(message)s"
log_date_format = "%Y-%m-%d %H:%M:%S"
addopts = "-v -rsXx"
# The following options are useful for local debugging
# addopts = "-v -rsXx -s -x --pdb"
# log_cli_level = "DEBUG"
# log_cli = true

@ -0,0 +1 @@
"""Test client."""

@ -0,0 +1,59 @@
"""Response test."""
import json
import pytest
from manifest.clients import Response
def test_init():
"""Test response initialization."""
with pytest.raises(ValueError) as exc_info:
response = Response(4)
assert str(exc_info.value) == "Response must be str or dict"
with pytest.raises(ValueError) as exc_info:
response = Response({"test": "hello"})
assert (
str(exc_info.value)
== "Response must be serialized to a dict with a list of choices"
)
with pytest.raises(ValueError) as exc_info:
response = Response({"choices": [{"blah": "hello"}]})
assert str(exc_info.value) == (
"Response must be serialized to a dict ",
"with a list of choices with text field",
)
response = Response({"choices": [{"text": "hello"}]})
assert response.response == {"choices": [{"text": "hello"}]}
response = Response(json.dumps({"choices": [{"text": "hello"}]}))
assert response.response == {"choices": [{"text": "hello"}]}
def test_getitem():
"""Test response getitem."""
response = Response({"choices": [{"text": "hello"}]})
assert response["choices"] == [{"text": "hello"}]
def test_serialize():
"""Test response serialization."""
response = Response({"choices": [{"text": "hello"}]})
assert Response.deserialize(response.serialize()).response == {
"choices": [{"text": "hello"}]
}
def test_get_results():
"""Test response get results."""
response = Response({"choices": []})
assert response.get_results() is None
response = Response({"choices": [{"text": "hello"}]})
assert response.get_results() == "hello"
response = Response(
{"choices": [{"text": "hello"}, {"text": "my"}, {"text": "name"}]}
)
assert response.get_results() == ["hello", "my", "name"]

@ -0,0 +1,36 @@
"""Setup for all tests."""
import os
import shutil
import pytest
@pytest.fixture
def sqlite_cache(tmp_path):
"""Sqlite Cache."""
cache = str(tmp_path / "sqlite_cache.sqlite")
yield cache
shutil.rmtree(cache, ignore_errors=True)
@pytest.fixture
def redis_cache():
"""Redis cache."""
if "CI" not in os.environ:
# Give a clear warning on setting REDIS_PORT before running tests.
try:
port = os.environ["REDIS_PORT"]
except KeyError:
raise KeyError(
"Set REDIS_PORT env var to the instance you want to use "
+ "for testing. Note that doing so WILL delete the db at "
+ "localhost:REDIS_PORT, db=0, so BE CAREFUL."
)
host = os.environ.get("REDIS_HOST", "localhost")
else:
host = os.environ.get("REDIS_HOST", "localhost")
port = os.environ.get("REDIS_PORT", 6379)
yield f"{host}:{port}"
# Clear out the database
# db = redis.Redis(host=host, port=port)
# db.flushdb()

@ -0,0 +1,69 @@
"""Cache test."""
import pytest
from sqlitedict import SqliteDict
from manifest.caches.redis import RedisCache
from manifest.caches.sqlite import SQLiteCache
from manifest.clients import Response
@pytest.mark.usefixtures("sqlite_cache")
@pytest.mark.usefixtures("redis_cache")
@pytest.mark.parametrize("cache_type", ["sqlite"])
def test_init(sqlite_cache, redis_cache, cache_type):
"""Test cache initialization."""
if cache_type == "sqlite":
cache = SQLiteCache(sqlite_cache)
assert isinstance(cache.cache, SqliteDict)
assert isinstance(cache.prompt_cache, SqliteDict)
else:
cache = RedisCache(redis_cache)
@pytest.mark.usefixtures("sqlite_cache")
@pytest.mark.usefixtures("redis_cache")
@pytest.mark.parametrize("cache_type", ["sqlite"])
def test_key_get_and_set(sqlite_cache, redis_cache, cache_type):
"""Test cache key get and set."""
if cache_type == "sqlite":
cache = SQLiteCache(sqlite_cache)
else:
cache = RedisCache(redis_cache)
cache.set_key("test", "valueA")
cache.set_key("testA", "valueB")
assert cache.get_key("test") == "valueA"
assert cache.get_key("testA") == "valueB"
cache.set_key("testA", "valueC")
assert cache.get_key("testA") == "valueC"
cache.get_key("test", table="prompt") is None
cache.set_key("test", "valueA", table="prompt")
cache.get_key("test", table="prompt") == "valueA"
@pytest.mark.usefixtures("sqlite_cache")
@pytest.mark.usefixtures("redis_cache")
@pytest.mark.parametrize("cache_type", ["sqlite"])
def test_get(sqlite_cache, redis_cache, cache_type):
"""Test cache save prompt."""
if cache_type == "sqlite":
cache = SQLiteCache(sqlite_cache)
else:
cache = RedisCache(redis_cache)
test_request = {"test": "hello", "testA": "world"}
compute = lambda: Response({"choices": [{"text": "hello"}]})
response, cached = cache.get(test_request, overwrite_cache=False, compute=compute)
assert response.get_results() == "hello"
assert not cached
response, cached = cache.get(test_request, overwrite_cache=False, compute=compute)
assert response.get_results() == "hello"
assert cached
response, cached = cache.get(test_request, overwrite_cache=True, compute=compute)
assert response.get_results() == "hello"
assert not cached

@ -0,0 +1,105 @@
"""Manifest test."""
import pytest
from manifest import Manifest, Prompt
from manifest.caches.cache import request_to_key
from manifest.caches.sqlite import SQLiteCache
from manifest.clients.dummy import DummyClient
@pytest.mark.usefixtures("sqlite_cache")
def test_init(sqlite_cache):
"""Test manifest initialization."""
manifest = Manifest(
client_name="dummy",
cache_name="sqlite",
cache_connection=sqlite_cache,
)
assert manifest.client_name == "dummy"
assert isinstance(manifest.client, DummyClient)
assert isinstance(manifest.cache, SQLiteCache)
manifest = Manifest(
client_name="dummy",
cache_name="sqlite",
cache_connection=sqlite_cache,
num_results=3,
)
assert manifest.client_name == "dummy"
assert isinstance(manifest.client, DummyClient)
assert isinstance(manifest.cache, SQLiteCache)
assert manifest.client.num_results == 3
@pytest.mark.usefixtures("sqlite_cache")
@pytest.mark.parametrize("num_results", [1, 2])
def test_run(sqlite_cache, num_results):
"""Test manifest run."""
manifest = Manifest(
client_name="dummy",
cache_name="sqlite",
cache_connection=sqlite_cache,
num_results=num_results,
)
prompt = Prompt("This is a prompt")
res = manifest.run(prompt)
assert (
manifest.cache.get_key(
request_to_key(
{
"prompt": "This is a prompt",
"client_name": "dummy",
"num_results": num_results,
}
)
)
is not None
)
if num_results == 1:
assert res == "hello"
else:
assert res == ["hello", "hello"]
prompt = Prompt(lambda x: f"{x} is a prompt")
res = manifest.run(prompt, "Hello")
assert (
manifest.cache.get_key(
request_to_key(
{
"prompt": "Hello is a prompt",
"client_name": "dummy",
"num_results": num_results,
}
)
)
is not None
)
if num_results == 1:
assert res == "hello"
else:
assert res == ["hello", "hello"]
@pytest.mark.usefixtures("sqlite_cache")
@pytest.mark.parametrize("num_results", [1, 2])
def test_batch_run(sqlite_cache, num_results):
"""Test manifest run."""
manifest = Manifest(
client_name="dummy",
cache_name="sqlite",
cache_connection=sqlite_cache,
num_results=num_results,
)
prompt = Prompt("This is a prompt")
res = manifest.run_batch(prompt)
if num_results == 1:
assert res == ["hello"]
else:
assert res == [["hello", "hello"]]
prompt = Prompt(lambda x: f"{x} is a prompt")
res = manifest.run_batch(prompt, ["Hello", "Hello"])
if num_results == 1:
assert res == ["hello", "hello"]
else:
assert res == [["hello", "hello"], ["hello", "hello"]]

@ -0,0 +1,54 @@
"""Prompt test."""
import pytest
from manifest import Prompt
def test_init():
"""Test prompt initialization."""
str_prompt = "This is a test prompt"
func_prompt = lambda: "This is a test prompt"
func_single_prompt = lambda x: f"{x} is a test prompt"
func_list_prompt = lambda x: f"{x[0]} is a test {x[1]}"
func_double_prompt = lambda x, y: f"{x} is a test {y}"
# TODO: add list of prompt tests
# String prompt
prompt = Prompt(str_prompt)
assert prompt(None) == str_prompt
assert prompt() == str_prompt
# Function no inputs
prompt = Prompt(func_prompt)
assert prompt(None) == str_prompt
assert prompt() == str_prompt
# Function single inputs
prompt = Prompt(func_single_prompt)
assert prompt("This") == str_prompt
assert prompt("Hello") == "Hello is a test prompt"
# Function list inputs
prompt = Prompt(func_list_prompt)
assert prompt(["This", "prompt"]) == str_prompt
assert prompt(["Hello", "prompt"]) == "Hello is a test prompt"
# Function two inputs
with pytest.raises(ValueError) as exc_info:
Prompt(func_double_prompt)
assert str(exc_info.value) == "Prompts must have zero or one input."
@pytest.mark.skip(reason="Not implemented")
def test_serialize():
"""Test prompt serialization."""
str_prompt = "This is a test prompt"
func_single_prompt = lambda x: f"{x} is a test prompt"
# String prompt
prompt = Prompt(str_prompt)
assert Prompt.deserialize(prompt.serialize()) == prompt
# Function single inputs
prompt = Prompt(func_single_prompt)
assert Prompt.deserialize(prompt.serialize()) == prompt
Loading…
Cancel
Save