feat: web app for manifest (#46)

Also fixed typing issues in tests
laurel/helm
Laurel Orr 1 year ago committed by GitHub
parent 6f5b64f0df
commit defc63bf36

@ -10,14 +10,14 @@ deepspeed:
pip install -e git+https://github.com/microsoft/DeepSpeed.git#egg=deepspeed
format:
isort --atomic manifest/ tests/
black manifest/ tests/
isort --atomic manifest/ tests/ web_app/
black manifest/ tests/ web_app/
check:
isort -c manifest/ tests/
black manifest/ tests/ --check
flake8 manifest/ tests/
mypy manifest/
isort -c manifest/ tests/ web_app/
black manifest/ tests/ web_app/ --check
flake8 manifest/ tests/ web_app/
mypy manifest/ tests/ web_app/
clean:
pip uninstall -y manifest

@ -160,9 +160,9 @@ class Client(ABC):
" Increase client_timeout."
)
raise e
except requests.exceptions.HTTPError as e:
logger.error(res.text)
raise e
except requests.exceptions.HTTPError:
logger.error(res.json())
raise requests.exceptions.HTTPError(res.json())
return self.format_response(res.json())
return _run_completion, request_params

@ -105,6 +105,42 @@ class Manifest:
self.client.close()
self.cache.close()
def change_client(
self,
client_name: Optional[str] = None,
client_connection: Optional[str] = None,
stop_token: Optional[str] = None,
**kwargs: Any,
) -> None:
"""
Change manifest client.
Args:
client_name: name of client.
client_connection: connection string for client.
stop_token: stop token prompt generation.
Can be overridden in run
Remaining kwargs sent to client.
"""
if client_name:
if client_name not in CLIENT_CONSTRUCTORS:
raise ValueError(
f"Unknown client name: {client_name}. "
f"Choices are {list(CLIENT_CONSTRUCTORS.keys())}"
)
self.client_name = client_name
self.client = CLIENT_CONSTRUCTORS[client_name]( # type: ignore
client_connection, client_args=kwargs
)
if len(kwargs) > 0:
raise ValueError(
f"{list(kwargs.items())} arguments are not recognized."
)
if stop_token is not None:
self.stop_token = stop_token
def run(
self,
prompt: Union[str, List[str]],
@ -164,7 +200,6 @@ class Manifest:
]
if len(request_unused_kwargs) > 0:
logger.warning(f"{list(request_unused_kwargs)} arguments are unused.")
# Create cacke key
cache_key = full_kwargs.copy()
# Make query model dependent

