fix: dummy client to output tokens and random responses (#106)

pull/109/head
Laurel Orr 11 months ago committed by GitHub
parent b775d15f2e
commit 49f51952df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,6 +1,10 @@
"""Dummy client.""" """Dummy client."""
import hashlib
import logging import logging
from typing import Any, Dict, Optional from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import tiktoken
from manifest.clients.client import Client from manifest.clients.client import Client
from manifest.request import LMChatRequest, LMRequest, LMScoreRequest, Request from manifest.request import LMChatRequest, LMRequest, LMScoreRequest, Request
@ -14,7 +18,13 @@ class DummyClient(Client):
# User param -> (client param, default value) # User param -> (client param, default value)
PARAMS = { PARAMS = {
"n": ("num_results", 1), "engine": ("model", "text-davinci-003"),
"temperature": ("temperature", 0.0),
"max_tokens": ("max_tokens", 10),
"n": ("n", 1),
"top_p": ("top_p", 1.0),
"top_k": ("best_of", 1),
"batch_size": ("batch_size", 20),
} }
REQUEST_CLS = LMRequest REQUEST_CLS = LMRequest
NAME = "dummy" NAME = "dummy"
@ -33,6 +43,9 @@ class DummyClient(Client):
connection_str: connection string. connection_str: connection string.
client_args: client arguments. client_args: client arguments.
""" """
# We tiktoken as it is faster than HF for tokenizing
# Use any model to create the tokenizer
self.encoder = tiktoken.get_encoding("cl100k_base")
for key in self.PARAMS: for key in self.PARAMS:
setattr(self, key, client_args.pop(key, self.PARAMS[key][1])) setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
@ -74,7 +87,65 @@ class DummyClient(Client):
Returns: Returns:
model params. model params.
""" """
return {"engine": "dummy"} return {"engine": "dummy", "model": getattr(self, "engine")}
def get_mock_output(
self, output_toks: int, is_completion: bool, seed: Optional[int] = None
) -> LMModelChoice:
"""Return mock model output by generating random tokens."""
np.random.seed(seed)
random_tokens = np.random.randint(
0, self.encoder.max_token_value + 1, output_toks
)
response = self.encoder.decode(random_tokens) # type: ignore
if is_completion:
np.random.seed(seed)
random_logprobs = np.random.uniform(
low=-2, high=-0.00001, size=output_toks
).tolist()
else:
# Return all Nones to mimic chat models
# OpenAI chat models do not return logprobs
random_logprobs = [None] * output_toks
return LMModelChoice(
text=response,
token_logprobs=random_logprobs,
tokens=random_tokens.tolist(),
)
def get_mock_choices(
self,
prompt_list: List[str],
request_params: Dict,
is_completion: bool,
) -> Tuple[List[LMModelChoice], List[Usage]]:
"""Get choices and usages of mock output."""
choices = []
usages = []
for prompt in prompt_list:
num_prompt_tokens = len(self.encoder.encode(prompt))
if request_params["temperature"] == 0:
# Get integer seed from hash of prompt
seed = (
int(hashlib.sha256(prompt.encode("utf-8")).hexdigest(), 16)
% 10**8
)
else:
# Get random seed
seed = None
for _ in range(int(request_params["n"])):
choice = self.get_mock_output(
request_params["max_tokens"], is_completion=is_completion, seed=seed
)
choices.append(choice)
usages.append(
Usage(
prompt_tokens=num_prompt_tokens,
completion_tokens=request_params["max_tokens"],
total_tokens=num_prompt_tokens + request_params["max_tokens"],
)
)
return choices, usages
def run_request(self, request: Request) -> Response: def run_request(self, request: Request) -> Response:
""" """
@ -88,32 +159,19 @@ class DummyClient(Client):
request parameters as dict. request parameters as dict.
""" """
if isinstance(request.prompt, list): if isinstance(request.prompt, list):
num_results = len(request.prompt) prompt_list = request.prompt
else: else:
num_results = 1 prompt_list = [request.prompt]
request_params = request.to_dict(self.PARAMS) request_params = request.to_dict(self.PARAMS)
choices, usages = self.get_mock_choices(
prompt_list, request_params, is_completion=True
)
return Response( return Response(
response=ModelChoices( response=ModelChoices(choices=choices), # type: ignore
choices=[LMModelChoice(text="hello")] # type: ignore
* int(request_params["num_results"])
* num_results
),
cached=False, cached=False,
request=request, request=request,
usages=Usages( usages=Usages(usages=usages),
usages=[
Usage(
**{
"prompt_tokens": 1,
"completion_tokens": 1,
"total_tokens": 2,
}
)
]
* int(request_params["num_results"])
* num_results
),
response_type="text", response_type="text",
request_type=self.REQUEST_CLS, request_type=self.REQUEST_CLS,
) )
@ -145,35 +203,17 @@ class DummyClient(Client):
Returns: Returns:
response. response.
""" """
num_results = 1 prompt_list = ["_".join(pmp["content"] for pmp in request.prompt)]
response_dict = { request_params = request.to_dict(self.PARAMS)
"choices": [
{ choices, usages = self.get_mock_choices(
"text": request.prompt[0]["content"], prompt_list, request_params, is_completion=False
} )
for i in range(num_results)
]
}
return Response( return Response(
response=ModelChoices( response=ModelChoices(choices=choices), # type: ignore
choices=[
LMModelChoice(**choice) # type: ignore
for choice in response_dict["choices"]
]
),
cached=False, cached=False,
request=request, request=request,
usages=Usages( usages=Usages(usages=usages),
usages=[
Usage(
**{
"prompt_tokens": 1,
"completion_tokens": 1,
"total_tokens": 2,
}
)
]
),
response_type="text", response_type="text",
request_type=LMChatRequest, request_type=LMChatRequest,
) )
@ -193,30 +233,19 @@ class DummyClient(Client):
request parameters as dict. request parameters as dict.
""" """
if isinstance(request.prompt, list): if isinstance(request.prompt, list):
num_results = len(request.prompt) prompt_list = request.prompt
else: else:
num_results = 1 prompt_list = [request.prompt]
response_dict = { request_params = request.to_dict(self.PARAMS)
"choices": [
{ choices, usages = self.get_mock_choices(
"text": request.prompt prompt_list, request_params, is_completion=True
if isinstance(request.prompt, str) )
else request.prompt[i],
"token_logprobs": [0.3],
}
for i in range(num_results)
]
}
return Response( return Response(
response=ModelChoices( response=ModelChoices(choices=choices), # type: ignore
choices=[
LMModelChoice(**choice) # type: ignore
for choice in response_dict["choices"]
]
),
cached=False, cached=False,
request=request, request=request,
usages=None, usages=Usages(usages=usages),
response_type="text", response_type="text",
request_type=LMScoreRequest, request_type=LMScoreRequest,
) )

@ -53,7 +53,7 @@ class LMModelChoice(BaseModel):
"""Model single completion.""" """Model single completion."""
text: str text: str
token_logprobs: Optional[List[float]] = None token_logprobs: Optional[List[Optional[float]]] = None
tokens: Optional[List[str]] = None tokens: Optional[List[str]] = None

@ -19,8 +19,19 @@ def test_init() -> None:
def test_get_params() -> None: def test_get_params() -> None:
"""Test get param functions.""" """Test get param functions."""
client = DummyClient(connection_str=None) client = DummyClient(connection_str=None)
assert client.get_model_params() == {"engine": "dummy"} assert client.get_model_params() == {
assert client.get_model_inputs() == ["n"] "engine": "dummy",
"model": "text-davinci-003",
}
assert client.get_model_inputs() == [
"engine",
"temperature",
"max_tokens",
"n",
"top_p",
"top_k",
"batch_size",
]
def test_get_request() -> None: def test_get_request() -> None:
@ -31,43 +42,148 @@ def test_get_request() -> None:
response = client.run_request(request_params) response = client.run_request(request_params)
assert client.get_cache_key(request_params) == { assert client.get_cache_key(request_params) == {
"prompt": "hello", "prompt": "hello",
"num_results": 3, "model": "text-davinci-003",
"n": 3,
"temperature": 0.0,
"max_tokens": 10,
"top_p": 1.0,
"best_of": 1,
"engine": "dummy", "engine": "dummy",
"request_cls": "LMRequest", "request_cls": "LMRequest",
} }
assert response.get_json_response() == { assert response.get_json_response() == {
"choices": [{"text": "hello", "token_logprobs": None, "tokens": None}] * 3, "choices": [
{
"text": " probsuib.FirstName>- commodityting segunda inserted signals Religious", # noqa: E501
"token_logprobs": [
-0.2649905035732101,
-1.210794839387105,
-1.2173929801003434,
-0.7758233850171001,
-0.7165940659570416,
-1.7430328887209088,
-1.5379414228820203,
-1.7838011423472508,
-1.139095076944217,
-0.6321855879833425,
],
"tokens": [
"70470",
"80723",
"52693",
"39743",
"38983",
"1303",
"56072",
"22306",
"17738",
"53176",
],
}
]
* 3
} }
assert response.get_usage_obj().dict() == { assert response.get_usage_obj().dict() == {
"usages": [{"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}] * 3, "usages": [{"prompt_tokens": 1, "completion_tokens": 10, "total_tokens": 11}]
* 3,
} }
request_params = client.get_request("hello", {"n": 5}) request_params = client.get_request("hello", {"n": 5})
response = client.run_request(request_params) response = client.run_request(request_params)
assert client.get_cache_key(request_params) == { assert client.get_cache_key(request_params) == {
"prompt": "hello", "prompt": "hello",
"num_results": 5, "model": "text-davinci-003",
"n": 5,
"temperature": 0.0,
"max_tokens": 10,
"top_p": 1.0,
"best_of": 1,
"engine": "dummy", "engine": "dummy",
"request_cls": "LMRequest", "request_cls": "LMRequest",
} }
assert response.get_json_response() == { assert response.get_json_response() == {
"choices": [{"text": "hello", "token_logprobs": None, "tokens": None}] * 5, "choices": [
{
"text": " probsuib.FirstName>- commodityting segunda inserted signals Religious", # noqa: E501
"token_logprobs": [
-0.2649905035732101,
-1.210794839387105,
-1.2173929801003434,
-0.7758233850171001,
-0.7165940659570416,
-1.7430328887209088,
-1.5379414228820203,
-1.7838011423472508,
-1.139095076944217,
-0.6321855879833425,
],
"tokens": [
"70470",
"80723",
"52693",
"39743",
"38983",
"1303",
"56072",
"22306",
"17738",
"53176",
],
}
]
* 5
} }
assert response.get_usage_obj().dict() == { assert response.get_usage_obj().dict() == {
"usages": [{"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}] * 5, "usages": [{"prompt_tokens": 1, "completion_tokens": 10, "total_tokens": 11}]
* 5,
} }
request_params = client.get_request(["hello"] * 5, {"n": 1}) request_params = client.get_request(["hello"] * 5, {"n": 1})
response = client.run_request(request_params) response = client.run_request(request_params)
assert client.get_cache_key(request_params) == { assert client.get_cache_key(request_params) == {
"prompt": ["hello"] * 5, "prompt": ["hello"] * 5,
"num_results": 1, "model": "text-davinci-003",
"n": 1,
"temperature": 0.0,
"max_tokens": 10,
"top_p": 1.0,
"best_of": 1,
"engine": "dummy", "engine": "dummy",
"request_cls": "LMRequest", "request_cls": "LMRequest",
} }
assert response.get_json_response() == { assert response.get_json_response() == {
"choices": [{"text": "hello", "token_logprobs": None, "tokens": None}] * 5, "choices": [
{
"text": " probsuib.FirstName>- commodityting segunda inserted signals Religious", # noqa: E501
"token_logprobs": [
-0.2649905035732101,
-1.210794839387105,
-1.2173929801003434,
-0.7758233850171001,
-0.7165940659570416,
-1.7430328887209088,
-1.5379414228820203,
-1.7838011423472508,
-1.139095076944217,
-0.6321855879833425,
],
"tokens": [
"70470",
"80723",
"52693",
"39743",
"38983",
"1303",
"56072",
"22306",
"17738",
"53176",
],
}
]
* 5
} }
assert response.get_usage_obj().dict() == { assert response.get_usage_obj().dict() == {
"usages": [{"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}] * 5, "usages": [{"prompt_tokens": 1, "completion_tokens": 10, "total_tokens": 11}]
* 5,
} }

@ -73,6 +73,7 @@ def test_run(sqlite_cache: str, n: int, return_response: bool) -> None:
cache_name="sqlite", cache_name="sqlite",
cache_connection=sqlite_cache, cache_connection=sqlite_cache,
n=n, n=n,
temperature=0.0,
) )
prompt = "This is a prompt" prompt = "This is a prompt"
@ -80,8 +81,6 @@ def test_run(sqlite_cache: str, n: int, return_response: bool) -> None:
result = manifest.run(prompt, return_response=return_response, bad_input=5) result = manifest.run(prompt, return_response=return_response, bad_input=5)
assert str(exc_info.value) == "[('bad_input', 5)] arguments are not recognized." assert str(exc_info.value) == "[('bad_input', 5)] arguments are not recognized."
# Allow params in the request object but not in the client to go through
assert "top_k" not in manifest.client_pool.get_next_client().PARAMS
result = manifest.run(prompt, return_response=return_response, top_k=5) result = manifest.run(prompt, return_response=return_response, top_k=5)
assert result is not None assert result is not None
@ -96,21 +95,30 @@ def test_run(sqlite_cache: str, n: int, return_response: bool) -> None:
res = result.get_response(manifest.stop_token) res = result.get_response(manifest.stop_token)
else: else:
res = cast(str, result) res = cast(str, result)
assert ( assert (
manifest.cache.get( manifest.cache.get(
{ {
"prompt": "This is a prompt", "best_of": 1,
"engine": "dummy", "engine": "dummy",
"max_tokens": 10,
"model": "text-davinci-003",
"n": n,
"prompt": "This is a prompt",
"request_cls": "LMRequest", "request_cls": "LMRequest",
"num_results": n, "temperature": 0.0,
}, "top_p": 1.0,
}
) )
is not None is not None
) )
if n == 1: if n == 1:
assert res == "hello" assert res == "Nice Employ NFCYouryms“Inwarn\ttemplate europ Moines"
else: else:
assert res == ["hello", "hello"] assert res == [
"Nice Employ NFCYouryms“Inwarn\ttemplate europ Moines",
"Nice Employ NFCYouryms“Inwarn\ttemplate europ Moines",
]
prompt = "This is a prompt" prompt = "This is a prompt"
result = manifest.run(prompt, run_id="34", return_response=return_response) result = manifest.run(prompt, run_id="34", return_response=return_response)
@ -126,19 +134,27 @@ def test_run(sqlite_cache: str, n: int, return_response: bool) -> None:
assert ( assert (
manifest.cache.get( manifest.cache.get(
{ {
"prompt": "This is a prompt", "best_of": 1,
"engine": "dummy", "engine": "dummy",
"max_tokens": 10,
"model": "text-davinci-003",
"n": n,
"prompt": "This is a prompt",
"request_cls": "LMRequest", "request_cls": "LMRequest",
"num_results": n, "temperature": 0.0,
"top_p": 1.0,
"run_id": "34", "run_id": "34",
} }
) )
is not None is not None
) )
if n == 1: if n == 1:
assert res == "hello" assert res == "Nice Employ NFCYouryms“Inwarn\ttemplate europ Moines"
else: else:
assert res == ["hello", "hello"] assert res == [
"Nice Employ NFCYouryms“Inwarn\ttemplate europ Moines",
"Nice Employ NFCYouryms“Inwarn\ttemplate europ Moines",
]
prompt = "Hello is a prompt" prompt = "Hello is a prompt"
result = manifest.run(prompt, return_response=return_response) result = manifest.run(prompt, return_response=return_response)
@ -154,45 +170,60 @@ def test_run(sqlite_cache: str, n: int, return_response: bool) -> None:
assert ( assert (
manifest.cache.get( manifest.cache.get(
{ {
"prompt": "Hello is a prompt", "best_of": 1,
"engine": "dummy", "engine": "dummy",
"max_tokens": 10,
"model": "text-davinci-003",
"n": n,
"prompt": "Hello is a prompt",
"request_cls": "LMRequest", "request_cls": "LMRequest",
"num_results": n, "temperature": 0.0,
}, "top_p": 1.0,
}
) )
is not None is not None
) )
if n == 1: if n == 1:
assert res == "hello" assert res == "appersstoff210 currentNodeleh norm unified_voice DIYHam"
else: else:
assert res == ["hello", "hello"] assert res == [
"appersstoff210 currentNodeleh norm unified_voice DIYHam",
"appersstoff210 currentNodeleh norm unified_voice DIYHam",
]
prompt = "Hello is a prompt" prompt = "Hello is a prompt"
result = manifest.run(prompt, stop_token="ll", return_response=return_response) result = manifest.run(
prompt, stop_token=" current", return_response=return_response
)
if return_response: if return_response:
assert isinstance(result, Response) assert isinstance(result, Response)
result = cast(Response, result) result = cast(Response, result)
assert len(result.get_usage_obj().usages) == len( assert len(result.get_usage_obj().usages) == len(
result.get_response_obj().choices result.get_response_obj().choices
) )
res = result.get_response(stop_token="ll") res = result.get_response(stop_token=" current")
else: else:
res = cast(str, result) res = cast(str, result)
assert ( assert (
manifest.cache.get( manifest.cache.get(
{ {
"prompt": "Hello is a prompt", "best_of": 1,
"engine": "dummy", "engine": "dummy",
"max_tokens": 10,
"model": "text-davinci-003",
"n": n,
"prompt": "Hello is a prompt",
"request_cls": "LMRequest", "request_cls": "LMRequest",
"num_results": n, "temperature": 0.0,
}, "top_p": 1.0,
}
) )
is not None is not None
) )
if n == 1: if n == 1:
assert res == "he" assert res == "appersstoff210"
else: else:
assert res == ["he", "he"] assert res == ["appersstoff210", "appersstoff210"]
@pytest.mark.usefixtures("sqlite_cache") @pytest.mark.usefixtures("sqlite_cache")
@ -205,6 +236,7 @@ def test_batch_run(sqlite_cache: str, n: int, return_response: bool) -> None:
cache_name="sqlite", cache_name="sqlite",
cache_connection=sqlite_cache, cache_connection=sqlite_cache,
n=n, n=n,
temperature=0.0,
) )
prompt = ["This is a prompt"] prompt = ["This is a prompt"]
if n == 2: if n == 2:
@ -222,15 +254,20 @@ def test_batch_run(sqlite_cache: str, n: int, return_response: bool) -> None:
res = result.get_response(manifest.stop_token, is_batch=True) res = result.get_response(manifest.stop_token, is_batch=True)
else: else:
res = cast(str, result) res = cast(str, result)
assert res == ["hello"] assert res == ["Nice Employ NFCYouryms“Inwarn\ttemplate europ Moines"]
assert ( assert (
manifest.cache.get( manifest.cache.get(
{ {
"prompt": "This is a prompt", "best_of": 1,
"engine": "dummy", "engine": "dummy",
"max_tokens": 10,
"model": "text-davinci-003",
"n": n,
"prompt": "This is a prompt",
"request_cls": "LMRequest", "request_cls": "LMRequest",
"num_results": n, "temperature": 0.0,
}, "top_p": 1.0,
}
) )
is not None is not None
) )
@ -246,15 +283,23 @@ def test_batch_run(sqlite_cache: str, n: int, return_response: bool) -> None:
res = result.get_response(manifest.stop_token, is_batch=True) res = result.get_response(manifest.stop_token, is_batch=True)
else: else:
res = cast(str, result) res = cast(str, result)
assert res == ["hello", "hello"] assert res == [
"appersstoff210 currentNodeleh norm unified_voice DIYHam",
"appersstoff210 currentNodeleh norm unified_voice DIYHam",
]
assert ( assert (
manifest.cache.get( manifest.cache.get(
{ {
"prompt": "Hello is a prompt", "best_of": 1,
"engine": "dummy", "engine": "dummy",
"max_tokens": 10,
"model": "text-davinci-003",
"n": n,
"prompt": "Hello is a prompt",
"request_cls": "LMRequest", "request_cls": "LMRequest",
"num_results": n, "temperature": 0.0,
}, "top_p": 1.0,
}
) )
is not None is not None
) )
@ -266,11 +311,16 @@ def test_batch_run(sqlite_cache: str, n: int, return_response: bool) -> None:
assert ( assert (
manifest.cache.get( manifest.cache.get(
{ {
"prompt": "New prompt", "best_of": 1,
"engine": "dummy", "engine": "dummy",
"max_tokens": 10,
"model": "text-davinci-003",
"n": n,
"prompt": "New prompt",
"request_cls": "LMRequest", "request_cls": "LMRequest",
"num_results": n, "temperature": 0.0,
}, "top_p": 1.0,
}
) )
is None is None
) )
@ -287,20 +337,25 @@ def test_batch_run(sqlite_cache: str, n: int, return_response: bool) -> None:
assert result.is_cached() assert result.is_cached()
else: else:
res = cast(str, result) res = cast(str, result)
assert res == ["hello", "hello"] assert res == [
"Nice Employ NFCYouryms“Inwarn\ttemplate europ Moines",
".vol.deserializebigmnchantment ROTıl='')\najsС",
]
prompt = ["Hello is a prompt", "Hello is a prompt"] prompt = ["Hello is a prompt", "Hello is a prompt"]
result = manifest.run(prompt, stop_token="ll", return_response=return_response) result = manifest.run(
prompt, stop_token=" current", return_response=return_response
)
if return_response: if return_response:
assert isinstance(result, Response) assert isinstance(result, Response)
result = cast(Response, result) result = cast(Response, result)
assert len(result.get_usage_obj().usages) == len( assert len(result.get_usage_obj().usages) == len(
result.get_response_obj().choices result.get_response_obj().choices
) )
res = result.get_response(stop_token="ll", is_batch=True) res = result.get_response(stop_token=" current", is_batch=True)
else: else:
res = cast(str, result) res = cast(str, result)
assert res == ["he", "he"] assert res == ["appersstoff210", "appersstoff210"]
@pytest.mark.usefixtures("sqlite_cache") @pytest.mark.usefixtures("sqlite_cache")
@ -310,6 +365,7 @@ def test_abatch_run(sqlite_cache: str) -> None:
client_name="dummy", client_name="dummy",
cache_name="sqlite", cache_name="sqlite",
cache_connection=sqlite_cache, cache_connection=sqlite_cache,
temperature=0.0,
) )
prompt = ["This is a prompt"] prompt = ["This is a prompt"]
result = cast( result = cast(
@ -318,15 +374,20 @@ def test_abatch_run(sqlite_cache: str) -> None:
assert len(result.get_usage_obj().usages) == len(result.get_response_obj().choices) assert len(result.get_usage_obj().usages) == len(result.get_response_obj().choices)
res = result.get_response(manifest.stop_token, is_batch=True) res = result.get_response(manifest.stop_token, is_batch=True)
assert res == ["hello"] assert res == ["Nice Employ NFCYouryms“Inwarn\ttemplate europ Moines"]
assert ( assert (
manifest.cache.get( manifest.cache.get(
{ {
"prompt": "This is a prompt", "best_of": 1,
"engine": "dummy", "engine": "dummy",
"max_tokens": 10,
"model": "text-davinci-003",
"n": 1,
"prompt": "This is a prompt",
"request_cls": "LMRequest", "request_cls": "LMRequest",
"num_results": 1, "temperature": 0.0,
}, "top_p": 1.0,
}
) )
is not None is not None
) )
@ -338,15 +399,23 @@ def test_abatch_run(sqlite_cache: str) -> None:
assert len(result.get_usage_obj().usages) == len(result.get_response_obj().choices) assert len(result.get_usage_obj().usages) == len(result.get_response_obj().choices)
res = result.get_response(manifest.stop_token, is_batch=True) res = result.get_response(manifest.stop_token, is_batch=True)
assert res == ["hello", "hello"] assert res == [
"appersstoff210 currentNodeleh norm unified_voice DIYHam",
"appersstoff210 currentNodeleh norm unified_voice DIYHam",
]
assert ( assert (
manifest.cache.get( manifest.cache.get(
{ {
"prompt": "Hello is a prompt", "best_of": 1,
"engine": "dummy", "engine": "dummy",
"max_tokens": 10,
"model": "text-davinci-003",
"n": 1,
"prompt": "Hello is a prompt",
"request_cls": "LMRequest", "request_cls": "LMRequest",
"num_results": 1, "temperature": 0.0,
}, "top_p": 1.0,
}
) )
is not None is not None
) )
@ -362,11 +431,16 @@ def test_abatch_run(sqlite_cache: str) -> None:
assert ( assert (
manifest.cache.get( manifest.cache.get(
{ {
"prompt": "New prompt", "best_of": 1,
"engine": "dummy", "engine": "dummy",
"max_tokens": 10,
"model": "text-davinci-003",
"n": 1,
"prompt": "New prompt",
"request_cls": "LMRequest", "request_cls": "LMRequest",
"num_results": 1, "temperature": 0.0,
}, "top_p": 1.0,
}
) )
is None is None
) )
@ -379,7 +453,10 @@ def test_abatch_run(sqlite_cache: str) -> None:
res = result.get_response(manifest.stop_token, is_batch=True) res = result.get_response(manifest.stop_token, is_batch=True)
# Cached because one item is in cache # Cached because one item is in cache
assert result.is_cached() assert result.is_cached()
assert res == ["hello", "hello"] assert res == [
"Nice Employ NFCYouryms“Inwarn\ttemplate europ Moines",
".vol.deserializebigmnchantment ROTıl='')\najsС",
]
prompt = ["Hello is a prompt", "Hello is a prompt"] prompt = ["Hello is a prompt", "Hello is a prompt"]
result = cast( result = cast(
@ -387,8 +464,8 @@ def test_abatch_run(sqlite_cache: str) -> None:
) )
assert len(result.get_usage_obj().usages) == len(result.get_response_obj().choices) assert len(result.get_usage_obj().usages) == len(result.get_response_obj().choices)
res = result.get_response(stop_token="ll", is_batch=True) res = result.get_response(stop_token=" current", is_batch=True)
assert res == ["he", "he"] assert res == ["appersstoff210", "appersstoff210"]
@pytest.mark.usefixtures("sqlite_cache") @pytest.mark.usefixtures("sqlite_cache")
@ -398,6 +475,7 @@ def test_run_chat(sqlite_cache: str) -> None:
client_name="dummy", client_name="dummy",
cache_name="sqlite", cache_name="sqlite",
cache_connection=sqlite_cache, cache_connection=sqlite_cache,
temperature=0.0,
) )
# Set CHAT to be true for this model # Set CHAT to be true for this model
manifest.client_pool.client_pool[0].IS_CHAT = True manifest.client_pool.client_pool[0].IS_CHAT = True
@ -406,15 +484,23 @@ def test_run_chat(sqlite_cache: str) -> None:
{"role": "system", "content": "Hello."}, {"role": "system", "content": "Hello."},
] ]
result = manifest.run(prompt, return_response=False) result = manifest.run(prompt, return_response=False)
assert result == "Hello." assert (
result
== "ectors WortGo ré_sg|--------------------------------------------------------------------------\n contradictory Aad \u200b getUserId" # noqa: E501
)
assert ( assert (
manifest.cache.get( manifest.cache.get(
{ {
"prompt": [{"content": "Hello.", "role": "system"}], "best_of": 1,
"engine": "dummy", "engine": "dummy",
"num_results": 1, "max_tokens": 10,
"model": "text-davinci-003",
"n": 1,
"prompt": [{"content": "Hello.", "role": "system"}],
"request_cls": "LMChatRequest", "request_cls": "LMChatRequest",
}, "temperature": 0.0,
"top_p": 1.0,
}
) )
is not None is not None
) )
@ -428,18 +514,23 @@ def test_run_chat(sqlite_cache: str) -> None:
result = cast(Response, result) result = cast(Response, result)
assert len(result.get_usage_obj().usages) == len(result.get_response_obj().choices) assert len(result.get_usage_obj().usages) == len(result.get_response_obj().choices)
res = result.get_response() res = result.get_response()
assert res == "Hello." assert res == "_deploy_age_gp hora Plus Scheduler EisenhowerRF视 chemotherapy"
assert ( assert (
manifest.cache.get( manifest.cache.get(
{ {
"best_of": 1,
"engine": "dummy",
"max_tokens": 10,
"model": "text-davinci-003",
"n": 1,
"prompt": [ "prompt": [
{"role": "system", "content": "Hello."}, {"role": "system", "content": "Hello."},
{"role": "user", "content": "Goodbye?"}, {"role": "user", "content": "Goodbye?"},
], ],
"engine": "dummy",
"num_results": 1,
"request_cls": "LMChatRequest", "request_cls": "LMChatRequest",
}, "temperature": 0.0,
"top_p": 1.0,
}
) )
is not None is not None
) )
@ -452,6 +543,7 @@ def test_score_run(sqlite_cache: str) -> None:
client_name="dummy", client_name="dummy",
cache_name="sqlite", cache_name="sqlite",
cache_connection=sqlite_cache, cache_connection=sqlite_cache,
temperature=0.0,
) )
prompt = "This is a prompt" prompt = "This is a prompt"
@ -459,33 +551,68 @@ def test_score_run(sqlite_cache: str) -> None:
assert ( assert (
manifest.cache.get( manifest.cache.get(
{ {
"prompt": "This is a prompt", "best_of": 1,
"engine": "dummy", "engine": "dummy",
"max_tokens": 10,
"model": "text-davinci-003",
"n": 1,
"prompt": "This is a prompt",
"request_cls": "LMScoreRequest", "request_cls": "LMScoreRequest",
"num_results": 1, "temperature": 0.0,
}, "top_p": 1.0,
}
) )
is not None is not None
) )
assert result == { assert result == {
"response": { "response": {
"choices": [ "choices": [
{"text": "This is a prompt", "token_logprobs": [0.3], "tokens": None} {
"text": "Nice Employ NFCYouryms“Inwarn\ttemplate europ Moines",
"token_logprobs": [
-1.827188890438529,
-1.6981601736417915,
-0.24606708391178755,
-1.9209383499010613,
-0.8833563758318617,
-1.4121369466920703,
-0.376352908076236,
-1.3200064558188096,
-0.813028447207917,
-0.5977255311239729,
],
"tokens": [
"46078",
"21445",
"48305",
"7927",
"76125",
"46233",
"34581",
"23679",
"63021",
"78158",
],
}
]
},
"usages": {
"usages": [
{"completion_tokens": 10, "prompt_tokens": 4, "total_tokens": 14}
] ]
}, },
"usages": {"usages": []},
"cached": False, "cached": False,
"request": { "request": {
"prompt": "This is a prompt", "prompt": "This is a prompt",
"engine": "text-ada-001", "engine": "text-davinci-003",
"n": 1, "n": 1,
"client_timeout": 60, "client_timeout": 60,
"run_id": None, "run_id": None,
"batch_size": 8, "batch_size": 20,
"temperature": 0.7, "temperature": 0.0,
"max_tokens": 100, "max_tokens": 10,
"top_p": 1.0, "top_p": 1.0,
"top_k": 50, "top_k": 1,
"logprobs": None, "logprobs": None,
"stop_sequences": None, "stop_sequences": None,
"num_beams": 1, "num_beams": 1,
@ -505,49 +632,112 @@ def test_score_run(sqlite_cache: str) -> None:
assert ( assert (
manifest.cache.get( manifest.cache.get(
{ {
"prompt": "Hello is a prompt", "best_of": 1,
"engine": "dummy", "engine": "dummy",
"max_tokens": 10,
"model": "text-davinci-003",
"n": 1,
"prompt": "Hello is a prompt",
"request_cls": "LMScoreRequest", "request_cls": "LMScoreRequest",
"num_results": 1, "temperature": 0.0,
}, "top_p": 1.0,
}
) )
is not None is not None
) )
assert ( assert (
manifest.cache.get( manifest.cache.get(
{ {
"prompt": "Hello is another prompt", "best_of": 1,
"engine": "dummy", "engine": "dummy",
"max_tokens": 10,
"model": "text-davinci-003",
"n": 1,
"prompt": "Hello is another prompt",
"request_cls": "LMScoreRequest", "request_cls": "LMScoreRequest",
"num_results": 1, "temperature": 0.0,
}, "top_p": 1.0,
}
) )
is not None is not None
) )
assert result == { assert result == {
"response": { "response": {
"choices": [ "choices": [
{"text": "Hello is a prompt", "token_logprobs": [0.3], "tokens": None},
{ {
"text": "Hello is another prompt", "text": "appersstoff210 currentNodeleh norm unified_voice DIYHam",
"token_logprobs": [0.3], "token_logprobs": [
"tokens": None, -0.5613340599860608,
-1.2822870706137146,
-1.9909319620162806,
-0.6312373658222814,
-1.9066239705571664,
-1.2420939968397082,
-0.7208735169940805,
-1.9144266963723062,
-0.041181937860757856,
-0.5356282450367043,
],
"tokens": [
"28921",
"81056",
"8848",
"47399",
"74890",
"7617",
"43790",
"77865",
"32558",
"41041",
],
}, },
{
"text": ".addAttribute_size DE imageUrl_datas\tapFixed(hour setups\tcomment", # noqa: E501
"token_logprobs": [
-1.1142500072582333,
-0.819706434396527,
-1.9956443391600693,
-0.8425896744807639,
-1.8398050571245623,
-1.912564137256891,
-1.6677665162080606,
-1.1579612203844727,
-1.9876114502998343,
-0.2698297864722319,
],
"tokens": [
"26300",
"2424",
"3467",
"40749",
"47630",
"70998",
"13829",
"72135",
"84823",
"97368",
],
},
]
},
"usages": {
"usages": [
{"completion_tokens": 10, "prompt_tokens": 4, "total_tokens": 14},
{"completion_tokens": 10, "prompt_tokens": 4, "total_tokens": 14},
] ]
}, },
"usages": {"usages": []},
"cached": False, "cached": False,
"request": { "request": {
"prompt": ["Hello is a prompt", "Hello is another prompt"], "prompt": ["Hello is a prompt", "Hello is another prompt"],
"engine": "text-ada-001", "engine": "text-davinci-003",
"n": 1, "n": 1,
"client_timeout": 60, "client_timeout": 60,
"run_id": None, "run_id": None,
"batch_size": 8, "batch_size": 20,
"temperature": 0.7, "temperature": 0.0,
"max_tokens": 100, "max_tokens": 10,
"top_p": 1.0, "top_p": 1.0,
"top_k": 50, "top_k": 1,
"logprobs": None, "logprobs": None,
"stop_sequences": None, "stop_sequences": None,
"num_beams": 1, "num_beams": 1,

Loading…
Cancel
Save