|
|
|
@ -1,5 +1,6 @@
|
|
|
|
|
"""Manifest test."""
|
|
|
|
|
import json
|
|
|
|
|
from typing import cast
|
|
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
|
@ -12,7 +13,7 @@ from manifest.session import Session
|
|
|
|
|
|
|
|
|
|
@pytest.mark.usefixtures("sqlite_cache")
|
|
|
|
|
@pytest.mark.usefixtures("session_cache")
|
|
|
|
|
def test_init(sqlite_cache, session_cache):
|
|
|
|
|
def test_init(sqlite_cache: str, session_cache: str) -> None:
|
|
|
|
|
"""Test manifest initialization."""
|
|
|
|
|
with pytest.raises(ValueError) as exc_info:
|
|
|
|
|
Manifest(
|
|
|
|
@ -32,7 +33,7 @@ def test_init(sqlite_cache, session_cache):
|
|
|
|
|
assert isinstance(manifest.client, DummyClient)
|
|
|
|
|
assert isinstance(manifest.cache, SQLiteCache)
|
|
|
|
|
assert manifest.session is None
|
|
|
|
|
assert manifest.client.n == 1
|
|
|
|
|
assert manifest.client.n == 1 # type: ignore
|
|
|
|
|
assert manifest.stop_token == ""
|
|
|
|
|
|
|
|
|
|
manifest = Manifest(
|
|
|
|
@ -46,7 +47,34 @@ def test_init(sqlite_cache, session_cache):
|
|
|
|
|
assert isinstance(manifest.client, DummyClient)
|
|
|
|
|
assert isinstance(manifest.cache, NoopCache)
|
|
|
|
|
assert isinstance(manifest.session, Session)
|
|
|
|
|
assert manifest.client.n == 3
|
|
|
|
|
assert manifest.client.n == 3 # type: ignore
|
|
|
|
|
assert manifest.stop_token == "\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.usefixtures("sqlite_cache")
|
|
|
|
|
@pytest.mark.usefixtures("session_cache")
|
|
|
|
|
def test_change_manifest(sqlite_cache: str, session_cache: str) -> None:
|
|
|
|
|
"""Test manifest change."""
|
|
|
|
|
manifest = Manifest(
|
|
|
|
|
client_name="dummy",
|
|
|
|
|
cache_name="sqlite",
|
|
|
|
|
cache_connection=sqlite_cache,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
manifest.change_client()
|
|
|
|
|
assert manifest.client_name == "dummy"
|
|
|
|
|
assert isinstance(manifest.client, DummyClient)
|
|
|
|
|
assert isinstance(manifest.cache, SQLiteCache)
|
|
|
|
|
assert manifest.session is None
|
|
|
|
|
assert manifest.client.n == 1 # type: ignore
|
|
|
|
|
assert manifest.stop_token == ""
|
|
|
|
|
|
|
|
|
|
manifest.change_client(stop_token="\n")
|
|
|
|
|
assert manifest.client_name == "dummy"
|
|
|
|
|
assert isinstance(manifest.client, DummyClient)
|
|
|
|
|
assert isinstance(manifest.cache, SQLiteCache)
|
|
|
|
|
assert manifest.session is None
|
|
|
|
|
assert manifest.client.n == 1 # type: ignore
|
|
|
|
|
assert manifest.stop_token == "\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -54,7 +82,9 @@ def test_init(sqlite_cache, session_cache):
|
|
|
|
|
@pytest.mark.usefixtures("session_cache")
|
|
|
|
|
@pytest.mark.parametrize("n", [1, 2])
|
|
|
|
|
@pytest.mark.parametrize("return_response", [True, False])
|
|
|
|
|
def test_run(sqlite_cache, session_cache, n, return_response):
|
|
|
|
|
def test_run(
|
|
|
|
|
sqlite_cache: str, session_cache: str, n: int, return_response: bool
|
|
|
|
|
) -> None:
|
|
|
|
|
"""Test manifest run."""
|
|
|
|
|
manifest = Manifest(
|
|
|
|
|
client_name="dummy",
|
|
|
|
@ -77,9 +107,9 @@ def test_run(sqlite_cache, session_cache, n, return_response):
|
|
|
|
|
result = manifest.run(prompt, return_response=return_response)
|
|
|
|
|
if return_response:
|
|
|
|
|
assert isinstance(result, Response)
|
|
|
|
|
res = result.get_response(manifest.stop_token)
|
|
|
|
|
res = cast(Response, result).get_response(manifest.stop_token)
|
|
|
|
|
else:
|
|
|
|
|
res = result
|
|
|
|
|
res = cast(str, result)
|
|
|
|
|
assert (
|
|
|
|
|
manifest.cache.get_key(
|
|
|
|
|
json.dumps(
|
|
|
|
@ -102,9 +132,9 @@ def test_run(sqlite_cache, session_cache, n, return_response):
|
|
|
|
|
result = manifest.run(prompt, run_id="34", return_response=return_response)
|
|
|
|
|
if return_response:
|
|
|
|
|
assert isinstance(result, Response)
|
|
|
|
|
res = result.get_response(manifest.stop_token)
|
|
|
|
|
res = cast(Response, result).get_response(manifest.stop_token)
|
|
|
|
|
else:
|
|
|
|
|
res = result
|
|
|
|
|
res = cast(str, result)
|
|
|
|
|
assert (
|
|
|
|
|
manifest.cache.get_key(
|
|
|
|
|
json.dumps(
|
|
|
|
@ -128,9 +158,9 @@ def test_run(sqlite_cache, session_cache, n, return_response):
|
|
|
|
|
result = manifest.run(prompt, return_response=return_response)
|
|
|
|
|
if return_response:
|
|
|
|
|
assert isinstance(result, Response)
|
|
|
|
|
res = result.get_response(manifest.stop_token)
|
|
|
|
|
res = cast(Response, result).get_response(manifest.stop_token)
|
|
|
|
|
else:
|
|
|
|
|
res = result
|
|
|
|
|
res = cast(str, result)
|
|
|
|
|
assert (
|
|
|
|
|
manifest.cache.get_key(
|
|
|
|
|
json.dumps(
|
|
|
|
@ -153,9 +183,9 @@ def test_run(sqlite_cache, session_cache, n, return_response):
|
|
|
|
|
result = manifest.run(prompt, stop_token="ll", return_response=return_response)
|
|
|
|
|
if return_response:
|
|
|
|
|
assert isinstance(result, Response)
|
|
|
|
|
res = result.get_response(stop_token="ll")
|
|
|
|
|
res = cast(Response, result).get_response(stop_token="ll")
|
|
|
|
|
else:
|
|
|
|
|
res = result
|
|
|
|
|
res = cast(str, result)
|
|
|
|
|
assert (
|
|
|
|
|
manifest.cache.get_key(
|
|
|
|
|
json.dumps(
|
|
|
|
@ -179,7 +209,9 @@ def test_run(sqlite_cache, session_cache, n, return_response):
|
|
|
|
|
@pytest.mark.usefixtures("session_cache")
|
|
|
|
|
@pytest.mark.parametrize("n", [1, 2])
|
|
|
|
|
@pytest.mark.parametrize("return_response", [True, False])
|
|
|
|
|
def test_batch_run(sqlite_cache, session_cache, n, return_response):
|
|
|
|
|
def test_batch_run(
|
|
|
|
|
sqlite_cache: str, session_cache: str, n: int, return_response: bool
|
|
|
|
|
) -> None:
|
|
|
|
|
"""Test manifest run."""
|
|
|
|
|
manifest = Manifest(
|
|
|
|
|
client_name="dummy",
|
|
|
|
@ -195,32 +227,38 @@ def test_batch_run(sqlite_cache, session_cache, n, return_response):
|
|
|
|
|
else:
|
|
|
|
|
result = manifest.run(prompt, return_response=return_response)
|
|
|
|
|
if return_response:
|
|
|
|
|
res = result.get_response(manifest.stop_token, is_batch=True)
|
|
|
|
|
res = cast(Response, result).get_response(
|
|
|
|
|
manifest.stop_token, is_batch=True
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
res = result
|
|
|
|
|
res = cast(str, result)
|
|
|
|
|
assert res == ["hello"]
|
|
|
|
|
|
|
|
|
|
prompt = ["Hello is a prompt", "Hello is a prompt"]
|
|
|
|
|
result = manifest.run(prompt, return_response=return_response)
|
|
|
|
|
if return_response:
|
|
|
|
|
res = result.get_response(manifest.stop_token, is_batch=True)
|
|
|
|
|
res = cast(Response, result).get_response(
|
|
|
|
|
manifest.stop_token, is_batch=True
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
res = result
|
|
|
|
|
res = cast(str, result)
|
|
|
|
|
assert res == ["hello", "hello"]
|
|
|
|
|
|
|
|
|
|
prompt = ["Hello is a prompt", "Hello is a prompt"]
|
|
|
|
|
result = manifest.run(prompt, stop_token="ll", return_response=return_response)
|
|
|
|
|
if return_response:
|
|
|
|
|
res = result.get_response(stop_token="ll", is_batch=True)
|
|
|
|
|
res = cast(Response, result).get_response(stop_token="ll", is_batch=True)
|
|
|
|
|
else:
|
|
|
|
|
res = result
|
|
|
|
|
res = cast(str, result)
|
|
|
|
|
assert res == ["he", "he"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.usefixtures("sqlite_cache")
|
|
|
|
|
@pytest.mark.usefixtures("session_cache")
|
|
|
|
|
@pytest.mark.parametrize("return_response", [True, False])
|
|
|
|
|
def test_choices_run(sqlite_cache, session_cache, return_response):
|
|
|
|
|
def test_choices_run(
|
|
|
|
|
sqlite_cache: str, session_cache: str, return_response: bool
|
|
|
|
|
) -> None:
|
|
|
|
|
"""Test manifest run."""
|
|
|
|
|
manifest = Manifest(
|
|
|
|
|
client_name="dummy",
|
|
|
|
@ -234,9 +272,9 @@ def test_choices_run(sqlite_cache, session_cache, return_response):
|
|
|
|
|
result = manifest.run(prompt, gold_choices=choices, return_response=return_response)
|
|
|
|
|
if return_response:
|
|
|
|
|
assert isinstance(result, Response)
|
|
|
|
|
res = result.get_response(manifest.stop_token)
|
|
|
|
|
res = cast(Response, result).get_response(manifest.stop_token)
|
|
|
|
|
else:
|
|
|
|
|
res = result
|
|
|
|
|
res = cast(str, result)
|
|
|
|
|
assert (
|
|
|
|
|
manifest.cache.get_key(
|
|
|
|
|
json.dumps(
|
|
|
|
@ -257,9 +295,9 @@ def test_choices_run(sqlite_cache, session_cache, return_response):
|
|
|
|
|
result = manifest.run(prompt, gold_choices=choices, return_response=return_response)
|
|
|
|
|
if return_response:
|
|
|
|
|
assert isinstance(result, Response)
|
|
|
|
|
res = result.get_response(manifest.stop_token)
|
|
|
|
|
res = cast(Response, result).get_response(manifest.stop_token)
|
|
|
|
|
else:
|
|
|
|
|
res = result
|
|
|
|
|
res = cast(str, result)
|
|
|
|
|
assert (
|
|
|
|
|
manifest.cache.get_key(
|
|
|
|
|
json.dumps(
|
|
|
|
@ -285,9 +323,9 @@ def test_choices_run(sqlite_cache, session_cache, return_response):
|
|
|
|
|
)
|
|
|
|
|
if return_response:
|
|
|
|
|
assert isinstance(result, Response)
|
|
|
|
|
res = result.get_response(stop_token="ll")
|
|
|
|
|
res = cast(Response, result).get_response(stop_token="ll")
|
|
|
|
|
else:
|
|
|
|
|
res = result
|
|
|
|
|
res = cast(str, result)
|
|
|
|
|
assert (
|
|
|
|
|
manifest.cache.get_key(
|
|
|
|
|
json.dumps(
|
|
|
|
@ -303,19 +341,19 @@ def test_choices_run(sqlite_cache, session_cache, return_response):
|
|
|
|
|
)
|
|
|
|
|
assert res == "ca"
|
|
|
|
|
|
|
|
|
|
prompt = ["Hello is a prompt", "Hello is a prompt"]
|
|
|
|
|
prompt_lst = ["Hello is a prompt", "Hello is a prompt"]
|
|
|
|
|
choices = ["callt", "dog"]
|
|
|
|
|
result = manifest.run(
|
|
|
|
|
prompt,
|
|
|
|
|
prompt_lst,
|
|
|
|
|
gold_choices=choices,
|
|
|
|
|
stop_token="ll",
|
|
|
|
|
return_response=return_response,
|
|
|
|
|
)
|
|
|
|
|
if return_response:
|
|
|
|
|
assert isinstance(result, Response)
|
|
|
|
|
res = result.get_response(stop_token="ll", is_batch=True)
|
|
|
|
|
res = cast(Response, result).get_response(stop_token="ll", is_batch=True)
|
|
|
|
|
else:
|
|
|
|
|
res = result
|
|
|
|
|
res = cast(str, result)
|
|
|
|
|
assert (
|
|
|
|
|
manifest.cache.get_key(
|
|
|
|
|
json.dumps(
|
|
|
|
@ -333,7 +371,7 @@ def test_choices_run(sqlite_cache, session_cache, return_response):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.usefixtures("session_cache")
|
|
|
|
|
def test_log_query(session_cache):
|
|
|
|
|
def test_log_query(session_cache: str) -> None:
|
|
|
|
|
"""Test manifest session logging."""
|
|
|
|
|
manifest = Manifest(client_name="dummy", cache_name="noop", session_id="_default")
|
|
|
|
|
prompt = "This is a prompt"
|
|
|
|
@ -361,8 +399,8 @@ def test_log_query(session_cache):
|
|
|
|
|
]
|
|
|
|
|
prior_cache_item = (query_key, response_key)
|
|
|
|
|
|
|
|
|
|
prompt = ["This is a prompt", "This is a prompt2"]
|
|
|
|
|
_ = manifest.run(prompt, return_response=False)
|
|
|
|
|
prompt_lst = ["This is a prompt", "This is a prompt2"]
|
|
|
|
|
_ = manifest.run(prompt_lst, return_response=False)
|
|
|
|
|
query_key = {
|
|
|
|
|
"prompt": ["This is a prompt", "This is a prompt2"],
|
|
|
|
|
"engine": "dummy",
|
|
|
|
|