@ -19,10 +19,10 @@ with open(ver_path) as ver_file:
# Package meta-data.
NAME = "manifest-ml"
DESCRIPTION = "Manifest for Prompt Programming Foundation Models."
DESCRIPTION = "Manifest for Prompting Foundation Models."
URL = "https://github.com/HazyResearch/manifest"
EMAIL = "lorr1@cs.stanford.edu"
AUTHOR = "Laurel Orr and Avanika Narayan"
AUTHOR = "Laurel Orr"
REQUIRES_PYTHON = ">=3.8.0"
VERSION = main_ns["__version__"]
@ -34,6 +34,8 @@ EXTRAS = {
"api": [
"diffusers>=0.6.0",
"Flask>=2.1.2",
"fastapi>=0.70.0",
"uvicorn>=0.18.0",
"accelerate>=0.10.0",
"transformers>=4.20.0",
"torch>=1.8.0",

@ -2,13 +2,14 @@
import os
import shutil
from pathlib import Path
from typing import Generator
import pytest
import redis
@pytest.fixture
def sqlite_cache(tmp_path):
def sqlite_cache(tmp_path: Path) -> Generator[str, None, None]:
"""Sqlite Cache."""
cache = str(tmp_path / "sqlite_cache.sqlite")
yield cache
@ -16,22 +17,22 @@ def sqlite_cache(tmp_path):
@pytest.fixture
def redis_cache():
def redis_cache() -> Generator[str, None, None]:
"""Redis cache."""
host = os.environ.get("REDIS_HOST", "localhost")
port = os.environ.get("REDIS_PORT", 6379)
port = int(os.environ.get("REDIS_PORT", 6379))
yield f"{host}:{port}"
# Clear out the database
try:
db = redis.Redis(host=host, port=port)
db.flushdb()
# For better local testing, pass if redis DB not started
except OSError:
except redis.exceptions.ConnectionError:
pass
@pytest.fixture
def session_cache(tmpdir):
def session_cache(tmpdir: str) -> Generator[Path, None, None]:
"""Session cache dir."""
os.environ["MANIFEST_HOME"] = str(tmpdir)
yield Path(tmpdir)

@ -7,16 +7,15 @@ import pytest
from manifest.caches.array_cache import ArrayCache
def test_init(tmpdir):
def test_init(tmpdir: Path) -> None:
"""Test cache initialization."""
tmpdir = Path(tmpdir)
cache = ArrayCache(tmpdir)
cache = ArrayCache(Path(tmpdir))
assert (tmpdir / "hash2arrloc.sqlite").exists()
assert cache.cur_file_idx == 0
assert cache.cur_offset == 0
def test_put_get(tmpdir):
def test_put_get(tmpdir: Path) -> None:
"""Test putting and getting."""
cache = ArrayCache(tmpdir)
cache.max_memmap_size = 5
@ -67,7 +66,7 @@ def test_put_get(tmpdir):
assert np.allclose(cache.get("key2"), arr2)
def test_contains_key(tmpdir):
def test_contains_key(tmpdir: Path) -> None:
"""Test contains key."""
cache = ArrayCache(tmpdir)
assert not cache.contains_key("key")

@ -1,9 +1,12 @@
"""Cache test."""
from typing import cast
import numpy as np
import pytest
from redis import Redis
from sqlitedict import SqliteDict
from manifest.caches.cache import Cache
from manifest.caches.noop import NoopCache
from manifest.caches.redis import RedisCache
from manifest.caches.sqlite import SQLiteCache
@ -12,25 +15,25 @@ from manifest.caches.sqlite import SQLiteCache
@pytest.mark.usefixtures("sqlite_cache")
@pytest.mark.usefixtures("redis_cache")
@pytest.mark.parametrize("cache_type", ["sqlite", "redis"])
def test_init(sqlite_cache, redis_cache, cache_type):
def test_init(sqlite_cache: str, redis_cache: str, cache_type: str) -> None:
"""Test cache initialization."""
if cache_type == "sqlite":
cache = SQLiteCache(sqlite_cache)
assert isinstance(cache.cache, SqliteDict)
sql_cache_obj = SQLiteCache(sqlite_cache)
assert isinstance(sql_cache_obj.cache, SqliteDict)
else:
cache = RedisCache(redis_cache)
assert isinstance(cache.redis, Redis)
redis_cache_obj = RedisCache(redis_cache)
assert isinstance(redis_cache_obj.redis, Redis)
@pytest.mark.usefixtures("sqlite_cache")
@pytest.mark.usefixtures("redis_cache")
@pytest.mark.parametrize("cache_type", ["sqlite", "redis"])
def test_key_get_and_set(sqlite_cache, redis_cache, cache_type):
def test_key_get_and_set(sqlite_cache: str, redis_cache: str, cache_type: str) -> None:
"""Test cache key get and set."""
if cache_type == "sqlite":
cache = SQLiteCache(sqlite_cache)
cache = cast(Cache, SQLiteCache(sqlite_cache))
else:
cache = RedisCache(redis_cache)
cache = cast(Cache, RedisCache(redis_cache))
cache.set_key("test", "valueA")
cache.set_key("testA", "valueB")
@ -48,46 +51,46 @@ def test_key_get_and_set(sqlite_cache, redis_cache, cache_type):
@pytest.mark.usefixtures("sqlite_cache")
@pytest.mark.usefixtures("redis_cache")
@pytest.mark.parametrize("cache_type", ["sqlite", "redis"])
def test_get(sqlite_cache, redis_cache, cache_type):
def test_get(sqlite_cache: str, redis_cache: str, cache_type: str) -> None:
"""Test cache save prompt."""
if cache_type == "sqlite":
cache = SQLiteCache(sqlite_cache)
cache = cast(Cache, SQLiteCache(sqlite_cache))
else:
cache = RedisCache(redis_cache)
cache = cast(Cache, RedisCache(redis_cache))
test_request = {"test": "hello", "testA": "world"}
compute = lambda: {"choices": [{"text": "hello"}]}
# response = cache.get(test_request, overwrite_cache=False, compute=compute)
# assert response.get_response() == "hello"
# assert not response.is_cached()
# assert response.get_request() == test_request
response = cache.get(test_request, overwrite_cache=False, compute=compute)
assert response.get_response() == "hello"
assert not response.is_cached()
assert response.get_request() == test_request
# response = cache.get(test_request, overwrite_cache=False, compute=compute)
# assert response.get_response() == "hello"
# assert response.is_cached()
# assert response.get_request() == test_request
response = cache.get(test_request, overwrite_cache=False, compute=compute)
assert response.get_response() == "hello"
assert response.is_cached()
assert response.get_request() == test_request
# response = cache.get(test_request, overwrite_cache=True, compute=compute)
# assert response.get_response() == "hello"
# assert not response.is_cached()
# assert response.get_request() == test_request
response = cache.get(test_request, overwrite_cache=True, compute=compute)
assert response.get_response() == "hello"
assert not response.is_cached()
assert response.get_request() == test_request
arr = np.random.rand(4, 4)
test_request = {"test": "hello", "testA": "world of images"}
compute = lambda: {"choices": [{"array": arr}]}
compute_arr = lambda: {"choices": [{"array": arr}]}
# Test array
if cache_type == "sqlite":
cache = SQLiteCache(sqlite_cache, client_name="diffuser")
else:
cache = RedisCache(redis_cache, client_name="diffuser")
response = cache.get(test_request, overwrite_cache=False, compute=compute)
response = cache.get(test_request, overwrite_cache=False, compute=compute_arr)
assert np.allclose(response.get_response(), arr)
assert not response.is_cached()
assert response.get_request() == test_request
def test_noop_cache():
def test_noop_cache() -> None:
"""Test cache that is a no-op cache."""
cache = NoopCache(None)
cache.set_key("test", "valueA")

@ -6,38 +6,38 @@ We just test the dummy client as we don't want to load a model or use OpenAI tok
from manifest.clients.dummy import DummyClient
def test_init():
def test_init() -> None:
"""Test client initialization."""
client = DummyClient(connection_str=None)
assert client.n == 1
assert client.n == 1 # type: ignore
args = {"n": 3}
client = DummyClient(connection_str=None, client_args=args)
assert client.n == 3
assert client.n == 3 # type: ignore
def test_get_params():
def test_get_params() -> None:
"""Test get param functions."""
client = DummyClient(connection_str=None)
assert client.get_model_params() == {"engine": "dummy"}
assert client.get_model_inputs() == ["n"]
def test_get_request():
def test_get_request() -> None:
"""Test client get request."""
args = {"n": 3}
client = DummyClient(connection_str=None, client_args=args)
request_params = client.get_request_params("hello", {})
request_func, request_params = client.get_request(request_params)
assert request_params == {"prompt": "hello", "num_results": 3}
request_func, request_params_return = client.get_request(request_params)
assert request_params_return == {"prompt": "hello", "num_results": 3}
assert request_func() == {"choices": [{"text": "hello"}] * 3}
request_params = client.get_request_params("hello", {"n": 5})
request_func, request_params = client.get_request(request_params)
assert request_params == {"prompt": "hello", "num_results": 5}
request_func, request_params_return = client.get_request(request_params)
assert request_params_return == {"prompt": "hello", "num_results": 5}
assert request_func() == {"choices": [{"text": "hello"}] * 5}
request_params = client.get_request_params(["hello"] * 5, {"n": 1})
request_func, request_params = client.get_request(request_params)
assert request_params == {"prompt": ["hello"] * 5, "num_results": 1}
request_func, request_params_return = client.get_request(request_params)
assert request_params_return == {"prompt": ["hello"] * 5, "num_results": 1}
assert request_func() == {"choices": [{"text": "hello"}] * 5}

@ -28,14 +28,16 @@ except OSError:
MAXGPU = 0
if NOCUDA == 0:
try:
p = os.popen("nvidia-smi --query-gpu=index --format=csv,noheader,nounits")
i = p.read().split("\n")
p = os.popen( # type: ignore
"nvidia-smi --query-gpu=index --format=csv,noheader,nounits"
)
i = p.read().split("\n") # type: ignore
MAXGPU = int(i[-2]) + 1
except OSError:
NOCUDA = 1
def test_gpt_generate():
def test_gpt_generate() -> None:
"""Test pipeline generation from a gpt model."""
model = TextGenerationModel(
model_name_or_path="gpt2",
@ -80,7 +82,7 @@ def test_gpt_generate():
assert math.isclose(round(result[0][1], 3), -1.414)
def test_encdec_generate():
def test_encdec_generate() -> None:
"""Test pipeline generation from a gpt model."""
model = TextGenerationModel(
model_name_or_path="google/t5-small-lm-adapt",
@ -125,7 +127,7 @@ def test_encdec_generate():
assert math.isclose(round(result[0][1], 3), -4.233)
def test_gpt_score():
def test_gpt_score() -> None:
"""Test pipeline generation from a gpt model."""
model = TextGenerationModel(
model_name_or_path="gpt2",
@ -144,7 +146,7 @@ def test_gpt_score():
assert math.isclose(round(result[1], 3), -45.831)
def test_batch_gpt_generate():
def test_batch_gpt_generate() -> None:
"""Test pipeline generation from a gpt model."""
model = TextGenerationModel(
model_name_or_path="gpt2",
@ -193,7 +195,7 @@ def test_batch_gpt_generate():
assert math.isclose(round(result[1][1], 3), -6.246)
def test_batch_encdec_generate():
def test_batch_encdec_generate() -> None:
"""Test pipeline generation from a gpt model."""
model = TextGenerationModel(
model_name_or_path="google/t5-small-lm-adapt",
@ -247,7 +249,7 @@ def test_batch_encdec_generate():
@pytest.mark.skipif(
(NOCUDA == 1 or MAXGPU == 0), reason="No cuda or GPUs found through nvidia-smi"
)
def test_gpt_deepspeed_generate():
def test_gpt_deepspeed_generate() -> None:
"""Test deepspeed generation from a gpt model."""
model = TextGenerationModel(
model_name_or_path="gpt2",

@ -1,5 +1,6 @@
"""Manifest test."""
import json
from typing import cast
import pytest
@ -12,7 +13,7 @@ from manifest.session import Session
@pytest.mark.usefixtures("sqlite_cache")
@pytest.mark.usefixtures("session_cache")
def test_init(sqlite_cache, session_cache):
def test_init(sqlite_cache: str, session_cache: str) -> None:
"""Test manifest initialization."""
with pytest.raises(ValueError) as exc_info:
Manifest(
@ -32,7 +33,7 @@ def test_init(sqlite_cache, session_cache):
assert isinstance(manifest.client, DummyClient)
assert isinstance(manifest.cache, SQLiteCache)
assert manifest.session is None
assert manifest.client.n == 1
assert manifest.client.n == 1 # type: ignore
assert manifest.stop_token == ""
manifest = Manifest(
@ -46,7 +47,34 @@ def test_init(sqlite_cache, session_cache):
assert isinstance(manifest.client, DummyClient)
assert isinstance(manifest.cache, NoopCache)
assert isinstance(manifest.session, Session)
assert manifest.client.n == 3
assert manifest.client.n == 3 # type: ignore
assert manifest.stop_token == "\n"
@pytest.mark.usefixtures("sqlite_cache")
@pytest.mark.usefixtures("session_cache")
def test_change_manifest(sqlite_cache: str, session_cache: str) -> None:
"""Test manifest change."""
manifest = Manifest(
client_name="dummy",
cache_name="sqlite",
cache_connection=sqlite_cache,
)
manifest.change_client()
assert manifest.client_name == "dummy"
assert isinstance(manifest.client, DummyClient)
assert isinstance(manifest.cache, SQLiteCache)
assert manifest.session is None
assert manifest.client.n == 1 # type: ignore
assert manifest.stop_token == ""
manifest.change_client(stop_token="\n")
assert manifest.client_name == "dummy"
assert isinstance(manifest.client, DummyClient)
assert isinstance(manifest.cache, SQLiteCache)
assert manifest.session is None
assert manifest.client.n == 1 # type: ignore
assert manifest.stop_token == "\n"
@ -54,7 +82,9 @@ def test_init(sqlite_cache, session_cache):
@pytest.mark.usefixtures("session_cache")
@pytest.mark.parametrize("n", [1, 2])
@pytest.mark.parametrize("return_response", [True, False])
def test_run(sqlite_cache, session_cache, n, return_response):
def test_run(
sqlite_cache: str, session_cache: str, n: int, return_response: bool
) -> None:
"""Test manifest run."""
manifest = Manifest(
client_name="dummy",
@ -77,9 +107,9 @@ def test_run(sqlite_cache, session_cache, n, return_response):
result = manifest.run(prompt, return_response=return_response)
if return_response:
assert isinstance(result, Response)
res = result.get_response(manifest.stop_token)
res = cast(Response, result).get_response(manifest.stop_token)
else:
res = result
res = cast(str, result)
assert (
manifest.cache.get_key(
json.dumps(
@ -102,9 +132,9 @@ def test_run(sqlite_cache, session_cache, n, return_response):
result = manifest.run(prompt, run_id="34", return_response=return_response)
if return_response:
assert isinstance(result, Response)
res = result.get_response(manifest.stop_token)
res = cast(Response, result).get_response(manifest.stop_token)
else:
res = result
res = cast(str, result)
assert (
manifest.cache.get_key(
json.dumps(
@ -128,9 +158,9 @@ def test_run(sqlite_cache, session_cache, n, return_response):
result = manifest.run(prompt, return_response=return_response)
if return_response:
assert isinstance(result, Response)
res = result.get_response(manifest.stop_token)
res = cast(Response, result).get_response(manifest.stop_token)
else:
res = result
res = cast(str, result)
assert (
manifest.cache.get_key(
json.dumps(
@ -153,9 +183,9 @@ def test_run(sqlite_cache, session_cache, n, return_response):
result = manifest.run(prompt, stop_token="ll", return_response=return_response)
if return_response:
assert isinstance(result, Response)
res = result.get_response(stop_token="ll")
res = cast(Response, result).get_response(stop_token="ll")
else:
res = result
res = cast(str, result)
assert (
manifest.cache.get_key(
json.dumps(
@ -179,7 +209,9 @@ def test_run(sqlite_cache, session_cache, n, return_response):
@pytest.mark.usefixtures("session_cache")
@pytest.mark.parametrize("n", [1, 2])
@pytest.mark.parametrize("return_response", [True, False])
def test_batch_run(sqlite_cache, session_cache, n, return_response):
def test_batch_run(
sqlite_cache: str, session_cache: str, n: int, return_response: bool
) -> None:
"""Test manifest run."""
manifest = Manifest(
client_name="dummy",
@ -195,32 +227,38 @@ def test_batch_run(sqlite_cache, session_cache, n, return_response):
else:
result = manifest.run(prompt, return_response=return_response)
if return_response:
res = result.get_response(manifest.stop_token, is_batch=True)
res = cast(Response, result).get_response(
manifest.stop_token, is_batch=True
)
else:
res = result
res = cast(str, result)
assert res == ["hello"]
prompt = ["Hello is a prompt", "Hello is a prompt"]
result = manifest.run(prompt, return_response=return_response)
if return_response:
res = result.get_response(manifest.stop_token, is_batch=True)
res = cast(Response, result).get_response(
manifest.stop_token, is_batch=True
)
else:
res = result
res = cast(str, result)
assert res == ["hello", "hello"]
prompt = ["Hello is a prompt", "Hello is a prompt"]
result = manifest.run(prompt, stop_token="ll", return_response=return_response)
if return_response:
res = result.get_response(stop_token="ll", is_batch=True)
res = cast(Response, result).get_response(stop_token="ll", is_batch=True)
else:
res = result
res = cast(str, result)
assert res == ["he", "he"]
@pytest.mark.usefixtures("sqlite_cache")
@pytest.mark.usefixtures("session_cache")
@pytest.mark.parametrize("return_response", [True, False])
def test_choices_run(sqlite_cache, session_cache, return_response):
def test_choices_run(
sqlite_cache: str, session_cache: str, return_response: bool
) -> None:
"""Test manifest run."""
manifest = Manifest(
client_name="dummy",
@ -234,9 +272,9 @@ def test_choices_run(sqlite_cache, session_cache, return_response):
result = manifest.run(prompt, gold_choices=choices, return_response=return_response)
if return_response:
assert isinstance(result, Response)
res = result.get_response(manifest.stop_token)
res = cast(Response, result).get_response(manifest.stop_token)
else:
res = result
res = cast(str, result)
assert (
manifest.cache.get_key(
json.dumps(
@ -257,9 +295,9 @@ def test_choices_run(sqlite_cache, session_cache, return_response):
result = manifest.run(prompt, gold_choices=choices, return_response=return_response)
if return_response:
assert isinstance(result, Response)
res = result.get_response(manifest.stop_token)
res = cast(Response, result).get_response(manifest.stop_token)
else:
res = result
res = cast(str, result)
assert (
manifest.cache.get_key(
json.dumps(
@ -285,9 +323,9 @@ def test_choices_run(sqlite_cache, session_cache, return_response):
)
if return_response:
assert isinstance(result, Response)
res = result.get_response(stop_token="ll")
res = cast(Response, result).get_response(stop_token="ll")
else:
res = result
res = cast(str, result)
assert (
manifest.cache.get_key(
json.dumps(
@ -303,19 +341,19 @@ def test_choices_run(sqlite_cache, session_cache, return_response):
)
assert res == "ca"
prompt = ["Hello is a prompt", "Hello is a prompt"]
prompt_lst = ["Hello is a prompt", "Hello is a prompt"]
choices = ["callt", "dog"]
result = manifest.run(
prompt,
prompt_lst,
gold_choices=choices,
stop_token="ll",
return_response=return_response,
)
if return_response:
assert isinstance(result, Response)
res = result.get_response(stop_token="ll", is_batch=True)
res = cast(Response, result).get_response(stop_token="ll", is_batch=True)
else:
res = result
res = cast(str, result)
assert (
manifest.cache.get_key(
json.dumps(
@ -333,7 +371,7 @@ def test_choices_run(sqlite_cache, session_cache, return_response):
@pytest.mark.usefixtures("session_cache")
def test_log_query(session_cache):
def test_log_query(session_cache: str) -> None:
"""Test manifest session logging."""
manifest = Manifest(client_name="dummy", cache_name="noop", session_id="_default")
prompt = "This is a prompt"
@ -361,8 +399,8 @@ def test_log_query(session_cache):
]
prior_cache_item = (query_key, response_key)
prompt = ["This is a prompt", "This is a prompt2"]
_ = manifest.run(prompt, return_response=False)
prompt_lst = ["This is a prompt", "This is a prompt2"]
_ = manifest.run(prompt_lst, return_response=False)
query_key = {
"prompt": ["This is a prompt", "This is a prompt2"],
"engine": "dummy",

@ -2,7 +2,7 @@
from manifest.request import DiffusionRequest, LMRequest
def test_llm_init():
def test_llm_init() -> None:
"""Test request initialization."""
request = LMRequest()
assert request.temperature == 0.7
@ -18,7 +18,7 @@ def test_llm_init():
assert request.prompt == "test"
def test_diff_init():
def test_diff_init() -> None:
"""Test request initialization."""
request = DiffusionRequest()
assert request.height == 512
@ -34,30 +34,30 @@ def test_diff_init():
assert request.prompt == "test"
def test_to_dict():
def test_to_dict() -> None:
"""Test request to dict."""
request = LMRequest()
dct = request.to_dict()
request_lm = LMRequest()
dct = request_lm.to_dict()
assert dct == {k: v for k, v in request.dict().items() if v is not None}
assert dct == {k: v for k, v in request_lm.dict().items() if v is not None}
# Note the second value is a placeholder for the default value
# It's unused in to_dict
keys = {"temperature": ("temp", 0.7)}
dct = request.to_dict(allowable_keys=keys)
dct = request_lm.to_dict(allowable_keys=keys)
assert dct == {"temp": 0.7, "prompt": ""}
dct = request.to_dict(allowable_keys=keys, add_prompt=False)
dct = request_lm.to_dict(allowable_keys=keys, add_prompt=False)
assert dct == {"temp": 0.7}
request = DiffusionRequest()
dct = request.to_dict()
request_diff = DiffusionRequest()
dct = request_diff.to_dict()
assert dct == {k: v for k, v in request.dict().items() if v is not None}
assert dct == {k: v for k, v in request_diff.dict().items() if v is not None}
keys = {"height": ("hgt", 512)}
dct = request.to_dict(allowable_keys=keys)
dct = request_diff.to_dict(allowable_keys=keys)
assert dct == {"hgt": 512, "prompt": ""}
dct = request.to_dict(allowable_keys=keys, add_prompt=False)
dct = request_diff.to_dict(allowable_keys=keys, add_prompt=False)
assert dct == {"hgt": 512}

@ -5,10 +5,10 @@ import pytest
from manifest import Response
def test_init():
def test_init() -> None:
"""Test response initialization."""
with pytest.raises(ValueError) as exc_info:
response = Response(4, False, {})
response = Response(4, False, {}) # type: ignore
assert str(exc_info.value) == "Response must be dict. Response is\n4."
with pytest.raises(ValueError) as exc_info:
response = Response({"test": "hello"}, False, {})
@ -64,7 +64,7 @@ def test_init():
assert response.item_dtype == "int64"
def test_getters():
def test_getters() -> None:
"""Test response cached."""
response = Response({"choices": [{"text": "hello"}]}, False, {})
assert response.get_json_response() == {"choices": [{"text": "hello"}]}
@ -85,7 +85,7 @@ def test_getters():
assert response.get_request() == {"request": "yoyo"}
def test_serialize():
def test_serialize() -> None:
"""Test response serialization."""
response = Response({"choices": [{"text": "hello"}]}, True, {"request": "yoyo"})
deserialized_response = Response.deserialize(response.serialize())
@ -116,7 +116,7 @@ def test_serialize():
assert deserialized_response._request_params == {"request": "yoyo"}
def test_get_results():
def test_get_results() -> None:
"""Test response get results."""
response = Response({"choices": [{"text": "hello"}]}, True, {"request": "yoyo"})
assert response.get_response() == "hello"

@ -1,12 +1,13 @@
"""Cache test."""
import json
from pathlib import Path
import numpy as np
from manifest.caches.serializers import ArraySerializer
def test_response_to_key(session_cache):
def test_response_to_key(session_cache: Path) -> None:
"""Test array serializer initialization."""
serializer = ArraySerializer()
arr = np.random.rand(4, 4)

@ -1,5 +1,6 @@
"""Test session."""
import sqlite3
from pathlib import Path
import pytest
@ -7,7 +8,7 @@ from manifest.session import Session
@pytest.mark.usefixtures("session_cache")
def test_init(session_cache):
def test_init(session_cache: Path) -> None:
"""Test session initialization."""
session = Session()
assert isinstance(session.conn, sqlite3.Connection)
@ -27,7 +28,7 @@ def test_init(session_cache):
@pytest.mark.usefixtures("session_cache")
def test_log_query(session_cache):
def test_log_query(session_cache: Path) -> None:
"""Test session log_query."""
session = Session()
assert session.get_last_queries(1) == []
@ -51,7 +52,7 @@ def test_log_query(session_cache):
@pytest.mark.usefixtures("session_cache")
def test_resume_query(session_cache):
def test_resume_query(session_cache: Path) -> None:
"""Test session log_query."""
session = Session(session_id="dog_days")
query_key = {"query": "What is your name?", "time": "now"}
@ -64,7 +65,7 @@ def test_resume_query(session_cache):
@pytest.mark.usefixtures("session_cache")
def test_session_keys(session_cache):
def test_session_keys(session_cache: Path) -> None:
"""Test get session keys."""
# Assert empty before queries
assert Session.get_session_keys(session_cache / ".manifest" / "session.db") == []

@ -0,0 +1,10 @@
## Running
In a separate tmux/terminal session, run
```
cd manifest
uvicorn web_app.main:app --reload
```
Change the port by ass `--port <PORT>`.

@ -0,0 +1 @@
"""Web application for Manifest."""

@ -0,0 +1,56 @@
"""Manifest as an app service."""
from typing import Any, Dict, cast
from fastapi import APIRouter, FastAPI, HTTPException
from manifest import Manifest
from manifest.response import Response as ManifestResponse
from web_app import schemas
app = FastAPI()
api_router = APIRouter()
@app.get("/")
async def root() -> Dict:
"""Root endpoint."""
return {"message": "Hello to the Manifest App"}
@api_router.post("/prompt/", status_code=201, response_model=schemas.ManifestResponse)
def prompt_manifest(*, manifest_in: schemas.ManifestCreate) -> Dict:
"""Prompt a manifest session and query."""
manifest = Manifest(
client_name=manifest_in.client_name,
client_connection=manifest_in.client_connection,
engine=manifest_in.engine,
cache_name=manifest_in.cache_name,
cache_connection=manifest_in.cache_connection,
)
manifest_prompt_args: Dict[str, Any] = {
"n": manifest_in.n,
"max_tokens": manifest_in.max_tokens,
}
if manifest_in.temperature:
manifest_prompt_args["temperature"] = manifest_in.temperature
if manifest_in.top_k:
manifest_prompt_args["top_k"] = manifest_in.top_k
if manifest_in.top_p:
manifest_prompt_args["top_p"] = manifest_in.top_p
try:
response = manifest.run(
prompt=manifest_in.prompt, return_response=True, **manifest_prompt_args
)
response = cast(ManifestResponse, response)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
return {
"response": response.get_response(),
"cached": response.is_cached(),
"request_params": response.get_request(),
}
app.include_router(api_router)

@ -0,0 +1,32 @@
"""Pydantic models."""
from typing import List, Optional, Union
from pydantic import BaseModel
class ManifestCreate(BaseModel):
"""Create manifest Pydantic."""
# Prompt params
prompt: str
n: int = 1
max_tokens: int = 132
temperature: Optional[float] = None
top_k: Optional[int] = None
top_p: Optional[float] = None
# Manifest client params
client_name: str = "openai"
client_connection: Optional[str] = None
engine: str = "text-davinci-003"
cache_name: str = "noop"
cache_connection: Optional[str] = None
class ManifestResponse(BaseModel):
"""Manifest response Pydantic."""
response: Union[str, List[str]]
cached: bool
request_params: dict
Loading…
Cancel
Save