You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
manifest/tests/test_client.py

74 lines
2.4 KiB
Python

"""
Test client.
We just test the dummy client.
"""
from manifest.clients.dummy import DummyClient
def test_init() -> None:
"""Test client initialization."""
client = DummyClient(connection_str=None)
assert client.n == 1 # type: ignore
args = {"n": 3}
client = DummyClient(connection_str=None, client_args=args)
assert client.n == 3 # type: ignore
def test_get_params() -> None:
"""Test get param functions."""
client = DummyClient(connection_str=None)
assert client.get_model_params() == {"engine": "dummy"}
assert client.get_model_inputs() == ["n"]
def test_get_request() -> None:
"""Test client get request."""
args = {"n": 3}
client = DummyClient(connection_str=None, client_args=args)
request_params = client.get_request("hello", {})
response = client.run_request(request_params)
assert client.get_cache_key(request_params) == {
"prompt": "hello",
"num_results": 3,
"engine": "dummy",
"request_cls": "LMRequest",
}
assert response.get_json_response() == {
"choices": [{"text": "hello", "token_logprobs": None, "tokens": None}] * 3,
}
assert response.get_usage_obj().dict() == {
"usages": [{"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}] * 3,
}
request_params = client.get_request("hello", {"n": 5})
response = client.run_request(request_params)
assert client.get_cache_key(request_params) == {
"prompt": "hello",
"num_results": 5,
"engine": "dummy",
"request_cls": "LMRequest",
}
assert response.get_json_response() == {
"choices": [{"text": "hello", "token_logprobs": None, "tokens": None}] * 5,
}
assert response.get_usage_obj().dict() == {
"usages": [{"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}] * 5,
}
request_params = client.get_request(["hello"] * 5, {"n": 1})
response = client.run_request(request_params)
assert client.get_cache_key(request_params) == {
"prompt": ["hello"] * 5,
"num_results": 1,
"engine": "dummy",
"request_cls": "LMRequest",
}
assert response.get_json_response() == {
"choices": [{"text": "hello", "token_logprobs": None, "tokens": None}] * 5,
}
assert response.get_usage_obj().dict() == {
"usages": [{"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}] * 5,
}