mirror of
https://github.com/HazyResearch/manifest
synced 2024-10-31 15:20:26 +00:00
First main commit
This commit is contained in:
parent
6ada1f2d2c
commit
ed301be2a7
10
.flake8
Normal file
10
.flake8
Normal file
@ -0,0 +1,10 @@
|
||||
# This is our code-style check. We currently allow the following exceptions:
|
||||
# - E731: do not assign a lambda expression, use a def
|
||||
# - E402: module level import not at top of file
|
||||
# - W503: line break before binary operator
|
||||
|
||||
[flake8]
|
||||
exclude = .git
|
||||
max-line-length = 88
|
||||
ignore = E731, E402, W503
|
||||
per-file-ignores = __init__.py:F401
|
28
.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
28
.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
@ -0,0 +1,28 @@
|
||||
---
|
||||
name: Bug report
|
||||
about: Create a report to help us improve
|
||||
|
||||
---
|
||||
|
||||
## Description of the bug
|
||||
A clear and concise description of what the bug is.
|
||||
|
||||
## To Reproduce
|
||||
Steps to reproduce the behavior:
|
||||
1. Go to '...'
|
||||
2. Click on '....'
|
||||
3. Execute the code '...'
|
||||
If necessary, attach example data which can be used to replicate the issue.
|
||||
|
||||
## Expected behavior
|
||||
A clear and concise description of what you expected to happen.
|
||||
|
||||
## Error Logs/Screenshots
|
||||
If applicable, add error logs or screenshots to help explain your problem.
|
||||
|
||||
## Environment (please complete the following information)
|
||||
- OS: [e.g. Ubuntu 18.04]
|
||||
- bootleg Version: [e.g. 0.6.0]
|
||||
|
||||
## Additional context
|
||||
Add any other context about the problem here.
|
19
.github/ISSUE_TEMPLATE/feature_request.md
vendored
Normal file
19
.github/ISSUE_TEMPLATE/feature_request.md
vendored
Normal file
@ -0,0 +1,19 @@
|
||||
---
|
||||
name: Feature request
|
||||
about: Suggest an idea for this project
|
||||
|
||||
---
|
||||
|
||||
## Description of the feature request
|
||||
|
||||
**Is your feature request related to a problem? Please describe.**
|
||||
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
|
||||
|
||||
## Description of the solution you'd like
|
||||
A clear and concise description of what you want to happen.
|
||||
|
||||
## Description of the alternatives you've considered
|
||||
A clear and concise description of any alternative solutions or features you've considered.
|
||||
|
||||
## Additional context
|
||||
Add any other context or screenshots about the feature request here.
|
65
.github/workflows/ci.yaml
vendored
Normal file
65
.github/workflows/ci.yaml
vendored
Normal file
@ -0,0 +1,65 @@
|
||||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
|
||||
# Allows you to run this workflow manually from the Actions tab
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ${{ matrix.os }}
|
||||
timeout-minutes: 30
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
python-version: ["3.8", "3.9", "3.10"]
|
||||
services:
|
||||
# Label used to access the service container
|
||||
redis:
|
||||
# Docker Hub image
|
||||
image: redislabs/redis
|
||||
# Set health checks to wait until redis has started
|
||||
options: >-
|
||||
--health-cmd "redis-cli ping"
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
ports:
|
||||
# Maps port 6379 on service container to the host
|
||||
- 6379:6379
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install Poetry
|
||||
uses: snok/install-poetry@v1
|
||||
with:
|
||||
virtualenvs-in-project: true
|
||||
virtualenvs-create: true
|
||||
installer-parallel: true
|
||||
- name: Load cached venv if cache exists
|
||||
id: cached-poetry-deps
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: .venv
|
||||
key: venv-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }}
|
||||
- name: Install dependencies
|
||||
if: steps.cached-poetry-deps.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
poetry install --no-interaction --no-root
|
||||
- name: Install Manifest
|
||||
run: |
|
||||
make dev
|
||||
- name: Run preliminary checks
|
||||
run: |
|
||||
make check
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
poetry run pytest tests
|
120
.gitignore
vendored
Normal file
120
.gitignore
vendored
Normal file
@ -0,0 +1,120 @@
|
||||
runs/*
|
||||
|
||||
*._*
|
||||
# Pickle Saved
|
||||
*.pt
|
||||
*.pk
|
||||
**/*.pt
|
||||
**/*.pk
|
||||
|
||||
# PyCharm
|
||||
*.idea
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# celery beat schedule file
|
||||
celerybeat-schedule
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
|
||||
*.tsv
|
||||
*.7z
|
||||
.DS_Store
|
23
.pre-commit-config.yaml
Normal file
23
.pre-commit-config.yaml
Normal file
@ -0,0 +1,23 @@
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v3.2.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: end-of-file-fixer
|
||||
- id: check-yaml
|
||||
- id: check-toml
|
||||
- id: check-merge-conflict
|
||||
- id: check-added-large-files
|
||||
- repo: https://github.com/timothycrosley/isort
|
||||
rev: 5.9.3
|
||||
hooks:
|
||||
- id: isort
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 22.3.0
|
||||
hooks:
|
||||
- id: black
|
||||
language_version: python3
|
||||
- repo: https://gitlab.com/pycqa/flake8
|
||||
rev: 3.9.2
|
||||
hooks:
|
||||
- id: flake8
|
28
Makefile
Normal file
28
Makefile
Normal file
@ -0,0 +1,28 @@
|
||||
dev:
|
||||
poetry install
|
||||
poetry run pre-commit install
|
||||
|
||||
test: dev check
|
||||
poetry install
|
||||
poetry run pytest tests
|
||||
|
||||
format:
|
||||
isort --atomic manifest/ tests/
|
||||
black manifest/ tests/
|
||||
|
||||
check:
|
||||
isort -c manifest/ tests/
|
||||
black manifest/ tests/ --check
|
||||
flake8 manifest/ tests/
|
||||
mypy manifest/
|
||||
|
||||
clean:
|
||||
pip uninstall -y manifest
|
||||
rm -rf src/manifest.egg-info
|
||||
rm -rf build/ dist/
|
||||
|
||||
prune:
|
||||
@bash -c "git fetch -p";
|
||||
@bash -c "for branch in $(git branch -vv | grep ': gone]' | awk '{print $1}'); do git branch -d $branch; done";
|
||||
|
||||
.PHONY: dev test clean check prune
|
19
README.md
19
README.md
@ -1,2 +1,21 @@
|
||||
# manifest
|
||||
Prompt programming with FMs.
|
||||
|
||||
# Install
|
||||
Download the code:
|
||||
```
|
||||
git clone git@github.com:HazyResearch/manifest.git
|
||||
cd manifest
|
||||
```
|
||||
|
||||
Install:
|
||||
```
|
||||
pip install poetry
|
||||
poetry install
|
||||
poetry run pre-commit install
|
||||
```
|
||||
or
|
||||
```
|
||||
pip install poetry
|
||||
make dev
|
||||
```
|
3
manifest/__init__.py
Normal file
3
manifest/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
"""Manifest init."""
|
||||
from manifest.manifest import Manifest
|
||||
from manifest.prompt import Prompt
|
1
manifest/api/app.py
Normal file
1
manifest/api/app.py
Normal file
@ -0,0 +1 @@
|
||||
"""Flask app."""
|
1
manifest/api/models/huggingface_model.py
Normal file
1
manifest/api/models/huggingface_model.py
Normal file
@ -0,0 +1 @@
|
||||
"""Huggingface model."""
|
1
manifest/api/models/model.py
Normal file
1
manifest/api/models/model.py
Normal file
@ -0,0 +1 @@
|
||||
"""Model class."""
|
2
manifest/caches/__init__.py
Normal file
2
manifest/caches/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
"""Cache init."""
|
||||
from manifest.caches.cache import Cache
|
113
manifest/caches/cache.py
Normal file
113
manifest/caches/cache.py
Normal file
@ -0,0 +1,113 @@
|
||||
"""Cache for queries and responses."""
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, Tuple, Union
|
||||
|
||||
from manifest.clients.response import Response
|
||||
|
||||
|
||||
def request_to_key(request: Dict) -> str:
|
||||
"""
|
||||
Normalize a request into a key.
|
||||
|
||||
Args:
|
||||
request: request to normalize.
|
||||
|
||||
Returns:
|
||||
normalized key.
|
||||
"""
|
||||
return json.dumps(request, sort_keys=True)
|
||||
|
||||
|
||||
def key_to_request(key: str) -> Dict:
|
||||
"""
|
||||
Convert the normalized version to the request.
|
||||
|
||||
Args:
|
||||
key: normalized key to convert.
|
||||
|
||||
Returns:
|
||||
unnormalized request dict.
|
||||
"""
|
||||
return json.loads(key)
|
||||
|
||||
|
||||
class Cache(ABC):
|
||||
"""A cache for request/response pairs."""
|
||||
|
||||
def __init__(self, connection_str: str, **kwargs: Any):
|
||||
"""
|
||||
Initialize client.
|
||||
|
||||
kwargs are passed to client as default parameters.
|
||||
|
||||
For clients like OpenAI that do not require a connection,
|
||||
the connection_str can be None.
|
||||
|
||||
Args:
|
||||
connection_str: connection string for client.
|
||||
"""
|
||||
self.connect(connection_str, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def close(self) -> None:
|
||||
"""Close the client."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def connect(self, connection_str: str, **kwargs: Any) -> None:
|
||||
"""
|
||||
Connect to client.
|
||||
|
||||
Args:
|
||||
connection_str: connection string.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def get_key(self, key: str, table: str = "default") -> Union[str, None]:
|
||||
"""
|
||||
Get the key for a request.
|
||||
|
||||
With return None if key is not in cache.
|
||||
|
||||
Args:
|
||||
key: key for cache.
|
||||
table: table to get key in.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def set_key(self, key: str, value: str, table: str = "default") -> None:
|
||||
"""
|
||||
Set the value for the key.
|
||||
|
||||
Will override old value.
|
||||
|
||||
Args:
|
||||
key: key for cache.
|
||||
value: new value for key.
|
||||
table: table to set key in.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def commit(self) -> None:
|
||||
"""Commit any results."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get(
|
||||
self, request: Dict, overwrite_cache: bool, compute: Callable[[], Response]
|
||||
) -> Tuple[Response, bool]:
|
||||
"""Get the result of request (by calling compute as needed)."""
|
||||
key = request_to_key(request)
|
||||
cached_response = self.get_key(key)
|
||||
if cached_response and not overwrite_cache:
|
||||
cached = True
|
||||
response = Response.deserialize(cached_response)
|
||||
else:
|
||||
# Type Response
|
||||
response = compute()
|
||||
self.set_key(key, response.serialize())
|
||||
cached = False
|
||||
return response, cached
|
52
manifest/caches/redis.py
Normal file
52
manifest/caches/redis.py
Normal file
@ -0,0 +1,52 @@
|
||||
"""Redis cache."""
|
||||
from typing import Any, Union
|
||||
|
||||
import redis
|
||||
|
||||
from manifest.caches import Cache
|
||||
|
||||
|
||||
class RedisCache(Cache):
|
||||
"""A Redis cache for request/response pairs."""
|
||||
|
||||
def connect(self, connection_str: str, **kwargs: Any) -> None:
|
||||
"""
|
||||
Connect to client.
|
||||
|
||||
Args:
|
||||
connection_str: connection string.
|
||||
"""
|
||||
host, port = connection_str.split(":")
|
||||
self.redis = redis.Redis(host=host, port=int(port))
|
||||
return
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the client."""
|
||||
self.redis.close()
|
||||
|
||||
def get_key(self, key: str, table: str = "default") -> Union[str, None]:
|
||||
"""
|
||||
Get the key for a request.
|
||||
|
||||
With return None if key is not in cache.
|
||||
|
||||
Args:
|
||||
key: key for cache.
|
||||
"""
|
||||
pass
|
||||
|
||||
def set_key(self, key: str, value: str, table: str = "default") -> None:
|
||||
"""
|
||||
Set the value for the key.
|
||||
|
||||
Will override old value.
|
||||
|
||||
Args:
|
||||
key: key for cache.
|
||||
value: new value for key.
|
||||
"""
|
||||
self.redis[key] = value
|
||||
|
||||
def commit(self) -> None:
|
||||
"""Commit any results."""
|
||||
pass
|
79
manifest/caches/sqlite.py
Normal file
79
manifest/caches/sqlite.py
Normal file
@ -0,0 +1,79 @@
|
||||
"""SQLite cache."""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Union
|
||||
|
||||
from sqlitedict import SqliteDict
|
||||
|
||||
from manifest.caches import Cache
|
||||
|
||||
logging.getLogger("sqlitedict").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
class SQLiteCache(Cache):
|
||||
"""A SQLite cache for request/response pairs."""
|
||||
|
||||
def connect(self, connection_str: str, **kwargs: Any) -> None:
|
||||
"""
|
||||
Connect to client.
|
||||
|
||||
Args:
|
||||
connection_str: connection string.
|
||||
"""
|
||||
self.cache_dir = connection_str
|
||||
Path(self.cache_dir).mkdir(parents=True, exist_ok=True)
|
||||
# If more than two tables, switch to full on SQL connection
|
||||
self.query_file = Path(self.cache_dir, "query.sqlite")
|
||||
self.prompt_file = Path(self.cache_dir, "prompts.sqlite")
|
||||
self.cache = SqliteDict(self.query_file, autocommit=False)
|
||||
self.prompt_cache = SqliteDict(self.prompt_file, autocommit=False)
|
||||
return
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the client."""
|
||||
self.cache.close()
|
||||
|
||||
def get_key(self, key: str, table: str = "default") -> Union[str, None]:
|
||||
"""
|
||||
Get the key for a request.
|
||||
|
||||
With return None if key is not in cache.
|
||||
|
||||
Args:
|
||||
key: key for cache.
|
||||
table: table to get key in.
|
||||
"""
|
||||
if table == "prompt":
|
||||
return self.prompt_cache.get(key)
|
||||
else:
|
||||
if table != "default":
|
||||
raise ValueError(
|
||||
"SQLiteDict only support table of `default` or `prompt`"
|
||||
)
|
||||
return self.cache.get(key)
|
||||
|
||||
def set_key(self, key: str, value: str, table: str = "default") -> None:
|
||||
"""
|
||||
Set the value for the key.
|
||||
|
||||
Will override old value.
|
||||
|
||||
Args:
|
||||
key: key for cache.
|
||||
value: new value for key.
|
||||
table: table to set key in.
|
||||
"""
|
||||
if table == "prompt":
|
||||
self.prompt_cache[key] = value
|
||||
else:
|
||||
if table != "default":
|
||||
raise ValueError(
|
||||
"SQLiteDict only support table of `default` or `prompt`"
|
||||
)
|
||||
self.cache[key] = value
|
||||
self.commit()
|
||||
|
||||
def commit(self) -> None:
|
||||
"""Commit any results."""
|
||||
self.prompt_cache.commit()
|
||||
self.cache.commit()
|
3
manifest/clients/__init__.py
Normal file
3
manifest/clients/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
"""Client init."""
|
||||
from manifest.clients.client import Client
|
||||
from manifest.clients.response import Response
|
58
manifest/clients/client.py
Normal file
58
manifest/clients/client.py
Normal file
@ -0,0 +1,58 @@
|
||||
"""Client class."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, Optional, Tuple
|
||||
|
||||
from manifest.clients import Response
|
||||
|
||||
|
||||
class Client(ABC):
|
||||
"""Client class."""
|
||||
|
||||
def __init__(self, connection_str: Optional[str] = None, **kwargs: Any):
|
||||
"""
|
||||
Initialize client.
|
||||
|
||||
kwargs are passed to client as default parameters.
|
||||
|
||||
For clients like OpenAI that do not require a connection,
|
||||
the connection_str can be None.
|
||||
|
||||
Args:
|
||||
connection_str: connection string for client.
|
||||
"""
|
||||
self.connect(connection_str, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def close(self) -> None:
|
||||
"""Close the client."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def connect(self, connection_str: str, **kwargs: Any) -> None:
|
||||
"""
|
||||
Connect to client.
|
||||
|
||||
Args:
|
||||
connection_str: connection string.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def get_request(
|
||||
self, query: str, **kwargs: Any
|
||||
) -> Tuple[Callable[[], Response], Dict]:
|
||||
"""
|
||||
Get request function.
|
||||
|
||||
kwargs override default parameters.
|
||||
|
||||
Calling the returned function will run the request.
|
||||
|
||||
Args:
|
||||
query: query string.
|
||||
|
||||
Returns:
|
||||
request function that takes no input.
|
||||
request parameters as dict.
|
||||
"""
|
||||
raise NotImplementedError()
|
52
manifest/clients/dummy.py
Normal file
52
manifest/clients/dummy.py
Normal file
@ -0,0 +1,52 @@
|
||||
"""Dummy client."""
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, Optional, Tuple
|
||||
|
||||
from manifest.clients import Client
|
||||
from manifest.clients.response import Response
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DummyClient(Client):
|
||||
"""Dummy client."""
|
||||
|
||||
def connect(
|
||||
self,
|
||||
connection_str: Optional[str] = None,
|
||||
num_results: Optional[int] = 1,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Connect to dummpy server.
|
||||
|
||||
This is a dummy client that returns identity responses. Used for testing.
|
||||
"""
|
||||
self.num_results = num_results
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the client."""
|
||||
pass
|
||||
|
||||
def get_request(
|
||||
self, query: str, **kwargs: Any
|
||||
) -> Tuple[Callable[[], Response], Dict]:
|
||||
"""
|
||||
Get request string function.
|
||||
|
||||
Args:
|
||||
query: query string.
|
||||
|
||||
Returns:
|
||||
request function that takes no input.
|
||||
request parameters as dict.
|
||||
"""
|
||||
request_params = {
|
||||
"prompt": query,
|
||||
"num_results": kwargs.get("num_results", self.num_results),
|
||||
}
|
||||
|
||||
def _run_completion() -> Response:
|
||||
return Response({"choices": [{"text": "hello"}] * self.num_results})
|
||||
|
||||
return _run_completion, request_params
|
95
manifest/clients/openai.py
Normal file
95
manifest/clients/openai.py
Normal file
@ -0,0 +1,95 @@
|
||||
"""OpenAI client."""
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Optional, Tuple
|
||||
|
||||
import openai
|
||||
|
||||
from manifest.clients import Response
|
||||
from manifest.clients.client import Client
|
||||
|
||||
logging.getLogger("openai").setLevel(logging.WARNING)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OPENAI_ENGINES = {
|
||||
"text-davinci-002",
|
||||
"text-curie-001",
|
||||
"text-babbage-001",
|
||||
"text-ada-001",
|
||||
}
|
||||
|
||||
|
||||
class OpenAIClient(Client):
|
||||
"""OpenAI client."""
|
||||
|
||||
def connect(
|
||||
self,
|
||||
connection_str: Optional[str] = None,
|
||||
engine: Optional[str] = "text-ada-001",
|
||||
temperature: Optional[float] = 0.0,
|
||||
max_tokens: Optional[int] = 10,
|
||||
top_p: Optional[int] = 1,
|
||||
frequency_penalty: Optional[int] = 0,
|
||||
presence_penalty: Optional[int] = 0,
|
||||
n: Optional[int] = 1,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Connect to the OpenAI server.
|
||||
|
||||
connection_str is passed as default OPENAI_API_KEY if variable not set.
|
||||
"""
|
||||
openai.api_key = os.environ.get("OPENAI_API_KEY", connection_str)
|
||||
if openai.api_key is None:
|
||||
raise ValueError(
|
||||
"OpenAI API key not set. Set OPENAI_API_KEY environment ",
|
||||
"svariable or pass through `connection_str`.",
|
||||
)
|
||||
self.engine = engine
|
||||
if self.engine not in OPENAI_ENGINES:
|
||||
raise ValueError(f"Invalid engine {self.engine}. Must be {OPENAI_ENGINES}.")
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.top_p = top_p
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.presence_penalty = presence_penalty
|
||||
self.n = n
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the client."""
|
||||
pass
|
||||
|
||||
def get_request(
|
||||
self, query: str, **kwargs: Any
|
||||
) -> Tuple[Callable[[], Response], Dict]:
|
||||
"""
|
||||
Get request string function.
|
||||
|
||||
Args:
|
||||
query: query string.
|
||||
|
||||
Returns:
|
||||
request function that takes no input.
|
||||
request parameters as dict.
|
||||
"""
|
||||
request_params = {
|
||||
"engine": kwargs.get("engine", self.engine),
|
||||
"prompt": query,
|
||||
"temperature": kwargs.get("temperature", self.temperature),
|
||||
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
|
||||
"top_p": kwargs.get("top_p", self.top_p),
|
||||
"frequency_penalty": kwargs.get(
|
||||
"frequency_penalty", self.frequency_penalty
|
||||
),
|
||||
"presence_penalty": kwargs.get("presence_penalty", self.presence_penalty),
|
||||
"n": kwargs.get("n", self.n),
|
||||
}
|
||||
|
||||
def _run_completion() -> Response:
|
||||
try:
|
||||
return Response(openai.Completion.create(**request_params))
|
||||
except openai.error.OpenAIError as e:
|
||||
logger.error(e)
|
||||
raise e
|
||||
|
||||
return _run_completion, request_params
|
70
manifest/clients/response.py
Normal file
70
manifest/clients/response.py
Normal file
@ -0,0 +1,70 @@
|
||||
"""Client response."""
|
||||
import json
|
||||
from typing import Dict, List, Union
|
||||
|
||||
|
||||
class Response:
|
||||
"""Response class."""
|
||||
|
||||
def __init__(self, response: Union[str, Dict]):
|
||||
"""Initialize response."""
|
||||
if isinstance(response, str):
|
||||
self.response = json.loads(response)
|
||||
elif isinstance(response, dict):
|
||||
self.response = response
|
||||
else:
|
||||
raise ValueError("Response must be str or dict")
|
||||
if ("choices" not in self.response) or (
|
||||
not isinstance(self.response["choices"], list)
|
||||
):
|
||||
raise ValueError(
|
||||
"Response must be serialized to a dict with a list of choices"
|
||||
)
|
||||
if len(self.response["choices"]) > 0:
|
||||
if "text" not in self.response["choices"][0]:
|
||||
raise ValueError(
|
||||
"Response must be serialized to a dict with a ",
|
||||
"list of choices with text field",
|
||||
)
|
||||
|
||||
def __getitem__(self, key: str) -> str:
|
||||
"""
|
||||
Return the response given the key.
|
||||
|
||||
Args:
|
||||
key: key to get.
|
||||
|
||||
Returns:
|
||||
value of key.
|
||||
"""
|
||||
return self.response[key]
|
||||
|
||||
def get_results(self) -> Union[str, List[str]]:
|
||||
"""Get all text results from response."""
|
||||
if len(self.response["choices"]) == 0:
|
||||
return None
|
||||
if len(self.response["choices"]) == 1:
|
||||
return self.response["choices"][0]["text"]
|
||||
return [choice["text"] for choice in self.response["choices"]]
|
||||
|
||||
def serialize(self) -> str:
|
||||
"""
|
||||
Serialize response to string.
|
||||
|
||||
Returns:
|
||||
serialized response.
|
||||
"""
|
||||
return json.dumps(self.response, sort_keys=True)
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, value: str) -> "Response":
|
||||
"""
|
||||
Deserialize string to response.
|
||||
|
||||
Args:
|
||||
value: serialized response.
|
||||
|
||||
Returns:
|
||||
serialized response.
|
||||
"""
|
||||
return Response(value)
|
138
manifest/manifest.py
Normal file
138
manifest/manifest.py
Normal file
@ -0,0 +1,138 @@
|
||||
"""Manifest class."""
|
||||
import logging
|
||||
from typing import Any, Iterable, List, Optional, Union
|
||||
|
||||
logging.getLogger("openai").setLevel(logging.WARNING)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from manifest import Prompt
|
||||
from manifest.caches.redis import RedisCache
|
||||
from manifest.caches.sqlite import SQLiteCache
|
||||
from manifest.clients.dummy import DummyClient
|
||||
from manifest.clients.openai import OpenAIClient
|
||||
|
||||
CLIENT_CONSTRUCTORS = {
|
||||
"openai": OpenAIClient,
|
||||
# "huggingface": manifest.clients.huggingface.HuggingFaceClient,
|
||||
"dummy": DummyClient,
|
||||
}
|
||||
|
||||
CACHE_CONSTRUCTORS = {
|
||||
"redis": RedisCache,
|
||||
"sqlite": SQLiteCache,
|
||||
}
|
||||
|
||||
|
||||
class Manifest:
|
||||
"""Manifest session object."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client_name: str = "openai",
|
||||
client_connection: Optional[str] = None,
|
||||
cache_name: str = "redis",
|
||||
cache_connection: str = "localhost:6379",
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""
|
||||
Initialize manifest.
|
||||
|
||||
Remaining kwargs sent to client and cache.
|
||||
"""
|
||||
if client_name not in CLIENT_CONSTRUCTORS:
|
||||
raise ValueError(
|
||||
f"Unknown client name: {client_name}. "
|
||||
f"Choices are {list(CLIENT_CONSTRUCTORS.keys())}"
|
||||
)
|
||||
if cache_name not in CACHE_CONSTRUCTORS:
|
||||
raise ValueError(
|
||||
f"Unknown cache name: {cache_name}. "
|
||||
f"Choices are {list(CACHE_CONSTRUCTORS.keys())}"
|
||||
)
|
||||
self.client_name = client_name
|
||||
self.client = CLIENT_CONSTRUCTORS[client_name](client_connection, **kwargs)
|
||||
self.cache = CACHE_CONSTRUCTORS[cache_name](cache_connection, **kwargs)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the client and cache."""
|
||||
self.client.close()
|
||||
self.cache.close()
|
||||
|
||||
def run(
|
||||
self,
|
||||
prompt: Prompt,
|
||||
input: Optional[Any] = None,
|
||||
overwrite_cache: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Union[str, List[str]]:
|
||||
"""
|
||||
Run the prompt.
|
||||
|
||||
Args:
|
||||
prompt: prompt to run.
|
||||
input: input to prompt.
|
||||
overwrite_cache: whether to overwrite cache.
|
||||
|
||||
Returns:
|
||||
response from prompt.
|
||||
"""
|
||||
prompt_str = prompt(input)
|
||||
possible_request, full_kwargs = self.client.get_request(prompt_str, **kwargs)
|
||||
# Create cacke key
|
||||
cache_key = full_kwargs.copy()
|
||||
# Make query model dependent
|
||||
cache_key["client_name"] = self.client_name
|
||||
# Make query prompt dependent
|
||||
cache_key["prompt"] = prompt_str
|
||||
response, _ = self.cache.get(cache_key, overwrite_cache, possible_request)
|
||||
return response.get_results()
|
||||
|
||||
def run_batch(
|
||||
self,
|
||||
prompt: Prompt,
|
||||
input: Optional[Iterable[Any]] = None,
|
||||
overwrite_cache: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Iterable[Union[str, List[str]]]:
|
||||
"""
|
||||
Run the prompt on a batch of inputs.
|
||||
|
||||
Args:
|
||||
prompt: prompt to run.
|
||||
input: batch of inputs.
|
||||
overwrite_cache: whether to overwrite cache.
|
||||
|
||||
Returns:
|
||||
batch of responses.
|
||||
"""
|
||||
if input is None:
|
||||
input = [None]
|
||||
return [self.run(prompt, inp, overwrite_cache, **kwargs) for inp in input]
|
||||
|
||||
def save_prompt(self, name: str, prompt: Prompt) -> None:
|
||||
"""
|
||||
Save the prompt to the cache for long term storage.
|
||||
|
||||
Args:
|
||||
name: name of prompt.
|
||||
prompt: prompt to save.
|
||||
"""
|
||||
self.cache.set_key(name, prompt.serialize(), table="prompt")
|
||||
|
||||
def load_prompt(self, name: str) -> Prompt:
|
||||
"""
|
||||
Load the prompt from the cache.
|
||||
|
||||
Args:
|
||||
name: name of prompt.
|
||||
|
||||
Returns:
|
||||
Prompt saved with name.
|
||||
"""
|
||||
return Prompt.deserialize(self.cache.get_key(name, table="prompt"))
|
||||
|
||||
def open_explorer(self) -> None:
|
||||
"""Open the explorer for jupyter widget."""
|
||||
# Open explorer
|
||||
# TODO: implement
|
||||
pass
|
72
manifest/prompt.py
Normal file
72
manifest/prompt.py
Normal file
@ -0,0 +1,72 @@
|
||||
"""Prompt class."""
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
|
||||
logging.getLogger("openai").setLevel(logging.WARNING)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Prompt:
|
||||
"""Prompt class."""
|
||||
|
||||
def __init__(self, prompt_obj: Union[str, Callable, "Prompt", List["Prompt"]]):
|
||||
"""
|
||||
Initialize prompt.
|
||||
|
||||
If prompt_obj is a string, it will be cast as function.
|
||||
If prompt_obj is list of promts, it will be composed.
|
||||
"""
|
||||
# TODO: figure out how to compose prompts to keep the
|
||||
# interface simple? Can we make a function
|
||||
# such that a single call will run the composition?
|
||||
if isinstance(prompt_obj, str):
|
||||
self.prompt_func = lambda: prompt_obj
|
||||
elif callable(prompt_obj):
|
||||
self.prompt_func = prompt_obj
|
||||
else:
|
||||
# TODO: implement
|
||||
raise NotImplementedError()
|
||||
self.num_args = len(inspect.signature(self.prompt_func).parameters)
|
||||
if self.num_args > 1:
|
||||
raise ValueError("Prompts must have zero or one input.")
|
||||
|
||||
def __call__(self, input: Optional[Any] = None) -> str:
|
||||
"""
|
||||
Return the prompt given the inputs.
|
||||
|
||||
Args:
|
||||
input: input to prompt.
|
||||
|
||||
Returns:
|
||||
prompt string.
|
||||
"""
|
||||
if self.num_args >= 1:
|
||||
return self.prompt_func(input) # type: ignore
|
||||
else:
|
||||
return self.prompt_func()
|
||||
|
||||
def serialize(self) -> str:
|
||||
"""
|
||||
Return the prompt as str.
|
||||
|
||||
Returns:
|
||||
json object.
|
||||
"""
|
||||
# TODO: implement
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, obj: str) -> "Prompt":
|
||||
"""
|
||||
Return the prompt from a json object.
|
||||
|
||||
Args:
|
||||
obj: json object.
|
||||
|
||||
Return:
|
||||
prompt.
|
||||
"""
|
||||
# TODO: implement
|
||||
pass
|
1418
poetry.lock
generated
Normal file
1418
poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
69
pyproject.toml
Normal file
69
pyproject.toml
Normal file
@ -0,0 +1,69 @@
|
||||
[tool.poetry]
|
||||
authors = ["Laurel Orr <lorr1@cs.stanford.edu>, Avanika Narayan <avanikan@stanford.edu>"]
|
||||
classifiers = [
|
||||
"Programming Language :: Python",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10"
|
||||
]
|
||||
description = "Manifest for Prompt Programming"
|
||||
name = "manifest"
|
||||
repository = "https://github.com/HazyResearch/manifest"
|
||||
version = "0.0.1"
|
||||
|
||||
[tool.poetry.urls]
|
||||
"Bug Tracker" = "https://github.com/HazyResearch/manifest/issues"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.8"
|
||||
sqlitedict = "^2.0.0"
|
||||
openai = "^0.18.1"
|
||||
redis = "^4.3.1"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
black = "^22.3.0"
|
||||
flake8 = "^4.0.0"
|
||||
flake8-docstrings = "^1.6.0"
|
||||
isort = "^5.9.3"
|
||||
mypy = "^0.950"
|
||||
pep8-naming = "^0.12.1"
|
||||
pre-commit = "^2.14.0"
|
||||
pytest = "^7.0.0"
|
||||
pytest-cov = "^3.0.0"
|
||||
python-dotenv = "^0.20.0"
|
||||
recommonmark = "^0.7.1"
|
||||
|
||||
[build-system]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
|
||||
# Additional Tool Configurations
|
||||
[tool.mypy]
|
||||
disallow_untyped_defs = true
|
||||
strict_optional = false
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
ignore_missing_imports = true
|
||||
module = [
|
||||
"numpy",
|
||||
"tqdm",
|
||||
"sqlitedict",
|
||||
]
|
||||
|
||||
[tool.isort]
|
||||
combine_as_imports = true
|
||||
force_grid_wrap = 0
|
||||
include_trailing_comma = true
|
||||
known_first_party = ["manifest"]
|
||||
line_length = 88
|
||||
multi_line_output = 3
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
log_format = "[%(levelname)s] %(message)s"
|
||||
log_date_format = "%Y-%m-%d %H:%M:%S"
|
||||
addopts = "-v -rsXx"
|
||||
# The following options are useful for local debugging
|
||||
# addopts = "-v -rsXx -s -x --pdb"
|
||||
# log_cli_level = "DEBUG"
|
||||
# log_cli = true
|
1
tests/clients/test_client.py
Normal file
1
tests/clients/test_client.py
Normal file
@ -0,0 +1 @@
|
||||
"""Test client."""
|
59
tests/clients/test_response.py
Normal file
59
tests/clients/test_response.py
Normal file
@ -0,0 +1,59 @@
|
||||
"""Response test."""
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from manifest.clients import Response
|
||||
|
||||
|
||||
def test_init():
|
||||
"""Test response initialization."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
response = Response(4)
|
||||
assert str(exc_info.value) == "Response must be str or dict"
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
response = Response({"test": "hello"})
|
||||
assert (
|
||||
str(exc_info.value)
|
||||
== "Response must be serialized to a dict with a list of choices"
|
||||
)
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
response = Response({"choices": [{"blah": "hello"}]})
|
||||
assert str(exc_info.value) == (
|
||||
"Response must be serialized to a dict ",
|
||||
"with a list of choices with text field",
|
||||
)
|
||||
|
||||
response = Response({"choices": [{"text": "hello"}]})
|
||||
assert response.response == {"choices": [{"text": "hello"}]}
|
||||
|
||||
response = Response(json.dumps({"choices": [{"text": "hello"}]}))
|
||||
assert response.response == {"choices": [{"text": "hello"}]}
|
||||
|
||||
|
||||
def test_getitem():
|
||||
"""Test response getitem."""
|
||||
response = Response({"choices": [{"text": "hello"}]})
|
||||
assert response["choices"] == [{"text": "hello"}]
|
||||
|
||||
|
||||
def test_serialize():
|
||||
"""Test response serialization."""
|
||||
response = Response({"choices": [{"text": "hello"}]})
|
||||
assert Response.deserialize(response.serialize()).response == {
|
||||
"choices": [{"text": "hello"}]
|
||||
}
|
||||
|
||||
|
||||
def test_get_results():
|
||||
"""Test response get results."""
|
||||
response = Response({"choices": []})
|
||||
assert response.get_results() is None
|
||||
|
||||
response = Response({"choices": [{"text": "hello"}]})
|
||||
assert response.get_results() == "hello"
|
||||
|
||||
response = Response(
|
||||
{"choices": [{"text": "hello"}, {"text": "my"}, {"text": "name"}]}
|
||||
)
|
||||
assert response.get_results() == ["hello", "my", "name"]
|
36
tests/conftest.py
Normal file
36
tests/conftest.py
Normal file
@ -0,0 +1,36 @@
|
||||
"""Setup for all tests."""
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sqlite_cache(tmp_path):
|
||||
"""Sqlite Cache."""
|
||||
cache = str(tmp_path / "sqlite_cache.sqlite")
|
||||
yield cache
|
||||
shutil.rmtree(cache, ignore_errors=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def redis_cache():
|
||||
"""Redis cache."""
|
||||
if "CI" not in os.environ:
|
||||
# Give a clear warning on setting REDIS_PORT before running tests.
|
||||
try:
|
||||
port = os.environ["REDIS_PORT"]
|
||||
except KeyError:
|
||||
raise KeyError(
|
||||
"Set REDIS_PORT env var to the instance you want to use "
|
||||
+ "for testing. Note that doing so WILL delete the db at "
|
||||
+ "localhost:REDIS_PORT, db=0, so BE CAREFUL."
|
||||
)
|
||||
host = os.environ.get("REDIS_HOST", "localhost")
|
||||
else:
|
||||
host = os.environ.get("REDIS_HOST", "localhost")
|
||||
port = os.environ.get("REDIS_PORT", 6379)
|
||||
yield f"{host}:{port}"
|
||||
# Clear out the database
|
||||
# db = redis.Redis(host=host, port=port)
|
||||
# db.flushdb()
|
69
tests/test_cache.py
Normal file
69
tests/test_cache.py
Normal file
@ -0,0 +1,69 @@
|
||||
"""Cache test."""
|
||||
import pytest
|
||||
from sqlitedict import SqliteDict
|
||||
|
||||
from manifest.caches.redis import RedisCache
|
||||
from manifest.caches.sqlite import SQLiteCache
|
||||
from manifest.clients import Response
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("sqlite_cache")
|
||||
@pytest.mark.usefixtures("redis_cache")
|
||||
@pytest.mark.parametrize("cache_type", ["sqlite"])
|
||||
def test_init(sqlite_cache, redis_cache, cache_type):
|
||||
"""Test cache initialization."""
|
||||
if cache_type == "sqlite":
|
||||
cache = SQLiteCache(sqlite_cache)
|
||||
assert isinstance(cache.cache, SqliteDict)
|
||||
assert isinstance(cache.prompt_cache, SqliteDict)
|
||||
else:
|
||||
cache = RedisCache(redis_cache)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("sqlite_cache")
|
||||
@pytest.mark.usefixtures("redis_cache")
|
||||
@pytest.mark.parametrize("cache_type", ["sqlite"])
|
||||
def test_key_get_and_set(sqlite_cache, redis_cache, cache_type):
|
||||
"""Test cache key get and set."""
|
||||
if cache_type == "sqlite":
|
||||
cache = SQLiteCache(sqlite_cache)
|
||||
else:
|
||||
cache = RedisCache(redis_cache)
|
||||
|
||||
cache.set_key("test", "valueA")
|
||||
cache.set_key("testA", "valueB")
|
||||
|
||||
assert cache.get_key("test") == "valueA"
|
||||
assert cache.get_key("testA") == "valueB"
|
||||
|
||||
cache.set_key("testA", "valueC")
|
||||
assert cache.get_key("testA") == "valueC"
|
||||
|
||||
cache.get_key("test", table="prompt") is None
|
||||
cache.set_key("test", "valueA", table="prompt")
|
||||
cache.get_key("test", table="prompt") == "valueA"
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("sqlite_cache")
|
||||
@pytest.mark.usefixtures("redis_cache")
|
||||
@pytest.mark.parametrize("cache_type", ["sqlite"])
|
||||
def test_get(sqlite_cache, redis_cache, cache_type):
|
||||
"""Test cache save prompt."""
|
||||
if cache_type == "sqlite":
|
||||
cache = SQLiteCache(sqlite_cache)
|
||||
else:
|
||||
cache = RedisCache(redis_cache)
|
||||
test_request = {"test": "hello", "testA": "world"}
|
||||
compute = lambda: Response({"choices": [{"text": "hello"}]})
|
||||
|
||||
response, cached = cache.get(test_request, overwrite_cache=False, compute=compute)
|
||||
assert response.get_results() == "hello"
|
||||
assert not cached
|
||||
|
||||
response, cached = cache.get(test_request, overwrite_cache=False, compute=compute)
|
||||
assert response.get_results() == "hello"
|
||||
assert cached
|
||||
|
||||
response, cached = cache.get(test_request, overwrite_cache=True, compute=compute)
|
||||
assert response.get_results() == "hello"
|
||||
assert not cached
|
105
tests/test_manifest.py
Normal file
105
tests/test_manifest.py
Normal file
@ -0,0 +1,105 @@
|
||||
"""Manifest test."""
|
||||
import pytest
|
||||
|
||||
from manifest import Manifest, Prompt
|
||||
from manifest.caches.cache import request_to_key
|
||||
from manifest.caches.sqlite import SQLiteCache
|
||||
from manifest.clients.dummy import DummyClient
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("sqlite_cache")
|
||||
def test_init(sqlite_cache):
|
||||
"""Test manifest initialization."""
|
||||
manifest = Manifest(
|
||||
client_name="dummy",
|
||||
cache_name="sqlite",
|
||||
cache_connection=sqlite_cache,
|
||||
)
|
||||
assert manifest.client_name == "dummy"
|
||||
assert isinstance(manifest.client, DummyClient)
|
||||
assert isinstance(manifest.cache, SQLiteCache)
|
||||
|
||||
manifest = Manifest(
|
||||
client_name="dummy",
|
||||
cache_name="sqlite",
|
||||
cache_connection=sqlite_cache,
|
||||
num_results=3,
|
||||
)
|
||||
assert manifest.client_name == "dummy"
|
||||
assert isinstance(manifest.client, DummyClient)
|
||||
assert isinstance(manifest.cache, SQLiteCache)
|
||||
assert manifest.client.num_results == 3
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("sqlite_cache")
|
||||
@pytest.mark.parametrize("num_results", [1, 2])
|
||||
def test_run(sqlite_cache, num_results):
|
||||
"""Test manifest run."""
|
||||
manifest = Manifest(
|
||||
client_name="dummy",
|
||||
cache_name="sqlite",
|
||||
cache_connection=sqlite_cache,
|
||||
num_results=num_results,
|
||||
)
|
||||
prompt = Prompt("This is a prompt")
|
||||
res = manifest.run(prompt)
|
||||
assert (
|
||||
manifest.cache.get_key(
|
||||
request_to_key(
|
||||
{
|
||||
"prompt": "This is a prompt",
|
||||
"client_name": "dummy",
|
||||
"num_results": num_results,
|
||||
}
|
||||
)
|
||||
)
|
||||
is not None
|
||||
)
|
||||
if num_results == 1:
|
||||
assert res == "hello"
|
||||
else:
|
||||
assert res == ["hello", "hello"]
|
||||
|
||||
prompt = Prompt(lambda x: f"{x} is a prompt")
|
||||
res = manifest.run(prompt, "Hello")
|
||||
assert (
|
||||
manifest.cache.get_key(
|
||||
request_to_key(
|
||||
{
|
||||
"prompt": "Hello is a prompt",
|
||||
"client_name": "dummy",
|
||||
"num_results": num_results,
|
||||
}
|
||||
)
|
||||
)
|
||||
is not None
|
||||
)
|
||||
if num_results == 1:
|
||||
assert res == "hello"
|
||||
else:
|
||||
assert res == ["hello", "hello"]
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("sqlite_cache")
|
||||
@pytest.mark.parametrize("num_results", [1, 2])
|
||||
def test_batch_run(sqlite_cache, num_results):
|
||||
"""Test manifest run."""
|
||||
manifest = Manifest(
|
||||
client_name="dummy",
|
||||
cache_name="sqlite",
|
||||
cache_connection=sqlite_cache,
|
||||
num_results=num_results,
|
||||
)
|
||||
prompt = Prompt("This is a prompt")
|
||||
res = manifest.run_batch(prompt)
|
||||
if num_results == 1:
|
||||
assert res == ["hello"]
|
||||
else:
|
||||
assert res == [["hello", "hello"]]
|
||||
|
||||
prompt = Prompt(lambda x: f"{x} is a prompt")
|
||||
res = manifest.run_batch(prompt, ["Hello", "Hello"])
|
||||
if num_results == 1:
|
||||
assert res == ["hello", "hello"]
|
||||
else:
|
||||
assert res == [["hello", "hello"], ["hello", "hello"]]
|
54
tests/test_prompt.py
Normal file
54
tests/test_prompt.py
Normal file
@ -0,0 +1,54 @@
|
||||
"""Prompt test."""
|
||||
import pytest
|
||||
|
||||
from manifest import Prompt
|
||||
|
||||
|
||||
def test_init():
|
||||
"""Test prompt initialization."""
|
||||
str_prompt = "This is a test prompt"
|
||||
func_prompt = lambda: "This is a test prompt"
|
||||
func_single_prompt = lambda x: f"{x} is a test prompt"
|
||||
func_list_prompt = lambda x: f"{x[0]} is a test {x[1]}"
|
||||
func_double_prompt = lambda x, y: f"{x} is a test {y}"
|
||||
# TODO: add list of prompt tests
|
||||
|
||||
# String prompt
|
||||
prompt = Prompt(str_prompt)
|
||||
assert prompt(None) == str_prompt
|
||||
assert prompt() == str_prompt
|
||||
|
||||
# Function no inputs
|
||||
prompt = Prompt(func_prompt)
|
||||
assert prompt(None) == str_prompt
|
||||
assert prompt() == str_prompt
|
||||
|
||||
# Function single inputs
|
||||
prompt = Prompt(func_single_prompt)
|
||||
assert prompt("This") == str_prompt
|
||||
assert prompt("Hello") == "Hello is a test prompt"
|
||||
|
||||
# Function list inputs
|
||||
prompt = Prompt(func_list_prompt)
|
||||
assert prompt(["This", "prompt"]) == str_prompt
|
||||
assert prompt(["Hello", "prompt"]) == "Hello is a test prompt"
|
||||
|
||||
# Function two inputs
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
Prompt(func_double_prompt)
|
||||
assert str(exc_info.value) == "Prompts must have zero or one input."
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Not implemented")
|
||||
def test_serialize():
|
||||
"""Test prompt serialization."""
|
||||
str_prompt = "This is a test prompt"
|
||||
func_single_prompt = lambda x: f"{x} is a test prompt"
|
||||
|
||||
# String prompt
|
||||
prompt = Prompt(str_prompt)
|
||||
assert Prompt.deserialize(prompt.serialize()) == prompt
|
||||
|
||||
# Function single inputs
|
||||
prompt = Prompt(func_single_prompt)
|
||||
assert Prompt.deserialize(prompt.serialize()) == prompt
|
Loading…
Reference in New Issue
Block a user