mirror of https://github.com/HazyResearch/manifest
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
37 lines
1.1 KiB
Python
37 lines
1.1 KiB
Python
"""
|
|
Test client.
|
|
|
|
We just test the dummy client as we don't want to load a model or use OpenAI tokens.
|
|
"""
|
|
from manifest.clients.dummy import DummyClient
|
|
|
|
|
|
def test_init():
|
|
"""Test client initialization."""
|
|
client = DummyClient(connection_str=None)
|
|
assert client.n == 1
|
|
|
|
args = {"n": 3}
|
|
client = DummyClient(connection_str=None, client_args=args)
|
|
assert client.n == 3
|
|
|
|
|
|
def test_get_params():
|
|
"""Test get param functions."""
|
|
client = DummyClient(connection_str=None)
|
|
assert client.get_model_params() == {"engine": "dummy"}
|
|
assert client.get_model_inputs() == ["n"]
|
|
|
|
|
|
def test_get_request():
|
|
"""Test client get request."""
|
|
args = {"n": 3}
|
|
client = DummyClient(connection_str=None, client_args=args)
|
|
request_func, request_params = client.get_request("hello")
|
|
assert request_params == {"prompt": "hello", "num_results": 3}
|
|
assert request_func() == {"choices": [{"text": "hello"}] * 3}
|
|
|
|
request_func, request_params = client.get_request("hello", {"n": 5})
|
|
assert request_params == {"prompt": "hello", "num_results": 5}
|
|
assert request_func() == {"choices": [{"text": "hello"}] * 5}
|