mirror of https://github.com/HazyResearch/manifest
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
307 lines
8.9 KiB
Python
307 lines
8.9 KiB
Python
"""Manifest test."""
|
|
import pytest
|
|
|
|
from manifest import Manifest, Prompt, Response
|
|
from manifest.caches.cache import request_to_key
|
|
from manifest.caches.noop import NoopCache
|
|
from manifest.caches.sqlite import SQLiteCache
|
|
from manifest.clients.dummy import DummyClient
|
|
from manifest.session import Session
|
|
|
|
|
|
@pytest.mark.usefixtures("sqlite_cache")
|
|
@pytest.mark.usefixtures("session_cache")
|
|
def test_init(sqlite_cache, session_cache):
|
|
"""Test manifest initialization."""
|
|
with pytest.raises(ValueError) as exc_info:
|
|
Manifest(
|
|
client_name="dummy",
|
|
cache_name="sqlite",
|
|
cache_connection=sqlite_cache,
|
|
sep_tok="",
|
|
)
|
|
assert str(exc_info.value) == "[('sep_tok', '')] arguments are not recognized."
|
|
|
|
manifest = Manifest(
|
|
client_name="dummy",
|
|
cache_name="sqlite",
|
|
cache_connection=sqlite_cache,
|
|
)
|
|
assert manifest.client_name == "dummy"
|
|
assert isinstance(manifest.client, DummyClient)
|
|
assert isinstance(manifest.cache, SQLiteCache)
|
|
assert isinstance(manifest.session, Session)
|
|
assert manifest.client.n == 1
|
|
assert manifest.stop_token == ""
|
|
|
|
manifest = Manifest(
|
|
client_name="dummy",
|
|
cache_name="noop",
|
|
n=3,
|
|
stop_token="\n",
|
|
)
|
|
assert manifest.client_name == "dummy"
|
|
assert isinstance(manifest.client, DummyClient)
|
|
assert isinstance(manifest.cache, NoopCache)
|
|
assert isinstance(manifest.session, Session)
|
|
assert manifest.client.n == 3
|
|
assert manifest.stop_token == "\n"
|
|
|
|
|
|
@pytest.mark.usefixtures("sqlite_cache")
|
|
@pytest.mark.usefixtures("session_cache")
|
|
@pytest.mark.parametrize("n", [1, 2])
|
|
@pytest.mark.parametrize("return_response", [True, False])
|
|
def test_run(sqlite_cache, session_cache, n, return_response):
|
|
"""Test manifest run."""
|
|
manifest = Manifest(
|
|
client_name="dummy",
|
|
cache_name="sqlite",
|
|
cache_connection=sqlite_cache,
|
|
n=n,
|
|
)
|
|
|
|
prompt = Prompt("This is a prompt")
|
|
with pytest.raises(ValueError) as exc_info:
|
|
result = manifest.run(prompt, return_response=return_response, bad_input=5)
|
|
assert str(exc_info.value) == "[('bad_input', 5)] arguments are not recognized."
|
|
|
|
prompt = Prompt("This is a prompt")
|
|
result = manifest.run(prompt, return_response=return_response)
|
|
if return_response:
|
|
assert isinstance(result, Response)
|
|
res = result.get_response(manifest.stop_token)
|
|
else:
|
|
res = result
|
|
assert (
|
|
manifest.cache.get_key(
|
|
request_to_key(
|
|
{
|
|
"prompt": "This is a prompt",
|
|
"client_name": "dummy",
|
|
"num_results": n,
|
|
}
|
|
)
|
|
)
|
|
is not None
|
|
)
|
|
if n == 1:
|
|
assert res == "hello"
|
|
else:
|
|
assert res == ["hello", "hello"]
|
|
|
|
prompt = Prompt(lambda x: f"{x} is a prompt")
|
|
result = manifest.run(prompt, "Hello", return_response=return_response)
|
|
if return_response:
|
|
assert isinstance(result, Response)
|
|
res = result.get_response(manifest.stop_token)
|
|
else:
|
|
res = result
|
|
assert (
|
|
manifest.cache.get_key(
|
|
request_to_key(
|
|
{
|
|
"prompt": "Hello is a prompt",
|
|
"client_name": "dummy",
|
|
"num_results": n,
|
|
}
|
|
)
|
|
)
|
|
is not None
|
|
)
|
|
if n == 1:
|
|
assert res == "hello"
|
|
else:
|
|
assert res == ["hello", "hello"]
|
|
|
|
prompt = Prompt(lambda x: f"{x} is a prompt")
|
|
result = manifest.run(
|
|
prompt, "Hello", stop_token="ll", return_response=return_response
|
|
)
|
|
if return_response:
|
|
assert isinstance(result, Response)
|
|
res = result.get_response(stop_token="ll")
|
|
else:
|
|
res = result
|
|
assert (
|
|
manifest.cache.get_key(
|
|
request_to_key(
|
|
{
|
|
"prompt": "Hello is a prompt",
|
|
"client_name": "dummy",
|
|
"num_results": n,
|
|
}
|
|
)
|
|
)
|
|
is not None
|
|
)
|
|
if n == 1:
|
|
assert res == "he"
|
|
else:
|
|
assert res == ["he", "he"]
|
|
|
|
|
|
@pytest.mark.usefixtures("sqlite_cache")
|
|
@pytest.mark.usefixtures("session_cache")
|
|
@pytest.mark.parametrize("n", [1, 2])
|
|
@pytest.mark.parametrize("return_response", [True, False])
|
|
def test_batch_run(sqlite_cache, session_cache, n, return_response):
|
|
"""Test manifest run."""
|
|
manifest = Manifest(
|
|
client_name="dummy",
|
|
cache_name="sqlite",
|
|
cache_connection=sqlite_cache,
|
|
n=n,
|
|
)
|
|
prompt = Prompt("This is a prompt")
|
|
result = manifest.run_batch(prompt, return_response=return_response)
|
|
if return_response:
|
|
res = [r.get_response(manifest.stop_token) for r in result]
|
|
else:
|
|
res = result
|
|
if n == 1:
|
|
assert res == ["hello"]
|
|
else:
|
|
assert res == [["hello", "hello"]]
|
|
|
|
prompt = Prompt(lambda x: f"{x} is a prompt")
|
|
result = manifest.run_batch(
|
|
prompt, ["Hello", "Hello"], return_response=return_response
|
|
)
|
|
if return_response:
|
|
res = [r.get_response(manifest.stop_token) for r in result]
|
|
else:
|
|
res = result
|
|
if n == 1:
|
|
assert res == ["hello", "hello"]
|
|
else:
|
|
assert res == [["hello", "hello"], ["hello", "hello"]]
|
|
|
|
prompt = Prompt(lambda x: f"{x} is a prompt")
|
|
result = manifest.run_batch(
|
|
prompt, ["Hello", "Hello"], stop_token="ll", return_response=return_response
|
|
)
|
|
if return_response:
|
|
res = [r.get_response(stop_token="ll") for r in result]
|
|
else:
|
|
res = result
|
|
if n == 1:
|
|
assert res == ["he", "he"]
|
|
else:
|
|
assert res == [["he", "he"], ["he", "he"]]
|
|
|
|
|
|
@pytest.mark.usefixtures("sqlite_cache")
|
|
@pytest.mark.usefixtures("session_cache")
|
|
@pytest.mark.parametrize("return_response", [True, False])
|
|
def test_choices_run(sqlite_cache, session_cache, return_response):
|
|
"""Test manifest run."""
|
|
manifest = Manifest(
|
|
client_name="dummy",
|
|
cache_name="sqlite",
|
|
cache_connection=sqlite_cache,
|
|
)
|
|
|
|
prompt = Prompt("This is a prompt")
|
|
# Dummy client will always return first choice
|
|
choices = ["cat", "dog"]
|
|
result = manifest.run(prompt, gold_choices=choices, return_response=return_response)
|
|
if return_response:
|
|
assert isinstance(result, Response)
|
|
res = result.get_response(manifest.stop_token)
|
|
else:
|
|
res = result
|
|
assert (
|
|
manifest.cache.get_key(
|
|
request_to_key(
|
|
{
|
|
"prompt": "This is a prompt",
|
|
"gold_choices": ["cat", "dog"],
|
|
"client_name": "dummy",
|
|
}
|
|
)
|
|
)
|
|
is not None
|
|
)
|
|
assert res == "cat"
|
|
|
|
prompt = Prompt(lambda x: f"{x} is a prompt")
|
|
choices = ["cat", "dog"]
|
|
result = manifest.run(
|
|
prompt, "Hello", gold_choices=choices, return_response=return_response
|
|
)
|
|
if return_response:
|
|
assert isinstance(result, Response)
|
|
res = result.get_response(manifest.stop_token)
|
|
else:
|
|
res = result
|
|
assert (
|
|
manifest.cache.get_key(
|
|
request_to_key(
|
|
{
|
|
"prompt": "Hello is a prompt",
|
|
"gold_choices": ["cat", "dog"],
|
|
"client_name": "dummy",
|
|
}
|
|
)
|
|
)
|
|
is not None
|
|
)
|
|
assert res == "cat"
|
|
|
|
prompt = Prompt(lambda x: f"{x} is a prompt")
|
|
choices = ["callt", "dog"]
|
|
result = manifest.run(
|
|
prompt,
|
|
"Hello",
|
|
gold_choices=choices,
|
|
stop_token="ll",
|
|
return_response=return_response,
|
|
)
|
|
if return_response:
|
|
assert isinstance(result, Response)
|
|
res = result.get_response(stop_token="ll")
|
|
else:
|
|
res = result
|
|
assert (
|
|
manifest.cache.get_key(
|
|
request_to_key(
|
|
{
|
|
"prompt": "Hello is a prompt",
|
|
"gold_choices": ["cat", "dog"],
|
|
"client_name": "dummy",
|
|
}
|
|
)
|
|
)
|
|
is not None
|
|
)
|
|
assert res == "ca"
|
|
|
|
|
|
@pytest.mark.usefixtures("session_cache")
|
|
def test_log_query(session_cache):
|
|
"""Test manifest session logging."""
|
|
manifest = Manifest(
|
|
client_name="dummy",
|
|
cache_name="noop",
|
|
)
|
|
prompt = Prompt("This is a prompt")
|
|
_ = manifest.run(prompt, return_response=False)
|
|
query_key = {
|
|
"prompt": "This is a prompt",
|
|
"client_name": "dummy",
|
|
"num_results": 1,
|
|
}
|
|
response_key = {
|
|
"cached": False,
|
|
"request_params": query_key,
|
|
"response": {"choices": [{"text": "hello"}]},
|
|
}
|
|
assert manifest.get_last_queries(1) == [("This is a prompt", "hello")]
|
|
assert manifest.get_last_queries(1, return_raw_values=True) == [
|
|
(query_key, response_key)
|
|
]
|
|
assert manifest.get_last_queries(3, return_raw_values=True) == [
|
|
(query_key, response_key)
|
|
]
|