mirror of
https://github.com/HazyResearch/manifest
synced 2024-10-31 15:20:26 +00:00
First main commit
This commit is contained in:
parent
6ada1f2d2c
commit
ed301be2a7
10
.flake8
Normal file
10
.flake8
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
# This is our code-style check. We currently allow the following exceptions:
|
||||||
|
# - E731: do not assign a lambda expression, use a def
|
||||||
|
# - E402: module level import not at top of file
|
||||||
|
# - W503: line break before binary operator
|
||||||
|
|
||||||
|
[flake8]
|
||||||
|
exclude = .git
|
||||||
|
max-line-length = 88
|
||||||
|
ignore = E731, E402, W503
|
||||||
|
per-file-ignores = __init__.py:F401
|
28
.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
28
.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
---
|
||||||
|
name: Bug report
|
||||||
|
about: Create a report to help us improve
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Description of the bug
|
||||||
|
A clear and concise description of what the bug is.
|
||||||
|
|
||||||
|
## To Reproduce
|
||||||
|
Steps to reproduce the behavior:
|
||||||
|
1. Go to '...'
|
||||||
|
2. Click on '....'
|
||||||
|
3. Execute the code '...'
|
||||||
|
If necessary, attach example data which can be used to replicate the issue.
|
||||||
|
|
||||||
|
## Expected behavior
|
||||||
|
A clear and concise description of what you expected to happen.
|
||||||
|
|
||||||
|
## Error Logs/Screenshots
|
||||||
|
If applicable, add error logs or screenshots to help explain your problem.
|
||||||
|
|
||||||
|
## Environment (please complete the following information)
|
||||||
|
- OS: [e.g. Ubuntu 18.04]
|
||||||
|
- bootleg Version: [e.g. 0.6.0]
|
||||||
|
|
||||||
|
## Additional context
|
||||||
|
Add any other context about the problem here.
|
19
.github/ISSUE_TEMPLATE/feature_request.md
vendored
Normal file
19
.github/ISSUE_TEMPLATE/feature_request.md
vendored
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
---
|
||||||
|
name: Feature request
|
||||||
|
about: Suggest an idea for this project
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Description of the feature request
|
||||||
|
|
||||||
|
**Is your feature request related to a problem? Please describe.**
|
||||||
|
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
|
||||||
|
|
||||||
|
## Description of the solution you'd like
|
||||||
|
A clear and concise description of what you want to happen.
|
||||||
|
|
||||||
|
## Description of the alternatives you've considered
|
||||||
|
A clear and concise description of any alternative solutions or features you've considered.
|
||||||
|
|
||||||
|
## Additional context
|
||||||
|
Add any other context or screenshots about the feature request here.
|
65
.github/workflows/ci.yaml
vendored
Normal file
65
.github/workflows/ci.yaml
vendored
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
name: CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
|
||||||
|
# Allows you to run this workflow manually from the Actions tab
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
timeout-minutes: 30
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
os: [ubuntu-latest]
|
||||||
|
python-version: ["3.8", "3.9", "3.10"]
|
||||||
|
services:
|
||||||
|
# Label used to access the service container
|
||||||
|
redis:
|
||||||
|
# Docker Hub image
|
||||||
|
image: redislabs/redis
|
||||||
|
# Set health checks to wait until redis has started
|
||||||
|
options: >-
|
||||||
|
--health-cmd "redis-cli ping"
|
||||||
|
--health-interval 10s
|
||||||
|
--health-timeout 5s
|
||||||
|
--health-retries 5
|
||||||
|
ports:
|
||||||
|
# Maps port 6379 on service container to the host
|
||||||
|
- 6379:6379
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v3
|
||||||
|
- uses: actions/setup-python@v3
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
- name: Install Poetry
|
||||||
|
uses: snok/install-poetry@v1
|
||||||
|
with:
|
||||||
|
virtualenvs-in-project: true
|
||||||
|
virtualenvs-create: true
|
||||||
|
installer-parallel: true
|
||||||
|
- name: Load cached venv if cache exists
|
||||||
|
id: cached-poetry-deps
|
||||||
|
uses: actions/cache@v3
|
||||||
|
with:
|
||||||
|
path: .venv
|
||||||
|
key: venv-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }}
|
||||||
|
- name: Install dependencies
|
||||||
|
if: steps.cached-poetry-deps.outputs.cache-hit != 'true'
|
||||||
|
run: |
|
||||||
|
poetry install --no-interaction --no-root
|
||||||
|
- name: Install Manifest
|
||||||
|
run: |
|
||||||
|
make dev
|
||||||
|
- name: Run preliminary checks
|
||||||
|
run: |
|
||||||
|
make check
|
||||||
|
- name: Test with pytest
|
||||||
|
run: |
|
||||||
|
poetry run pytest tests
|
120
.gitignore
vendored
Normal file
120
.gitignore
vendored
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
runs/*
|
||||||
|
|
||||||
|
*._*
|
||||||
|
# Pickle Saved
|
||||||
|
*.pt
|
||||||
|
*.pk
|
||||||
|
**/*.pt
|
||||||
|
**/*.pk
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
*.idea
|
||||||
|
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
.python-version
|
||||||
|
|
||||||
|
# celery beat schedule file
|
||||||
|
celerybeat-schedule
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
|
||||||
|
*.tsv
|
||||||
|
*.7z
|
||||||
|
.DS_Store
|
23
.pre-commit-config.yaml
Normal file
23
.pre-commit-config.yaml
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
repos:
|
||||||
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
rev: v3.2.0
|
||||||
|
hooks:
|
||||||
|
- id: trailing-whitespace
|
||||||
|
- id: end-of-file-fixer
|
||||||
|
- id: check-yaml
|
||||||
|
- id: check-toml
|
||||||
|
- id: check-merge-conflict
|
||||||
|
- id: check-added-large-files
|
||||||
|
- repo: https://github.com/timothycrosley/isort
|
||||||
|
rev: 5.9.3
|
||||||
|
hooks:
|
||||||
|
- id: isort
|
||||||
|
- repo: https://github.com/psf/black
|
||||||
|
rev: 22.3.0
|
||||||
|
hooks:
|
||||||
|
- id: black
|
||||||
|
language_version: python3
|
||||||
|
- repo: https://gitlab.com/pycqa/flake8
|
||||||
|
rev: 3.9.2
|
||||||
|
hooks:
|
||||||
|
- id: flake8
|
28
Makefile
Normal file
28
Makefile
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
dev:
|
||||||
|
poetry install
|
||||||
|
poetry run pre-commit install
|
||||||
|
|
||||||
|
test: dev check
|
||||||
|
poetry install
|
||||||
|
poetry run pytest tests
|
||||||
|
|
||||||
|
format:
|
||||||
|
isort --atomic manifest/ tests/
|
||||||
|
black manifest/ tests/
|
||||||
|
|
||||||
|
check:
|
||||||
|
isort -c manifest/ tests/
|
||||||
|
black manifest/ tests/ --check
|
||||||
|
flake8 manifest/ tests/
|
||||||
|
mypy manifest/
|
||||||
|
|
||||||
|
clean:
|
||||||
|
pip uninstall -y manifest
|
||||||
|
rm -rf src/manifest.egg-info
|
||||||
|
rm -rf build/ dist/
|
||||||
|
|
||||||
|
prune:
|
||||||
|
@bash -c "git fetch -p";
|
||||||
|
@bash -c "for branch in $(git branch -vv | grep ': gone]' | awk '{print $1}'); do git branch -d $branch; done";
|
||||||
|
|
||||||
|
.PHONY: dev test clean check prune
|
19
README.md
19
README.md
@ -1,2 +1,21 @@
|
|||||||
# manifest
|
# manifest
|
||||||
Prompt programming with FMs.
|
Prompt programming with FMs.
|
||||||
|
|
||||||
|
# Install
|
||||||
|
Download the code:
|
||||||
|
```
|
||||||
|
git clone git@github.com:HazyResearch/manifest.git
|
||||||
|
cd manifest
|
||||||
|
```
|
||||||
|
|
||||||
|
Install:
|
||||||
|
```
|
||||||
|
pip install poetry
|
||||||
|
poetry install
|
||||||
|
poetry run pre-commit install
|
||||||
|
```
|
||||||
|
or
|
||||||
|
```
|
||||||
|
pip install poetry
|
||||||
|
make dev
|
||||||
|
```
|
3
manifest/__init__.py
Normal file
3
manifest/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
"""Manifest init."""
|
||||||
|
from manifest.manifest import Manifest
|
||||||
|
from manifest.prompt import Prompt
|
1
manifest/api/app.py
Normal file
1
manifest/api/app.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
"""Flask app."""
|
1
manifest/api/models/huggingface_model.py
Normal file
1
manifest/api/models/huggingface_model.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
"""Huggingface model."""
|
1
manifest/api/models/model.py
Normal file
1
manifest/api/models/model.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
"""Model class."""
|
2
manifest/caches/__init__.py
Normal file
2
manifest/caches/__init__.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
"""Cache init."""
|
||||||
|
from manifest.caches.cache import Cache
|
113
manifest/caches/cache.py
Normal file
113
manifest/caches/cache.py
Normal file
@ -0,0 +1,113 @@
|
|||||||
|
"""Cache for queries and responses."""
|
||||||
|
import json
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Callable, Dict, Tuple, Union
|
||||||
|
|
||||||
|
from manifest.clients.response import Response
|
||||||
|
|
||||||
|
|
||||||
|
def request_to_key(request: Dict) -> str:
|
||||||
|
"""
|
||||||
|
Normalize a request into a key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: request to normalize.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
normalized key.
|
||||||
|
"""
|
||||||
|
return json.dumps(request, sort_keys=True)
|
||||||
|
|
||||||
|
|
||||||
|
def key_to_request(key: str) -> Dict:
|
||||||
|
"""
|
||||||
|
Convert the normalized version to the request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: normalized key to convert.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
unnormalized request dict.
|
||||||
|
"""
|
||||||
|
return json.loads(key)
|
||||||
|
|
||||||
|
|
||||||
|
class Cache(ABC):
|
||||||
|
"""A cache for request/response pairs."""
|
||||||
|
|
||||||
|
def __init__(self, connection_str: str, **kwargs: Any):
|
||||||
|
"""
|
||||||
|
Initialize client.
|
||||||
|
|
||||||
|
kwargs are passed to client as default parameters.
|
||||||
|
|
||||||
|
For clients like OpenAI that do not require a connection,
|
||||||
|
the connection_str can be None.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connection_str: connection string for client.
|
||||||
|
"""
|
||||||
|
self.connect(connection_str, **kwargs)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close the client."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def connect(self, connection_str: str, **kwargs: Any) -> None:
|
||||||
|
"""
|
||||||
|
Connect to client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connection_str: connection string.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_key(self, key: str, table: str = "default") -> Union[str, None]:
|
||||||
|
"""
|
||||||
|
Get the key for a request.
|
||||||
|
|
||||||
|
With return None if key is not in cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: key for cache.
|
||||||
|
table: table to get key in.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def set_key(self, key: str, value: str, table: str = "default") -> None:
|
||||||
|
"""
|
||||||
|
Set the value for the key.
|
||||||
|
|
||||||
|
Will override old value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: key for cache.
|
||||||
|
value: new value for key.
|
||||||
|
table: table to set key in.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def commit(self) -> None:
|
||||||
|
"""Commit any results."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get(
|
||||||
|
self, request: Dict, overwrite_cache: bool, compute: Callable[[], Response]
|
||||||
|
) -> Tuple[Response, bool]:
|
||||||
|
"""Get the result of request (by calling compute as needed)."""
|
||||||
|
key = request_to_key(request)
|
||||||
|
cached_response = self.get_key(key)
|
||||||
|
if cached_response and not overwrite_cache:
|
||||||
|
cached = True
|
||||||
|
response = Response.deserialize(cached_response)
|
||||||
|
else:
|
||||||
|
# Type Response
|
||||||
|
response = compute()
|
||||||
|
self.set_key(key, response.serialize())
|
||||||
|
cached = False
|
||||||
|
return response, cached
|
52
manifest/caches/redis.py
Normal file
52
manifest/caches/redis.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
"""Redis cache."""
|
||||||
|
from typing import Any, Union
|
||||||
|
|
||||||
|
import redis
|
||||||
|
|
||||||
|
from manifest.caches import Cache
|
||||||
|
|
||||||
|
|
||||||
|
class RedisCache(Cache):
|
||||||
|
"""A Redis cache for request/response pairs."""
|
||||||
|
|
||||||
|
def connect(self, connection_str: str, **kwargs: Any) -> None:
|
||||||
|
"""
|
||||||
|
Connect to client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connection_str: connection string.
|
||||||
|
"""
|
||||||
|
host, port = connection_str.split(":")
|
||||||
|
self.redis = redis.Redis(host=host, port=int(port))
|
||||||
|
return
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close the client."""
|
||||||
|
self.redis.close()
|
||||||
|
|
||||||
|
def get_key(self, key: str, table: str = "default") -> Union[str, None]:
|
||||||
|
"""
|
||||||
|
Get the key for a request.
|
||||||
|
|
||||||
|
With return None if key is not in cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: key for cache.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def set_key(self, key: str, value: str, table: str = "default") -> None:
|
||||||
|
"""
|
||||||
|
Set the value for the key.
|
||||||
|
|
||||||
|
Will override old value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: key for cache.
|
||||||
|
value: new value for key.
|
||||||
|
"""
|
||||||
|
self.redis[key] = value
|
||||||
|
|
||||||
|
def commit(self) -> None:
|
||||||
|
"""Commit any results."""
|
||||||
|
pass
|
79
manifest/caches/sqlite.py
Normal file
79
manifest/caches/sqlite.py
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
"""SQLite cache."""
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Union
|
||||||
|
|
||||||
|
from sqlitedict import SqliteDict
|
||||||
|
|
||||||
|
from manifest.caches import Cache
|
||||||
|
|
||||||
|
logging.getLogger("sqlitedict").setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
|
||||||
|
class SQLiteCache(Cache):
|
||||||
|
"""A SQLite cache for request/response pairs."""
|
||||||
|
|
||||||
|
def connect(self, connection_str: str, **kwargs: Any) -> None:
|
||||||
|
"""
|
||||||
|
Connect to client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connection_str: connection string.
|
||||||
|
"""
|
||||||
|
self.cache_dir = connection_str
|
||||||
|
Path(self.cache_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
# If more than two tables, switch to full on SQL connection
|
||||||
|
self.query_file = Path(self.cache_dir, "query.sqlite")
|
||||||
|
self.prompt_file = Path(self.cache_dir, "prompts.sqlite")
|
||||||
|
self.cache = SqliteDict(self.query_file, autocommit=False)
|
||||||
|
self.prompt_cache = SqliteDict(self.prompt_file, autocommit=False)
|
||||||
|
return
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close the client."""
|
||||||
|
self.cache.close()
|
||||||
|
|
||||||
|
def get_key(self, key: str, table: str = "default") -> Union[str, None]:
|
||||||
|
"""
|
||||||
|
Get the key for a request.
|
||||||
|
|
||||||
|
With return None if key is not in cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: key for cache.
|
||||||
|
table: table to get key in.
|
||||||
|
"""
|
||||||
|
if table == "prompt":
|
||||||
|
return self.prompt_cache.get(key)
|
||||||
|
else:
|
||||||
|
if table != "default":
|
||||||
|
raise ValueError(
|
||||||
|
"SQLiteDict only support table of `default` or `prompt`"
|
||||||
|
)
|
||||||
|
return self.cache.get(key)
|
||||||
|
|
||||||
|
def set_key(self, key: str, value: str, table: str = "default") -> None:
|
||||||
|
"""
|
||||||
|
Set the value for the key.
|
||||||
|
|
||||||
|
Will override old value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: key for cache.
|
||||||
|
value: new value for key.
|
||||||
|
table: table to set key in.
|
||||||
|
"""
|
||||||
|
if table == "prompt":
|
||||||
|
self.prompt_cache[key] = value
|
||||||
|
else:
|
||||||
|
if table != "default":
|
||||||
|
raise ValueError(
|
||||||
|
"SQLiteDict only support table of `default` or `prompt`"
|
||||||
|
)
|
||||||
|
self.cache[key] = value
|
||||||
|
self.commit()
|
||||||
|
|
||||||
|
def commit(self) -> None:
|
||||||
|
"""Commit any results."""
|
||||||
|
self.prompt_cache.commit()
|
||||||
|
self.cache.commit()
|
3
manifest/clients/__init__.py
Normal file
3
manifest/clients/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
"""Client init."""
|
||||||
|
from manifest.clients.client import Client
|
||||||
|
from manifest.clients.response import Response
|
58
manifest/clients/client.py
Normal file
58
manifest/clients/client.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
"""Client class."""
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Callable, Dict, Optional, Tuple
|
||||||
|
|
||||||
|
from manifest.clients import Response
|
||||||
|
|
||||||
|
|
||||||
|
class Client(ABC):
|
||||||
|
"""Client class."""
|
||||||
|
|
||||||
|
def __init__(self, connection_str: Optional[str] = None, **kwargs: Any):
|
||||||
|
"""
|
||||||
|
Initialize client.
|
||||||
|
|
||||||
|
kwargs are passed to client as default parameters.
|
||||||
|
|
||||||
|
For clients like OpenAI that do not require a connection,
|
||||||
|
the connection_str can be None.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connection_str: connection string for client.
|
||||||
|
"""
|
||||||
|
self.connect(connection_str, **kwargs)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close the client."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def connect(self, connection_str: str, **kwargs: Any) -> None:
|
||||||
|
"""
|
||||||
|
Connect to client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connection_str: connection string.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_request(
|
||||||
|
self, query: str, **kwargs: Any
|
||||||
|
) -> Tuple[Callable[[], Response], Dict]:
|
||||||
|
"""
|
||||||
|
Get request function.
|
||||||
|
|
||||||
|
kwargs override default parameters.
|
||||||
|
|
||||||
|
Calling the returned function will run the request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: query string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
request function that takes no input.
|
||||||
|
request parameters as dict.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
52
manifest/clients/dummy.py
Normal file
52
manifest/clients/dummy.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
"""Dummy client."""
|
||||||
|
import logging
|
||||||
|
from typing import Any, Callable, Dict, Optional, Tuple
|
||||||
|
|
||||||
|
from manifest.clients import Client
|
||||||
|
from manifest.clients.response import Response
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DummyClient(Client):
|
||||||
|
"""Dummy client."""
|
||||||
|
|
||||||
|
def connect(
|
||||||
|
self,
|
||||||
|
connection_str: Optional[str] = None,
|
||||||
|
num_results: Optional[int] = 1,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Connect to dummpy server.
|
||||||
|
|
||||||
|
This is a dummy client that returns identity responses. Used for testing.
|
||||||
|
"""
|
||||||
|
self.num_results = num_results
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close the client."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_request(
|
||||||
|
self, query: str, **kwargs: Any
|
||||||
|
) -> Tuple[Callable[[], Response], Dict]:
|
||||||
|
"""
|
||||||
|
Get request string function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: query string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
request function that takes no input.
|
||||||
|
request parameters as dict.
|
||||||
|
"""
|
||||||
|
request_params = {
|
||||||
|
"prompt": query,
|
||||||
|
"num_results": kwargs.get("num_results", self.num_results),
|
||||||
|
}
|
||||||
|
|
||||||
|
def _run_completion() -> Response:
|
||||||
|
return Response({"choices": [{"text": "hello"}] * self.num_results})
|
||||||
|
|
||||||
|
return _run_completion, request_params
|
95
manifest/clients/openai.py
Normal file
95
manifest/clients/openai.py
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
"""OpenAI client."""
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Any, Callable, Dict, Optional, Tuple
|
||||||
|
|
||||||
|
import openai
|
||||||
|
|
||||||
|
from manifest.clients import Response
|
||||||
|
from manifest.clients.client import Client
|
||||||
|
|
||||||
|
logging.getLogger("openai").setLevel(logging.WARNING)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
OPENAI_ENGINES = {
|
||||||
|
"text-davinci-002",
|
||||||
|
"text-curie-001",
|
||||||
|
"text-babbage-001",
|
||||||
|
"text-ada-001",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIClient(Client):
|
||||||
|
"""OpenAI client."""
|
||||||
|
|
||||||
|
def connect(
|
||||||
|
self,
|
||||||
|
connection_str: Optional[str] = None,
|
||||||
|
engine: Optional[str] = "text-ada-001",
|
||||||
|
temperature: Optional[float] = 0.0,
|
||||||
|
max_tokens: Optional[int] = 10,
|
||||||
|
top_p: Optional[int] = 1,
|
||||||
|
frequency_penalty: Optional[int] = 0,
|
||||||
|
presence_penalty: Optional[int] = 0,
|
||||||
|
n: Optional[int] = 1,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Connect to the OpenAI server.
|
||||||
|
|
||||||
|
connection_str is passed as default OPENAI_API_KEY if variable not set.
|
||||||
|
"""
|
||||||
|
openai.api_key = os.environ.get("OPENAI_API_KEY", connection_str)
|
||||||
|
if openai.api_key is None:
|
||||||
|
raise ValueError(
|
||||||
|
"OpenAI API key not set. Set OPENAI_API_KEY environment ",
|
||||||
|
"svariable or pass through `connection_str`.",
|
||||||
|
)
|
||||||
|
self.engine = engine
|
||||||
|
if self.engine not in OPENAI_ENGINES:
|
||||||
|
raise ValueError(f"Invalid engine {self.engine}. Must be {OPENAI_ENGINES}.")
|
||||||
|
self.temperature = temperature
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
self.top_p = top_p
|
||||||
|
self.frequency_penalty = frequency_penalty
|
||||||
|
self.presence_penalty = presence_penalty
|
||||||
|
self.n = n
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close the client."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_request(
|
||||||
|
self, query: str, **kwargs: Any
|
||||||
|
) -> Tuple[Callable[[], Response], Dict]:
|
||||||
|
"""
|
||||||
|
Get request string function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: query string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
request function that takes no input.
|
||||||
|
request parameters as dict.
|
||||||
|
"""
|
||||||
|
request_params = {
|
||||||
|
"engine": kwargs.get("engine", self.engine),
|
||||||
|
"prompt": query,
|
||||||
|
"temperature": kwargs.get("temperature", self.temperature),
|
||||||
|
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
|
||||||
|
"top_p": kwargs.get("top_p", self.top_p),
|
||||||
|
"frequency_penalty": kwargs.get(
|
||||||
|
"frequency_penalty", self.frequency_penalty
|
||||||
|
),
|
||||||
|
"presence_penalty": kwargs.get("presence_penalty", self.presence_penalty),
|
||||||
|
"n": kwargs.get("n", self.n),
|
||||||
|
}
|
||||||
|
|
||||||
|
def _run_completion() -> Response:
|
||||||
|
try:
|
||||||
|
return Response(openai.Completion.create(**request_params))
|
||||||
|
except openai.error.OpenAIError as e:
|
||||||
|
logger.error(e)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
return _run_completion, request_params
|
70
manifest/clients/response.py
Normal file
70
manifest/clients/response.py
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
"""Client response."""
|
||||||
|
import json
|
||||||
|
from typing import Dict, List, Union
|
||||||
|
|
||||||
|
|
||||||
|
class Response:
|
||||||
|
"""Response class."""
|
||||||
|
|
||||||
|
def __init__(self, response: Union[str, Dict]):
|
||||||
|
"""Initialize response."""
|
||||||
|
if isinstance(response, str):
|
||||||
|
self.response = json.loads(response)
|
||||||
|
elif isinstance(response, dict):
|
||||||
|
self.response = response
|
||||||
|
else:
|
||||||
|
raise ValueError("Response must be str or dict")
|
||||||
|
if ("choices" not in self.response) or (
|
||||||
|
not isinstance(self.response["choices"], list)
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Response must be serialized to a dict with a list of choices"
|
||||||
|
)
|
||||||
|
if len(self.response["choices"]) > 0:
|
||||||
|
if "text" not in self.response["choices"][0]:
|
||||||
|
raise ValueError(
|
||||||
|
"Response must be serialized to a dict with a ",
|
||||||
|
"list of choices with text field",
|
||||||
|
)
|
||||||
|
|
||||||
|
def __getitem__(self, key: str) -> str:
|
||||||
|
"""
|
||||||
|
Return the response given the key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: key to get.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
value of key.
|
||||||
|
"""
|
||||||
|
return self.response[key]
|
||||||
|
|
||||||
|
def get_results(self) -> Union[str, List[str]]:
|
||||||
|
"""Get all text results from response."""
|
||||||
|
if len(self.response["choices"]) == 0:
|
||||||
|
return None
|
||||||
|
if len(self.response["choices"]) == 1:
|
||||||
|
return self.response["choices"][0]["text"]
|
||||||
|
return [choice["text"] for choice in self.response["choices"]]
|
||||||
|
|
||||||
|
def serialize(self) -> str:
|
||||||
|
"""
|
||||||
|
Serialize response to string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
serialized response.
|
||||||
|
"""
|
||||||
|
return json.dumps(self.response, sort_keys=True)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def deserialize(cls, value: str) -> "Response":
|
||||||
|
"""
|
||||||
|
Deserialize string to response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: serialized response.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
serialized response.
|
||||||
|
"""
|
||||||
|
return Response(value)
|
138
manifest/manifest.py
Normal file
138
manifest/manifest.py
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
"""Manifest class."""
|
||||||
|
import logging
|
||||||
|
from typing import Any, Iterable, List, Optional, Union
|
||||||
|
|
||||||
|
logging.getLogger("openai").setLevel(logging.WARNING)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
from manifest import Prompt
|
||||||
|
from manifest.caches.redis import RedisCache
|
||||||
|
from manifest.caches.sqlite import SQLiteCache
|
||||||
|
from manifest.clients.dummy import DummyClient
|
||||||
|
from manifest.clients.openai import OpenAIClient
|
||||||
|
|
||||||
|
CLIENT_CONSTRUCTORS = {
|
||||||
|
"openai": OpenAIClient,
|
||||||
|
# "huggingface": manifest.clients.huggingface.HuggingFaceClient,
|
||||||
|
"dummy": DummyClient,
|
||||||
|
}
|
||||||
|
|
||||||
|
CACHE_CONSTRUCTORS = {
|
||||||
|
"redis": RedisCache,
|
||||||
|
"sqlite": SQLiteCache,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Manifest:
|
||||||
|
"""Manifest session object."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
client_name: str = "openai",
|
||||||
|
client_connection: Optional[str] = None,
|
||||||
|
cache_name: str = "redis",
|
||||||
|
cache_connection: str = "localhost:6379",
|
||||||
|
**kwargs: Any,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize manifest.
|
||||||
|
|
||||||
|
Remaining kwargs sent to client and cache.
|
||||||
|
"""
|
||||||
|
if client_name not in CLIENT_CONSTRUCTORS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown client name: {client_name}. "
|
||||||
|
f"Choices are {list(CLIENT_CONSTRUCTORS.keys())}"
|
||||||
|
)
|
||||||
|
if cache_name not in CACHE_CONSTRUCTORS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown cache name: {cache_name}. "
|
||||||
|
f"Choices are {list(CACHE_CONSTRUCTORS.keys())}"
|
||||||
|
)
|
||||||
|
self.client_name = client_name
|
||||||
|
self.client = CLIENT_CONSTRUCTORS[client_name](client_connection, **kwargs)
|
||||||
|
self.cache = CACHE_CONSTRUCTORS[cache_name](cache_connection, **kwargs)
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close the client and cache."""
|
||||||
|
self.client.close()
|
||||||
|
self.cache.close()
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
prompt: Prompt,
|
||||||
|
input: Optional[Any] = None,
|
||||||
|
overwrite_cache: bool = False,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Union[str, List[str]]:
|
||||||
|
"""
|
||||||
|
Run the prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: prompt to run.
|
||||||
|
input: input to prompt.
|
||||||
|
overwrite_cache: whether to overwrite cache.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
response from prompt.
|
||||||
|
"""
|
||||||
|
prompt_str = prompt(input)
|
||||||
|
possible_request, full_kwargs = self.client.get_request(prompt_str, **kwargs)
|
||||||
|
# Create cacke key
|
||||||
|
cache_key = full_kwargs.copy()
|
||||||
|
# Make query model dependent
|
||||||
|
cache_key["client_name"] = self.client_name
|
||||||
|
# Make query prompt dependent
|
||||||
|
cache_key["prompt"] = prompt_str
|
||||||
|
response, _ = self.cache.get(cache_key, overwrite_cache, possible_request)
|
||||||
|
return response.get_results()
|
||||||
|
|
||||||
|
def run_batch(
|
||||||
|
self,
|
||||||
|
prompt: Prompt,
|
||||||
|
input: Optional[Iterable[Any]] = None,
|
||||||
|
overwrite_cache: bool = False,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterable[Union[str, List[str]]]:
|
||||||
|
"""
|
||||||
|
Run the prompt on a batch of inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: prompt to run.
|
||||||
|
input: batch of inputs.
|
||||||
|
overwrite_cache: whether to overwrite cache.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
batch of responses.
|
||||||
|
"""
|
||||||
|
if input is None:
|
||||||
|
input = [None]
|
||||||
|
return [self.run(prompt, inp, overwrite_cache, **kwargs) for inp in input]
|
||||||
|
|
||||||
|
def save_prompt(self, name: str, prompt: Prompt) -> None:
|
||||||
|
"""
|
||||||
|
Save the prompt to the cache for long term storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: name of prompt.
|
||||||
|
prompt: prompt to save.
|
||||||
|
"""
|
||||||
|
self.cache.set_key(name, prompt.serialize(), table="prompt")
|
||||||
|
|
||||||
|
def load_prompt(self, name: str) -> Prompt:
|
||||||
|
"""
|
||||||
|
Load the prompt from the cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: name of prompt.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Prompt saved with name.
|
||||||
|
"""
|
||||||
|
return Prompt.deserialize(self.cache.get_key(name, table="prompt"))
|
||||||
|
|
||||||
|
def open_explorer(self) -> None:
|
||||||
|
"""Open the explorer for jupyter widget."""
|
||||||
|
# Open explorer
|
||||||
|
# TODO: implement
|
||||||
|
pass
|
72
manifest/prompt.py
Normal file
72
manifest/prompt.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
"""Prompt class."""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
from typing import Any, Callable, List, Optional, Union
|
||||||
|
|
||||||
|
logging.getLogger("openai").setLevel(logging.WARNING)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Prompt:
|
||||||
|
"""Prompt class."""
|
||||||
|
|
||||||
|
def __init__(self, prompt_obj: Union[str, Callable, "Prompt", List["Prompt"]]):
|
||||||
|
"""
|
||||||
|
Initialize prompt.
|
||||||
|
|
||||||
|
If prompt_obj is a string, it will be cast as function.
|
||||||
|
If prompt_obj is list of promts, it will be composed.
|
||||||
|
"""
|
||||||
|
# TODO: figure out how to compose prompts to keep the
|
||||||
|
# interface simple? Can we make a function
|
||||||
|
# such that a single call will run the composition?
|
||||||
|
if isinstance(prompt_obj, str):
|
||||||
|
self.prompt_func = lambda: prompt_obj
|
||||||
|
elif callable(prompt_obj):
|
||||||
|
self.prompt_func = prompt_obj
|
||||||
|
else:
|
||||||
|
# TODO: implement
|
||||||
|
raise NotImplementedError()
|
||||||
|
self.num_args = len(inspect.signature(self.prompt_func).parameters)
|
||||||
|
if self.num_args > 1:
|
||||||
|
raise ValueError("Prompts must have zero or one input.")
|
||||||
|
|
||||||
|
def __call__(self, input: Optional[Any] = None) -> str:
|
||||||
|
"""
|
||||||
|
Return the prompt given the inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input: input to prompt.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
prompt string.
|
||||||
|
"""
|
||||||
|
if self.num_args >= 1:
|
||||||
|
return self.prompt_func(input) # type: ignore
|
||||||
|
else:
|
||||||
|
return self.prompt_func()
|
||||||
|
|
||||||
|
def serialize(self) -> str:
|
||||||
|
"""
|
||||||
|
Return the prompt as str.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
json object.
|
||||||
|
"""
|
||||||
|
# TODO: implement
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def deserialize(cls, obj: str) -> "Prompt":
|
||||||
|
"""
|
||||||
|
Return the prompt from a json object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj: json object.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
prompt.
|
||||||
|
"""
|
||||||
|
# TODO: implement
|
||||||
|
pass
|
1418
poetry.lock
generated
Normal file
1418
poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
69
pyproject.toml
Normal file
69
pyproject.toml
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
[tool.poetry]
|
||||||
|
authors = ["Laurel Orr <lorr1@cs.stanford.edu>, Avanika Narayan <avanikan@stanford.edu>"]
|
||||||
|
classifiers = [
|
||||||
|
"Programming Language :: Python",
|
||||||
|
"Programming Language :: Python :: 3",
|
||||||
|
"Programming Language :: Python :: 3.8",
|
||||||
|
"Programming Language :: Python :: 3.9",
|
||||||
|
"Programming Language :: Python :: 3.10"
|
||||||
|
]
|
||||||
|
description = "Manifest for Prompt Programming"
|
||||||
|
name = "manifest"
|
||||||
|
repository = "https://github.com/HazyResearch/manifest"
|
||||||
|
version = "0.0.1"
|
||||||
|
|
||||||
|
[tool.poetry.urls]
|
||||||
|
"Bug Tracker" = "https://github.com/HazyResearch/manifest/issues"
|
||||||
|
|
||||||
|
[tool.poetry.dependencies]
|
||||||
|
python = "^3.8"
|
||||||
|
sqlitedict = "^2.0.0"
|
||||||
|
openai = "^0.18.1"
|
||||||
|
redis = "^4.3.1"
|
||||||
|
|
||||||
|
[tool.poetry.dev-dependencies]
|
||||||
|
black = "^22.3.0"
|
||||||
|
flake8 = "^4.0.0"
|
||||||
|
flake8-docstrings = "^1.6.0"
|
||||||
|
isort = "^5.9.3"
|
||||||
|
mypy = "^0.950"
|
||||||
|
pep8-naming = "^0.12.1"
|
||||||
|
pre-commit = "^2.14.0"
|
||||||
|
pytest = "^7.0.0"
|
||||||
|
pytest-cov = "^3.0.0"
|
||||||
|
python-dotenv = "^0.20.0"
|
||||||
|
recommonmark = "^0.7.1"
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
requires = ["poetry-core>=1.0.0"]
|
||||||
|
|
||||||
|
# Additional Tool Configurations
|
||||||
|
[tool.mypy]
|
||||||
|
disallow_untyped_defs = true
|
||||||
|
strict_optional = false
|
||||||
|
|
||||||
|
[[tool.mypy.overrides]]
|
||||||
|
ignore_missing_imports = true
|
||||||
|
module = [
|
||||||
|
"numpy",
|
||||||
|
"tqdm",
|
||||||
|
"sqlitedict",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.isort]
|
||||||
|
combine_as_imports = true
|
||||||
|
force_grid_wrap = 0
|
||||||
|
include_trailing_comma = true
|
||||||
|
known_first_party = ["manifest"]
|
||||||
|
line_length = 88
|
||||||
|
multi_line_output = 3
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
log_format = "[%(levelname)s] %(message)s"
|
||||||
|
log_date_format = "%Y-%m-%d %H:%M:%S"
|
||||||
|
addopts = "-v -rsXx"
|
||||||
|
# The following options are useful for local debugging
|
||||||
|
# addopts = "-v -rsXx -s -x --pdb"
|
||||||
|
# log_cli_level = "DEBUG"
|
||||||
|
# log_cli = true
|
1
tests/clients/test_client.py
Normal file
1
tests/clients/test_client.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
"""Test client."""
|
59
tests/clients/test_response.py
Normal file
59
tests/clients/test_response.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
"""Response test."""
|
||||||
|
import json
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from manifest.clients import Response
|
||||||
|
|
||||||
|
|
||||||
|
def test_init():
|
||||||
|
"""Test response initialization."""
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
response = Response(4)
|
||||||
|
assert str(exc_info.value) == "Response must be str or dict"
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
response = Response({"test": "hello"})
|
||||||
|
assert (
|
||||||
|
str(exc_info.value)
|
||||||
|
== "Response must be serialized to a dict with a list of choices"
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
response = Response({"choices": [{"blah": "hello"}]})
|
||||||
|
assert str(exc_info.value) == (
|
||||||
|
"Response must be serialized to a dict ",
|
||||||
|
"with a list of choices with text field",
|
||||||
|
)
|
||||||
|
|
||||||
|
response = Response({"choices": [{"text": "hello"}]})
|
||||||
|
assert response.response == {"choices": [{"text": "hello"}]}
|
||||||
|
|
||||||
|
response = Response(json.dumps({"choices": [{"text": "hello"}]}))
|
||||||
|
assert response.response == {"choices": [{"text": "hello"}]}
|
||||||
|
|
||||||
|
|
||||||
|
def test_getitem():
|
||||||
|
"""Test response getitem."""
|
||||||
|
response = Response({"choices": [{"text": "hello"}]})
|
||||||
|
assert response["choices"] == [{"text": "hello"}]
|
||||||
|
|
||||||
|
|
||||||
|
def test_serialize():
|
||||||
|
"""Test response serialization."""
|
||||||
|
response = Response({"choices": [{"text": "hello"}]})
|
||||||
|
assert Response.deserialize(response.serialize()).response == {
|
||||||
|
"choices": [{"text": "hello"}]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_results():
|
||||||
|
"""Test response get results."""
|
||||||
|
response = Response({"choices": []})
|
||||||
|
assert response.get_results() is None
|
||||||
|
|
||||||
|
response = Response({"choices": [{"text": "hello"}]})
|
||||||
|
assert response.get_results() == "hello"
|
||||||
|
|
||||||
|
response = Response(
|
||||||
|
{"choices": [{"text": "hello"}, {"text": "my"}, {"text": "name"}]}
|
||||||
|
)
|
||||||
|
assert response.get_results() == ["hello", "my", "name"]
|
36
tests/conftest.py
Normal file
36
tests/conftest.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
"""Setup for all tests."""
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sqlite_cache(tmp_path):
|
||||||
|
"""Sqlite Cache."""
|
||||||
|
cache = str(tmp_path / "sqlite_cache.sqlite")
|
||||||
|
yield cache
|
||||||
|
shutil.rmtree(cache, ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def redis_cache():
|
||||||
|
"""Redis cache."""
|
||||||
|
if "CI" not in os.environ:
|
||||||
|
# Give a clear warning on setting REDIS_PORT before running tests.
|
||||||
|
try:
|
||||||
|
port = os.environ["REDIS_PORT"]
|
||||||
|
except KeyError:
|
||||||
|
raise KeyError(
|
||||||
|
"Set REDIS_PORT env var to the instance you want to use "
|
||||||
|
+ "for testing. Note that doing so WILL delete the db at "
|
||||||
|
+ "localhost:REDIS_PORT, db=0, so BE CAREFUL."
|
||||||
|
)
|
||||||
|
host = os.environ.get("REDIS_HOST", "localhost")
|
||||||
|
else:
|
||||||
|
host = os.environ.get("REDIS_HOST", "localhost")
|
||||||
|
port = os.environ.get("REDIS_PORT", 6379)
|
||||||
|
yield f"{host}:{port}"
|
||||||
|
# Clear out the database
|
||||||
|
# db = redis.Redis(host=host, port=port)
|
||||||
|
# db.flushdb()
|
69
tests/test_cache.py
Normal file
69
tests/test_cache.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
"""Cache test."""
|
||||||
|
import pytest
|
||||||
|
from sqlitedict import SqliteDict
|
||||||
|
|
||||||
|
from manifest.caches.redis import RedisCache
|
||||||
|
from manifest.caches.sqlite import SQLiteCache
|
||||||
|
from manifest.clients import Response
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("sqlite_cache")
|
||||||
|
@pytest.mark.usefixtures("redis_cache")
|
||||||
|
@pytest.mark.parametrize("cache_type", ["sqlite"])
|
||||||
|
def test_init(sqlite_cache, redis_cache, cache_type):
|
||||||
|
"""Test cache initialization."""
|
||||||
|
if cache_type == "sqlite":
|
||||||
|
cache = SQLiteCache(sqlite_cache)
|
||||||
|
assert isinstance(cache.cache, SqliteDict)
|
||||||
|
assert isinstance(cache.prompt_cache, SqliteDict)
|
||||||
|
else:
|
||||||
|
cache = RedisCache(redis_cache)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("sqlite_cache")
|
||||||
|
@pytest.mark.usefixtures("redis_cache")
|
||||||
|
@pytest.mark.parametrize("cache_type", ["sqlite"])
|
||||||
|
def test_key_get_and_set(sqlite_cache, redis_cache, cache_type):
|
||||||
|
"""Test cache key get and set."""
|
||||||
|
if cache_type == "sqlite":
|
||||||
|
cache = SQLiteCache(sqlite_cache)
|
||||||
|
else:
|
||||||
|
cache = RedisCache(redis_cache)
|
||||||
|
|
||||||
|
cache.set_key("test", "valueA")
|
||||||
|
cache.set_key("testA", "valueB")
|
||||||
|
|
||||||
|
assert cache.get_key("test") == "valueA"
|
||||||
|
assert cache.get_key("testA") == "valueB"
|
||||||
|
|
||||||
|
cache.set_key("testA", "valueC")
|
||||||
|
assert cache.get_key("testA") == "valueC"
|
||||||
|
|
||||||
|
cache.get_key("test", table="prompt") is None
|
||||||
|
cache.set_key("test", "valueA", table="prompt")
|
||||||
|
cache.get_key("test", table="prompt") == "valueA"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("sqlite_cache")
|
||||||
|
@pytest.mark.usefixtures("redis_cache")
|
||||||
|
@pytest.mark.parametrize("cache_type", ["sqlite"])
|
||||||
|
def test_get(sqlite_cache, redis_cache, cache_type):
|
||||||
|
"""Test cache save prompt."""
|
||||||
|
if cache_type == "sqlite":
|
||||||
|
cache = SQLiteCache(sqlite_cache)
|
||||||
|
else:
|
||||||
|
cache = RedisCache(redis_cache)
|
||||||
|
test_request = {"test": "hello", "testA": "world"}
|
||||||
|
compute = lambda: Response({"choices": [{"text": "hello"}]})
|
||||||
|
|
||||||
|
response, cached = cache.get(test_request, overwrite_cache=False, compute=compute)
|
||||||
|
assert response.get_results() == "hello"
|
||||||
|
assert not cached
|
||||||
|
|
||||||
|
response, cached = cache.get(test_request, overwrite_cache=False, compute=compute)
|
||||||
|
assert response.get_results() == "hello"
|
||||||
|
assert cached
|
||||||
|
|
||||||
|
response, cached = cache.get(test_request, overwrite_cache=True, compute=compute)
|
||||||
|
assert response.get_results() == "hello"
|
||||||
|
assert not cached
|
105
tests/test_manifest.py
Normal file
105
tests/test_manifest.py
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
"""Manifest test."""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from manifest import Manifest, Prompt
|
||||||
|
from manifest.caches.cache import request_to_key
|
||||||
|
from manifest.caches.sqlite import SQLiteCache
|
||||||
|
from manifest.clients.dummy import DummyClient
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("sqlite_cache")
|
||||||
|
def test_init(sqlite_cache):
|
||||||
|
"""Test manifest initialization."""
|
||||||
|
manifest = Manifest(
|
||||||
|
client_name="dummy",
|
||||||
|
cache_name="sqlite",
|
||||||
|
cache_connection=sqlite_cache,
|
||||||
|
)
|
||||||
|
assert manifest.client_name == "dummy"
|
||||||
|
assert isinstance(manifest.client, DummyClient)
|
||||||
|
assert isinstance(manifest.cache, SQLiteCache)
|
||||||
|
|
||||||
|
manifest = Manifest(
|
||||||
|
client_name="dummy",
|
||||||
|
cache_name="sqlite",
|
||||||
|
cache_connection=sqlite_cache,
|
||||||
|
num_results=3,
|
||||||
|
)
|
||||||
|
assert manifest.client_name == "dummy"
|
||||||
|
assert isinstance(manifest.client, DummyClient)
|
||||||
|
assert isinstance(manifest.cache, SQLiteCache)
|
||||||
|
assert manifest.client.num_results == 3
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("sqlite_cache")
|
||||||
|
@pytest.mark.parametrize("num_results", [1, 2])
|
||||||
|
def test_run(sqlite_cache, num_results):
|
||||||
|
"""Test manifest run."""
|
||||||
|
manifest = Manifest(
|
||||||
|
client_name="dummy",
|
||||||
|
cache_name="sqlite",
|
||||||
|
cache_connection=sqlite_cache,
|
||||||
|
num_results=num_results,
|
||||||
|
)
|
||||||
|
prompt = Prompt("This is a prompt")
|
||||||
|
res = manifest.run(prompt)
|
||||||
|
assert (
|
||||||
|
manifest.cache.get_key(
|
||||||
|
request_to_key(
|
||||||
|
{
|
||||||
|
"prompt": "This is a prompt",
|
||||||
|
"client_name": "dummy",
|
||||||
|
"num_results": num_results,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
is not None
|
||||||
|
)
|
||||||
|
if num_results == 1:
|
||||||
|
assert res == "hello"
|
||||||
|
else:
|
||||||
|
assert res == ["hello", "hello"]
|
||||||
|
|
||||||
|
prompt = Prompt(lambda x: f"{x} is a prompt")
|
||||||
|
res = manifest.run(prompt, "Hello")
|
||||||
|
assert (
|
||||||
|
manifest.cache.get_key(
|
||||||
|
request_to_key(
|
||||||
|
{
|
||||||
|
"prompt": "Hello is a prompt",
|
||||||
|
"client_name": "dummy",
|
||||||
|
"num_results": num_results,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
is not None
|
||||||
|
)
|
||||||
|
if num_results == 1:
|
||||||
|
assert res == "hello"
|
||||||
|
else:
|
||||||
|
assert res == ["hello", "hello"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("sqlite_cache")
|
||||||
|
@pytest.mark.parametrize("num_results", [1, 2])
|
||||||
|
def test_batch_run(sqlite_cache, num_results):
|
||||||
|
"""Test manifest run."""
|
||||||
|
manifest = Manifest(
|
||||||
|
client_name="dummy",
|
||||||
|
cache_name="sqlite",
|
||||||
|
cache_connection=sqlite_cache,
|
||||||
|
num_results=num_results,
|
||||||
|
)
|
||||||
|
prompt = Prompt("This is a prompt")
|
||||||
|
res = manifest.run_batch(prompt)
|
||||||
|
if num_results == 1:
|
||||||
|
assert res == ["hello"]
|
||||||
|
else:
|
||||||
|
assert res == [["hello", "hello"]]
|
||||||
|
|
||||||
|
prompt = Prompt(lambda x: f"{x} is a prompt")
|
||||||
|
res = manifest.run_batch(prompt, ["Hello", "Hello"])
|
||||||
|
if num_results == 1:
|
||||||
|
assert res == ["hello", "hello"]
|
||||||
|
else:
|
||||||
|
assert res == [["hello", "hello"], ["hello", "hello"]]
|
54
tests/test_prompt.py
Normal file
54
tests/test_prompt.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
"""Prompt test."""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from manifest import Prompt
|
||||||
|
|
||||||
|
|
||||||
|
def test_init():
|
||||||
|
"""Test prompt initialization."""
|
||||||
|
str_prompt = "This is a test prompt"
|
||||||
|
func_prompt = lambda: "This is a test prompt"
|
||||||
|
func_single_prompt = lambda x: f"{x} is a test prompt"
|
||||||
|
func_list_prompt = lambda x: f"{x[0]} is a test {x[1]}"
|
||||||
|
func_double_prompt = lambda x, y: f"{x} is a test {y}"
|
||||||
|
# TODO: add list of prompt tests
|
||||||
|
|
||||||
|
# String prompt
|
||||||
|
prompt = Prompt(str_prompt)
|
||||||
|
assert prompt(None) == str_prompt
|
||||||
|
assert prompt() == str_prompt
|
||||||
|
|
||||||
|
# Function no inputs
|
||||||
|
prompt = Prompt(func_prompt)
|
||||||
|
assert prompt(None) == str_prompt
|
||||||
|
assert prompt() == str_prompt
|
||||||
|
|
||||||
|
# Function single inputs
|
||||||
|
prompt = Prompt(func_single_prompt)
|
||||||
|
assert prompt("This") == str_prompt
|
||||||
|
assert prompt("Hello") == "Hello is a test prompt"
|
||||||
|
|
||||||
|
# Function list inputs
|
||||||
|
prompt = Prompt(func_list_prompt)
|
||||||
|
assert prompt(["This", "prompt"]) == str_prompt
|
||||||
|
assert prompt(["Hello", "prompt"]) == "Hello is a test prompt"
|
||||||
|
|
||||||
|
# Function two inputs
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
Prompt(func_double_prompt)
|
||||||
|
assert str(exc_info.value) == "Prompts must have zero or one input."
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="Not implemented")
|
||||||
|
def test_serialize():
|
||||||
|
"""Test prompt serialization."""
|
||||||
|
str_prompt = "This is a test prompt"
|
||||||
|
func_single_prompt = lambda x: f"{x} is a test prompt"
|
||||||
|
|
||||||
|
# String prompt
|
||||||
|
prompt = Prompt(str_prompt)
|
||||||
|
assert Prompt.deserialize(prompt.serialize()) == prompt
|
||||||
|
|
||||||
|
# Function single inputs
|
||||||
|
prompt = Prompt(func_single_prompt)
|
||||||
|
assert Prompt.deserialize(prompt.serialize()) == prompt
|
Loading…
Reference in New Issue
Block a user