2022-05-26 03:53:54 +00:00
|
|
|
"""Response test."""
|
2023-04-24 17:10:47 +00:00
|
|
|
from typing import List, cast
|
2023-04-17 05:07:35 +00:00
|
|
|
|
2022-12-24 23:17:02 +00:00
|
|
|
import numpy as np
|
2022-05-26 03:53:54 +00:00
|
|
|
import pytest
|
|
|
|
|
|
|
|
from manifest import Response
|
2023-04-24 17:10:47 +00:00
|
|
|
from manifest.request import EmbeddingRequest, LMRequest
|
2023-05-21 22:50:03 +00:00
|
|
|
from manifest.response import (
|
|
|
|
ArrayModelChoice,
|
|
|
|
LMModelChoice,
|
|
|
|
ModelChoices,
|
|
|
|
Usage,
|
|
|
|
Usages,
|
|
|
|
)
|
2022-05-26 03:53:54 +00:00
|
|
|
|
|
|
|
|
2023-04-24 17:10:47 +00:00
|
|
|
def test_init(
|
|
|
|
model_choice: ModelChoices,
|
|
|
|
model_choice_arr: ModelChoices,
|
|
|
|
model_choice_arr_int: ModelChoices,
|
|
|
|
request_lm: LMRequest,
|
|
|
|
request_array: EmbeddingRequest,
|
|
|
|
) -> None:
|
2022-05-26 03:53:54 +00:00
|
|
|
"""Test response initialization."""
|
2023-04-24 17:10:47 +00:00
|
|
|
response = Response(
|
|
|
|
response=model_choice,
|
|
|
|
cached=False,
|
|
|
|
request=request_lm,
|
|
|
|
usages=None,
|
|
|
|
request_type=LMRequest,
|
|
|
|
response_type="text",
|
|
|
|
)
|
|
|
|
assert response._response == model_choice
|
2022-05-26 03:53:54 +00:00
|
|
|
assert response._cached is False
|
2023-04-24 17:10:47 +00:00
|
|
|
assert response._request == request_lm
|
|
|
|
assert response._usages == Usages(usages=[])
|
|
|
|
assert response._request_type == LMRequest
|
|
|
|
assert response._response_type == "text"
|
|
|
|
assert response._item_dtype is None
|
2022-12-24 23:17:02 +00:00
|
|
|
|
|
|
|
response = Response(
|
2023-04-24 17:10:47 +00:00
|
|
|
response=model_choice_arr_int,
|
|
|
|
cached=False,
|
|
|
|
request=request_array,
|
|
|
|
usages=Usages(usages=[Usage(total_tokens=4), Usage(total_tokens=6)]),
|
|
|
|
request_type=EmbeddingRequest,
|
|
|
|
response_type="array",
|
|
|
|
)
|
2022-12-24 23:17:02 +00:00
|
|
|
assert response._cached is False
|
2023-04-24 17:10:47 +00:00
|
|
|
assert response._request == request_array
|
|
|
|
assert sum([usg.total_tokens for usg in response._usages.usages]) == 10
|
|
|
|
assert response._request_type == EmbeddingRequest
|
|
|
|
assert response._response_type == "array"
|
|
|
|
assert response._item_dtype == "int64"
|
2022-12-24 23:17:02 +00:00
|
|
|
|
2023-04-24 17:10:47 +00:00
|
|
|
with pytest.raises(ValueError) as excinfo:
|
|
|
|
Response(
|
|
|
|
response=model_choice,
|
|
|
|
cached=False,
|
|
|
|
request=request_lm,
|
|
|
|
usages=None,
|
|
|
|
request_type=LMRequest,
|
|
|
|
response_type="blah",
|
|
|
|
)
|
|
|
|
assert "blah" in str(excinfo.value)
|
|
|
|
|
|
|
|
# Can't convert array with text
|
|
|
|
with pytest.raises(ValueError) as excinfo:
|
|
|
|
Response(
|
|
|
|
response=model_choice,
|
|
|
|
cached=False,
|
|
|
|
request=request_lm,
|
|
|
|
usages=None,
|
|
|
|
request_type=LMRequest,
|
|
|
|
response_type="array",
|
|
|
|
)
|
|
|
|
assert str(excinfo.value) == (
|
|
|
|
"response_type is array but response is "
|
|
|
|
"<class 'manifest.response.LMModelChoice'>"
|
|
|
|
)
|
|
|
|
|
|
|
|
# Can't convert text with array
|
|
|
|
with pytest.raises(ValueError) as excinfo:
|
|
|
|
Response(
|
|
|
|
response=model_choice_arr,
|
|
|
|
cached=False,
|
|
|
|
request=request_array,
|
|
|
|
usages=None,
|
|
|
|
request_type=LMRequest,
|
|
|
|
response_type="text",
|
|
|
|
)
|
|
|
|
assert str(excinfo.value) == (
|
|
|
|
"response_type is text but response is "
|
|
|
|
"<class 'manifest.response.ArrayModelChoice'>"
|
2022-12-24 23:17:02 +00:00
|
|
|
)
|
2022-05-26 03:53:54 +00:00
|
|
|
|
|
|
|
|
2023-04-24 17:10:47 +00:00
|
|
|
def test_getters(model_choice: ModelChoices, request_lm: LMRequest) -> None:
|
2022-05-26 03:53:54 +00:00
|
|
|
"""Test response cached."""
|
2023-04-24 17:10:47 +00:00
|
|
|
response = Response(
|
|
|
|
response=model_choice,
|
|
|
|
cached=False,
|
|
|
|
request=request_lm,
|
|
|
|
usages=None,
|
|
|
|
request_type=LMRequest,
|
|
|
|
response_type="text",
|
|
|
|
)
|
|
|
|
assert response.get_response_obj() == model_choice
|
2022-05-26 03:53:54 +00:00
|
|
|
assert response.is_cached() is False
|
2023-04-24 17:10:47 +00:00
|
|
|
assert response.get_request_obj() == request_lm
|
|
|
|
assert response.get_usage_obj() == Usages(usages=[])
|
|
|
|
assert response.get_json_response() == model_choice.dict()
|
|
|
|
assert response.get_response() == ["hello", "bye"]
|
2022-05-26 03:53:54 +00:00
|
|
|
|
|
|
|
|
2023-04-24 17:10:47 +00:00
|
|
|
def test_serialize(
|
|
|
|
model_choice: ModelChoices,
|
|
|
|
model_choice_arr: ModelChoices,
|
|
|
|
model_choice_arr_int: ModelChoices,
|
|
|
|
request_lm: LMRequest,
|
|
|
|
request_array: EmbeddingRequest,
|
|
|
|
) -> None:
|
|
|
|
"""Test response serialization."""
|
2022-12-24 23:17:02 +00:00
|
|
|
response = Response(
|
2023-04-24 17:10:47 +00:00
|
|
|
response=model_choice,
|
|
|
|
cached=False,
|
|
|
|
request=request_lm,
|
|
|
|
usages=None,
|
|
|
|
request_type=LMRequest,
|
|
|
|
response_type="text",
|
2022-12-24 23:17:02 +00:00
|
|
|
)
|
2023-04-24 17:10:47 +00:00
|
|
|
deserialized_response = Response.deserialize(response.serialize())
|
|
|
|
assert deserialized_response.get_response_obj() == model_choice
|
|
|
|
assert deserialized_response.is_cached() is False
|
|
|
|
assert deserialized_response.get_request_obj() == request_lm
|
|
|
|
assert deserialized_response.get_usage_obj() == Usages(usages=[])
|
|
|
|
assert deserialized_response.get_json_response() == model_choice.dict()
|
|
|
|
assert deserialized_response.get_response() == ["hello", "bye"]
|
2022-12-24 23:17:02 +00:00
|
|
|
|
2023-04-24 17:10:47 +00:00
|
|
|
deserialized_response = Response.from_dict(response.to_dict())
|
|
|
|
assert deserialized_response.get_response_obj() == model_choice
|
|
|
|
assert deserialized_response.is_cached() is False
|
|
|
|
assert deserialized_response.get_request_obj() == request_lm
|
|
|
|
assert deserialized_response.get_usage_obj() == Usages(usages=[])
|
|
|
|
assert deserialized_response.get_json_response() == model_choice.dict()
|
|
|
|
assert deserialized_response.get_response() == ["hello", "bye"]
|
2022-05-26 03:53:54 +00:00
|
|
|
|
2023-04-24 17:10:47 +00:00
|
|
|
deserialized_response = Response.from_dict(
|
|
|
|
response.to_dict(drop_request=True), request_dict={"prompt": "blahhhh"}
|
|
|
|
)
|
|
|
|
assert deserialized_response.get_response_obj() == model_choice
|
|
|
|
assert deserialized_response.is_cached() is False
|
|
|
|
assert deserialized_response.get_request_obj().prompt == "blahhhh"
|
|
|
|
assert deserialized_response.get_usage_obj() == Usages(usages=[])
|
|
|
|
assert deserialized_response.get_json_response() == model_choice.dict()
|
|
|
|
assert deserialized_response.get_response() == ["hello", "bye"]
|
2022-05-26 03:53:54 +00:00
|
|
|
|
2023-04-24 17:10:47 +00:00
|
|
|
# Int type
|
2022-12-24 23:17:02 +00:00
|
|
|
response = Response(
|
2023-04-24 17:10:47 +00:00
|
|
|
response=model_choice_arr_int,
|
|
|
|
cached=False,
|
|
|
|
request=request_array,
|
|
|
|
usages=Usages(usages=[Usage(total_tokens=4), Usage(total_tokens=6)]),
|
|
|
|
request_type=EmbeddingRequest,
|
|
|
|
response_type="array",
|
2022-12-24 23:17:02 +00:00
|
|
|
)
|
|
|
|
deserialized_response = Response.deserialize(response.serialize())
|
2023-04-24 17:10:47 +00:00
|
|
|
assert deserialized_response._item_dtype == "int64"
|
|
|
|
assert (
|
|
|
|
cast(
|
|
|
|
ArrayModelChoice, deserialized_response.get_response_obj().choices[0]
|
|
|
|
).array.dtype
|
|
|
|
== np.int64
|
|
|
|
)
|
2022-12-24 23:17:02 +00:00
|
|
|
assert np.array_equal(
|
2023-04-24 17:10:47 +00:00
|
|
|
cast(
|
|
|
|
ArrayModelChoice, deserialized_response.get_response_obj().choices[0]
|
|
|
|
).array,
|
|
|
|
cast(ArrayModelChoice, model_choice_arr_int.choices[0]).array,
|
2022-12-24 23:17:02 +00:00
|
|
|
)
|
|
|
|
|
2023-04-24 17:10:47 +00:00
|
|
|
# Float type
|
2022-12-24 23:17:02 +00:00
|
|
|
response = Response(
|
2023-04-24 17:10:47 +00:00
|
|
|
response=model_choice_arr,
|
|
|
|
cached=False,
|
|
|
|
request=request_array,
|
|
|
|
usages=Usages(usages=[Usage(total_tokens=4), Usage(total_tokens=6)]),
|
|
|
|
request_type=EmbeddingRequest,
|
|
|
|
response_type="array",
|
2022-12-24 23:17:02 +00:00
|
|
|
)
|
|
|
|
deserialized_response = Response.deserialize(response.serialize())
|
2023-04-24 17:10:47 +00:00
|
|
|
assert deserialized_response._item_dtype == "float64"
|
|
|
|
assert (
|
|
|
|
cast(
|
|
|
|
ArrayModelChoice, deserialized_response.get_response_obj().choices[0]
|
|
|
|
).array.dtype
|
|
|
|
== np.float64
|
|
|
|
)
|
2022-12-24 23:17:02 +00:00
|
|
|
assert np.array_equal(
|
2023-04-24 17:10:47 +00:00
|
|
|
cast(
|
|
|
|
ArrayModelChoice, deserialized_response.get_response_obj().choices[0]
|
|
|
|
).array,
|
|
|
|
cast(ArrayModelChoice, model_choice_arr.choices[0]).array,
|
2022-12-24 23:17:02 +00:00
|
|
|
)
|
|
|
|
|
2022-05-26 03:53:54 +00:00
|
|
|
|
2023-04-24 17:10:47 +00:00
|
|
|
def test_get_results(
|
|
|
|
model_choice: ModelChoices,
|
|
|
|
model_choice_single: ModelChoices,
|
|
|
|
model_choice_arr: ModelChoices,
|
|
|
|
request_lm: LMRequest,
|
|
|
|
request_array: EmbeddingRequest,
|
|
|
|
) -> None:
|
2022-05-26 03:53:54 +00:00
|
|
|
"""Test response get results."""
|
2023-04-24 17:10:47 +00:00
|
|
|
response = Response(
|
|
|
|
response=model_choice_single,
|
|
|
|
cached=False,
|
|
|
|
request=request_lm,
|
|
|
|
usages=None,
|
|
|
|
request_type=LMRequest,
|
|
|
|
response_type="text",
|
|
|
|
)
|
|
|
|
assert response.get_response() == "helloo"
|
2022-05-26 03:53:54 +00:00
|
|
|
assert response.get_response(stop_token="ll") == "he"
|
2022-11-11 08:19:27 +00:00
|
|
|
assert response.get_response(stop_token="ll", is_batch=True) == ["he"]
|
2022-05-26 03:53:54 +00:00
|
|
|
|
|
|
|
response = Response(
|
2023-04-24 17:10:47 +00:00
|
|
|
response=model_choice,
|
|
|
|
cached=False,
|
|
|
|
request=request_lm,
|
|
|
|
usages=None,
|
|
|
|
request_type=LMRequest,
|
|
|
|
response_type="text",
|
2022-05-26 03:53:54 +00:00
|
|
|
)
|
2023-04-24 17:10:47 +00:00
|
|
|
assert response.get_response() == ["hello", "bye"]
|
|
|
|
assert response.get_response(stop_token="b") == ["hello", ""]
|
|
|
|
assert response.get_response(stop_token="y", is_batch=True) == ["hello", "b"]
|
2022-12-24 23:17:02 +00:00
|
|
|
|
2023-04-24 17:10:47 +00:00
|
|
|
float_arr1 = cast(ArrayModelChoice, model_choice_arr.choices[0]).array
|
|
|
|
float_arr2 = cast(ArrayModelChoice, model_choice_arr.choices[1]).array
|
2022-12-24 23:17:02 +00:00
|
|
|
response = Response(
|
2023-04-24 17:10:47 +00:00
|
|
|
response=model_choice_arr,
|
|
|
|
cached=False,
|
|
|
|
request=request_array,
|
|
|
|
usages=Usages(usages=[Usage(total_tokens=4), Usage(total_tokens=6)]),
|
|
|
|
request_type=EmbeddingRequest,
|
|
|
|
response_type="array",
|
2022-12-24 23:17:02 +00:00
|
|
|
)
|
2023-04-24 17:10:47 +00:00
|
|
|
assert np.array_equal(response.get_response()[0], float_arr1)
|
|
|
|
assert np.array_equal(response.get_response()[1], float_arr2)
|
|
|
|
assert np.array_equal(response.get_response(stop_token="t")[0], float_arr1)
|
|
|
|
assert np.array_equal(response.get_response(stop_token="t")[1], float_arr2)
|
2023-04-17 05:07:35 +00:00
|
|
|
|
|
|
|
|
2023-04-24 17:10:47 +00:00
|
|
|
def test_union_all(
|
|
|
|
model_choice: ModelChoices,
|
|
|
|
model_choice_single: ModelChoices,
|
|
|
|
request_lm: LMRequest,
|
|
|
|
request_lm_single: LMRequest,
|
|
|
|
) -> None:
|
2023-04-17 05:07:35 +00:00
|
|
|
"""Test union all."""
|
2023-04-24 17:10:47 +00:00
|
|
|
response1 = Response(
|
|
|
|
response=model_choice,
|
|
|
|
cached=False,
|
|
|
|
request=request_lm,
|
|
|
|
usages=None,
|
|
|
|
request_type=LMRequest,
|
|
|
|
response_type="text",
|
|
|
|
)
|
2023-04-17 05:07:35 +00:00
|
|
|
|
2023-04-24 17:10:47 +00:00
|
|
|
response2 = Response(
|
|
|
|
response=model_choice_single,
|
|
|
|
cached=False,
|
|
|
|
request=request_lm_single,
|
|
|
|
usages=None,
|
|
|
|
request_type=LMRequest,
|
|
|
|
response_type="text",
|
|
|
|
)
|
2023-04-17 05:07:35 +00:00
|
|
|
|
2023-04-24 17:10:47 +00:00
|
|
|
final_response = Response.union_all([response1, response2])
|
2023-04-17 05:07:35 +00:00
|
|
|
assert final_response.get_json_response() == {
|
|
|
|
"choices": [
|
2023-05-21 22:50:03 +00:00
|
|
|
{"text": "hello", "token_logprobs": [0.1, 0.2], "tokens": ["hel", "lo"]},
|
|
|
|
{"text": "bye", "token_logprobs": [0.3], "tokens": ["bye"]},
|
|
|
|
{"text": "helloo", "token_logprobs": [0.1, 0.2], "tokens": ["hel", "loo"]},
|
2023-04-17 05:07:35 +00:00
|
|
|
]
|
|
|
|
}
|
2023-04-24 17:10:47 +00:00
|
|
|
assert final_response.get_usage_obj() == Usages(usages=[Usage(), Usage(), Usage()])
|
|
|
|
merged_prompts: List[str] = request_lm.prompt + [request_lm_single.prompt] # type: ignore # noqa: E501
|
|
|
|
assert final_response.get_request_obj().prompt == merged_prompts
|
|
|
|
assert final_response.get_request_obj().engine == "dummy::text-ada-001"
|
2023-04-17 05:07:35 +00:00
|
|
|
|
|
|
|
# Modify A to have usage and cached
|
2023-04-24 17:10:47 +00:00
|
|
|
response1 = Response(
|
|
|
|
response=model_choice,
|
|
|
|
cached=False,
|
|
|
|
request=request_lm,
|
|
|
|
usages=Usages(usages=[Usage(total_tokens=4), Usage(total_tokens=6)]),
|
|
|
|
request_type=LMRequest,
|
|
|
|
response_type="text",
|
|
|
|
)
|
2023-04-17 05:07:35 +00:00
|
|
|
|
2023-04-24 17:10:47 +00:00
|
|
|
final_response = Response.union_all([response1, response2])
|
|
|
|
assert final_response.get_usage_obj() == Usages(
|
|
|
|
usages=[Usage(total_tokens=4), Usage(total_tokens=6), Usage()]
|
|
|
|
)
|
2023-05-21 22:50:03 +00:00
|
|
|
|
|
|
|
# Test merge to single
|
|
|
|
model_choices = ModelChoices(
|
|
|
|
choices=[
|
|
|
|
LMModelChoice(
|
|
|
|
text=" helloo this is a bug",
|
|
|
|
token_logprobs=[0.1, 0.2, 0.3],
|
|
|
|
tokens=[" helloo", " this is", " a bug"],
|
|
|
|
),
|
|
|
|
]
|
|
|
|
)
|
|
|
|
request = LMRequest(prompt="monkey", engine="dummy")
|
|
|
|
response1 = Response(
|
|
|
|
response=model_choices,
|
|
|
|
cached=False,
|
|
|
|
request=request,
|
|
|
|
usages=None,
|
|
|
|
request_type=LMRequest,
|
|
|
|
response_type="text",
|
|
|
|
)
|
|
|
|
final_response = Response.union_all([response1, response1], as_single_lmchoice=True)
|
|
|
|
assert final_response.get_json_response() == {
|
|
|
|
"choices": [
|
|
|
|
{
|
|
|
|
"text": " helloo this is a bug helloo this is a bug",
|
|
|
|
"token_logprobs": [0.1, 0.2, 0.3, 0.1, 0.2, 0.3],
|
|
|
|
"tokens": [
|
|
|
|
" helloo",
|
|
|
|
" this is",
|
|
|
|
" a bug",
|
|
|
|
" helloo",
|
|
|
|
" this is",
|
|
|
|
" a bug",
|
|
|
|
],
|
|
|
|
},
|
|
|
|
]
|
|
|
|
}
|
|
|
|
assert final_response.get_usage_obj() == Usages(usages=[Usage()])
|
|
|
|
assert final_response.get_request_obj().prompt == "monkey"
|
|
|
|
assert final_response.get_request_obj().engine == "dummy"
|
|
|
|
|
|
|
|
|
|
|
|
def test_as_iter(
|
|
|
|
model_choice_single: ModelChoices, request_lm_single: LMRequest
|
|
|
|
) -> None:
|
|
|
|
"""Test as iter."""
|
|
|
|
response = Response(
|
|
|
|
response=model_choice_single,
|
|
|
|
cached=False,
|
|
|
|
request=request_lm_single,
|
|
|
|
usages=None,
|
|
|
|
request_type=LMRequest,
|
|
|
|
response_type="text",
|
|
|
|
)
|
|
|
|
response_iter_list = list(response.as_iter())
|
|
|
|
assert len(response_iter_list) == 2
|
|
|
|
assert response_iter_list[0].get_response() == "hel"
|
|
|
|
assert response_iter_list[1].get_response() == "loo"
|
|
|
|
|
|
|
|
model_choices = ModelChoices(
|
|
|
|
choices=[
|
|
|
|
LMModelChoice(text="helloo this is a bug"),
|
|
|
|
]
|
|
|
|
)
|
|
|
|
request = LMRequest(prompt="monkey", engine="dummy")
|
|
|
|
response = Response(
|
|
|
|
response=model_choices,
|
|
|
|
cached=False,
|
|
|
|
request=request,
|
|
|
|
usages=None,
|
|
|
|
request_type=LMRequest,
|
|
|
|
response_type="text",
|
|
|
|
)
|
|
|
|
response_iter_list = list(response.as_iter())
|
|
|
|
assert len(response_iter_list) == 5
|
|
|
|
assert response_iter_list[0].get_response() == "helloo"
|
|
|
|
assert response_iter_list[1].get_response() == " this"
|
|
|
|
assert response_iter_list[2].get_response() == " is"
|
|
|
|
assert response_iter_list[3].get_response() == " a"
|
|
|
|
assert response_iter_list[4].get_response() == " bug"
|