Laurel Orr 2022-06-11 08:02:31 +00:00
parent 74b9302b1b
commit f568875e57
14 changed files with 406 additions and 1875 deletions

View File

@ -1,7 +1,6 @@
dev:
poetry install
poetry run pre-commit install
poetry run mypy --install-types
test: dev check
poetry run pytest tests

47
manifest/caches/noop.py Normal file
View File

@ -0,0 +1,47 @@
"""Noop cache."""
from typing import Any, Dict, Union
from manifest.caches import Cache
class NoopCache(Cache):
"""A Noop cache that caches nothing for request/response pairs."""
def connect(self, connection_str: str, cache_args: Dict[str, Any]) -> None:
"""
Connect to client.
Args:
connection_str: connection string.
cache_args: cache arguments.
"""
pass
def close(self) -> None:
"""Close the client."""
pass
def get_key(self, key: str, table: str = "default") -> Union[str, None]:
"""
Return None key for never in cache.
Args:
key: key for cache.
table: table to get key in.
"""
return None
def set_key(self, key: str, value: str, table: str = "default") -> None:
"""
Do not set anything as no cache.
Args:
key: key for cache.
value: new value for key.
table: table to set key in.
"""
pass
def commit(self) -> None:
"""Commit any results."""
pass

103
manifest/clients/ai21.py Normal file
View File

@ -0,0 +1,103 @@
"""OpenAI client."""
import logging
import os
from typing import Any, Callable, Dict, Optional, Tuple
import requests
from manifest.clients.client import Client
logger = logging.getLogger(__name__)
AI21_ENGINES = {
"j1-jumbo",
"j1-grande",
"j1-large",
}
class AI21Client(Client):
"""AI21Client client."""
def connect(
self,
connection_str: Optional[str] = None,
client_args: Dict[str, Any] = {},
) -> None:
"""
Connect to the AI21 server.
connection_str is passed as default AI21_API_KEY if variable not set.
Args:
connection_str: connection string.
client_args: client arguments.
"""
# Taken from https://studio.ai21.com/docs/api/
self.host = "https://api.ai21.com/studio/v1"
self.api_key = os.environ.get("AI21_API_KEY", connection_str)
if self.api_key is None:
raise ValueError(
"AI21 API key not set. Set AI21_API_KEY environment "
"variable or pass through `connection_str`."
)
self.engine = client_args.pop("engine", "j1-large")
if self.engine not in AI21_ENGINES:
raise ValueError(f"Invalid engine {self.engine}. Must be {AI21_ENGINES}.")
self.temperature = client_args.pop("temperature", 0.0)
self.max_tokens = client_args.pop("max_tokens", 10)
self.top_k_return = client_args.pop("topKReturn", 1.0)
self.num_results = client_args.pop("numResults", 1)
self.top_p = client_args.pop("topP", 1.0)
def close(self) -> None:
"""Close the client."""
pass
def get_model_params(self) -> Dict:
"""
Get model params.
By getting model params from the server, we can add to request
and make sure cache keys are unique to model.
Returns:
model params.
"""
return {"model_name": "ai21", "engine": self.engine}
def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]:
"""
Get request string function.
Args:
query: query string.
Returns:
request function that takes no input.
request parameters as dict.
"""
request_params = {
"engine": kwargs.get("engine", self.engine),
"prompt": query,
"temperature": kwargs.get("temperature", self.temperature),
"maxTokens": kwargs.get("maxTokens", self.max_tokens),
"topKReturn": kwargs.get("topKReturn", self.top_k_return),
"numResults": kwargs.get("numResults", self.num_results),
"topP": kwargs.get("topP", self.top_p),
}
def _run_completion() -> Dict:
post_str = self.host + "/" + self.engine + "/complete"
print(self.api_key)
print(post_str)
print("https://api.ai21.com/studio/v1/j1-large/complete")
print(request_params)
res = requests.post(
post_str,
headers={"Authorization": f"Bearer {self.api_key}"},
json=request_params,
)
return res.json()
return _run_completion, request_params

View File

@ -28,6 +28,19 @@ class Client(ABC):
"""Close the client."""
raise NotImplementedError()
@abstractmethod
def get_model_params(self) -> Dict:
"""
Get model params.
By getting model params from the server, we can add to request
and make sure cache keys are unique to model.
Returns:
model params.
"""
raise NotImplementedError()
@abstractmethod
def connect(self, connection_str: str, client_args: Dict[str, Any]) -> None:
"""

138
manifest/clients/crfm.py Normal file
View File

@ -0,0 +1,138 @@
"""OpenAI client."""
import logging
import os
import sys
from typing import Any, Callable, Dict, Optional, Tuple
from manifest.clients.client import Client
crfm_code_dir = os.environ.get("CRFM_CODE_DIR", "/home/code/benchmarking")
sys.path.append(crfm_code_dir)
from src.common.authentication import Authentication # type: ignore
from src.common.request import Request, RequestResult # type: ignore
from src.proxy.remote_service import RemoteService # type: ignore
logger = logging.getLogger(__name__)
CRFM_ENGINES = {
"ai21/j1-jumbo",
"ai21/j1-grande",
"ai21/j1-large",
}
class CRFMClient(Client):
"""CRFMClient client."""
def connect(
self,
connection_str: Optional[str] = None,
client_args: Dict[str, Any] = {},
) -> None:
"""
Connect to the CRFM endpoint.
connection_str is passed as default CRFM_API_KEY if variable not set.
Args:
connection_str: connection string.
client_args: client arguments.
"""
self.service = RemoteService("https://crfm-models.stanford.edu")
api_key = os.environ.get("CRFM_API_KEY", connection_str)
if api_key is None:
raise ValueError(
"CRFM API key not set. Set CRFM_API_KEY environment "
"variable or pass through `connection_str`."
)
self.auth = Authentication(api_key=api_key)
self.engine = client_args.pop("engine", "ai21/j1-large")
if self.engine not in CRFM_ENGINES:
raise ValueError(f"Invalid engine {self.engine}. Must be {CRFM_ENGINES}.")
self.temperature = client_args.pop("temperature", 0.0)
self.max_tokens = client_args.pop("max_tokens", 10)
self.top_k_per_token = client_args.pop("top_k_per_token", 1)
self.num_completions = client_args.pop("num_completions", 1)
self.stop_sequences = client_args.pop("stop_sequences", [])
self.top_p = client_args.pop("top_p", 1.0)
self.presence_penalty = client_args.pop("presence_penalty", 1.0)
self.frequency_penalty = client_args.pop("frequency_penalty", 1.0)
def close(self) -> None:
"""Close the client."""
pass
def get_model_params(self) -> Dict:
"""
Get model params.
By getting model params from the server, we can add to request
and make sure cache keys are unique to model.
Returns:
model params.
"""
return {"model_name": "crfm", "engine": self.engine}
def format_response(self, response: RequestResult) -> Dict[str, Any]:
"""
Format RequestResult to dict.
Args:
response: RequestResult
Return:
response as dict
"""
return {
"object": "text_completion",
"model": self.engine,
"choices": [
{
"text": text.text,
# TODO: Add in more metadata for HF models
# "logprobs": {
# "tokens": result["tokens"],
# "token_logprobs": result["token_scores"],
# "text_offset": result["text_offset"],
# "top_logprobs": result["top_logprobs"],
# "finish_reason": "length",
# },
}
for text in response.completions
],
}
def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]:
"""
Get request string function.
Args:
query: query string.
Returns:
request function that takes no input.
request parameters as dict.
"""
request_params = {
"model": kwargs.get("engine", self.engine),
"prompt": query,
"temperature": kwargs.get("temperature", self.temperature),
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
"top_k_per_token": kwargs.get("top_k_per_token", self.top_k_per_token),
"num_completions": kwargs.get("num_completions", self.num_completions),
"stop_sequences": kwargs.get("stop_sequences", self.stop_sequences),
"top_p": kwargs.get("top_p", self.top_p),
"presence_penalty": kwargs.get("presence_penalty", self.presence_penalty),
"frequency_penalty": kwargs.get(
"frequency_penalty", self.frequency_penalty
),
}
def _run_completion() -> Dict:
request = Request(**request_params)
request_result = self.service.make_request(self.auth, request)
return self.format_response(request_result)
return _run_completion, request_params

View File

@ -30,6 +30,18 @@ class DummyClient(Client):
"""Close the client."""
pass
def get_model_params(self) -> Dict:
"""
Get model params.
By getting model params from the server, we can add to request
and make sure cache keys are unique to model.
Returns:
model params.
"""
return {"engine": "dummy"}
def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]:
"""
Get request string function.

View File

@ -49,6 +49,8 @@ class OpenAIClient(Client):
self.temperature = client_args.pop("temperature", 0.0)
self.max_tokens = client_args.pop("max_tokens", 10)
self.top_p = client_args.pop("top_p", 1.0)
self.logprobs = client_args.pop("logprobs", None)
self.best_of = client_args.pop("best_of", 1)
self.frequency_penalty = client_args.pop("frequency_penalty", 0.0)
self.presence_penalty = client_args.pop("presence_penalty", 0.0)
self.n = client_args.pop("n", 1)
@ -57,6 +59,18 @@ class OpenAIClient(Client):
"""Close the client."""
pass
def get_model_params(self) -> Dict:
"""
Get model params.
By getting model params from the server, we can add to request
and make sure cache keys are unique to model.
Returns:
model params.
"""
return {"model_name": "openai", "engine": self.engine}
def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]:
"""
Get request string function.
@ -77,6 +91,8 @@ class OpenAIClient(Client):
"frequency_penalty": kwargs.get(
"frequency_penalty", self.frequency_penalty
),
"logprobs": kwargs.get("logprobs", self.logprobs),
"best_of": kwargs.get("best_of", self.best_of),
"presence_penalty": kwargs.get("presence_penalty", self.presence_penalty),
"n": kwargs.get("n", self.n),
}

View File

@ -34,6 +34,18 @@ class OPTClient(Client):
"""Close the client."""
pass
def get_model_params(self) -> Dict:
"""
Get model params.
By getting model params from the server, we can add to request
and make sure cache keys are unique to model.
Returns:
model params.
"""
return {"model_name": "opt"}
def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]:
"""
Get request string function.

View File

@ -4,8 +4,10 @@ from typing import Any, Iterable, List, Optional, Union
from tqdm.auto import tqdm
from manifest.caches.noop import NoopCache
from manifest.caches.redis import RedisCache
from manifest.caches.sqlite import SQLiteCache
from manifest.clients.ai21 import AI21Client
from manifest.clients.dummy import DummyClient
from manifest.clients.huggingface import HuggingFaceClient
from manifest.clients.openai import OpenAIClient
@ -18,6 +20,7 @@ logger = logging.getLogger(__name__)
CLIENT_CONSTRUCTORS = {
"openai": OpenAIClient,
"ai21": AI21Client,
"huggingface": HuggingFaceClient,
"opt": OPTClient,
"dummy": DummyClient,
@ -26,8 +29,17 @@ CLIENT_CONSTRUCTORS = {
CACHE_CONSTRUCTORS = {
"redis": RedisCache,
"sqlite": SQLiteCache,
"noop": NoopCache,
}
try:
from manifest.clients.crfm import CRFMClient
CLIENT_CONSTRUCTORS["crfm"] = CRFMClient
except ImportError:
# TODO: remove this when CRFM is public
pass
class Manifest:
"""Manifest session object."""
@ -36,8 +48,8 @@ class Manifest:
self,
client_name: str = "openai",
client_connection: Optional[str] = None,
cache_name: str = "redis",
cache_connection: str = "localhost:6379",
cache_name: str = "noop",
cache_connection: Optional[str] = None,
stop_token: str = "",
**kwargs: Any,
):
@ -65,11 +77,13 @@ class Manifest:
f"Choices are {list(CACHE_CONSTRUCTORS.keys())}"
)
self.client_name = client_name
# Must pass kwargs as dict to client "pop" methods removed used arguments
# Must pass kwargs as dict for client "pop" methods removed used arguments
self.client = CLIENT_CONSTRUCTORS[client_name]( # type: ignore
client_connection, client_args=kwargs
)
self.cache = CACHE_CONSTRUCTORS[cache_name](cache_connection, cache_args=kwargs)
self.cache = CACHE_CONSTRUCTORS[cache_name]( # type: ignore
cache_connection, cache_args=kwargs
)
if len(kwargs) > 0:
raise ValueError(f"{list(kwargs.items())} arguments are not recognized.")
@ -82,7 +96,7 @@ class Manifest:
def run(
self,
prompt: Prompt,
prompt: Union[Prompt, str],
input: Optional[Any] = None,
overwrite_cache: bool = False,
stop_token: Optional[str] = None,
@ -93,7 +107,7 @@ class Manifest:
Run the prompt.
Args:
prompt: prompt to run.
prompt: prompt to run. If string, will cast to prompt.
input: input to prompt.
overwrite_cache: whether to overwrite cache.
stop_token: stop token for prompt generation.
@ -103,6 +117,8 @@ class Manifest:
Returns:
response from prompt.
"""
if isinstance(prompt, str):
prompt = Prompt(prompt)
stop_token = stop_token if stop_token is not None else self.stop_token
prompt_str = prompt(input)
possible_request, full_kwargs = self.client.get_request(prompt_str, **kwargs)
@ -143,6 +159,11 @@ class Manifest:
Returns:
batch of responses.
"""
if isinstance(prompt, str):
raise ValueError(
"Prompt must be a Prompt object for batch run on data. "
"We only support strings in `manifest.run`."
)
if input is None:
input = [None]
return [

View File

@ -16,7 +16,8 @@ class Response:
not isinstance(self._response["choices"], list)
):
raise ValueError(
"Response must be serialized to a dict with a list of choices"
"Response must be serialized to a dict with a list of choices. "
f"Response is {self._response}."
)
if len(self._response["choices"]) > 0:
if "text" not in self._response["choices"][0]:

1862
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -22,8 +22,8 @@ openai = "^0.18.1"
redis = "^4.3.1"
dill = "^0.3.5"
Flask = "^2.1.2"
transformers = "^4.19.2"
torch = "^1.11.0"
#transformers = "^4.19.2"
#torch = "^1.11.0"
requests = "^2.27.1"
tqdm = "^4.64.0"
types-redis = "^4.2.6"

View File

@ -3,6 +3,7 @@ import pytest
from redis import Redis
from sqlitedict import SqliteDict
from manifest.caches.noop import NoopCache
from manifest.caches.redis import RedisCache
from manifest.caches.sqlite import SQLiteCache
@ -69,3 +70,33 @@ def test_get(sqlite_cache, redis_cache, cache_type):
assert response.get_response() == "hello"
assert not response.is_cached()
assert response.get_request() == test_request
def test_noop_cache():
"""Test cache that is a no-op cache."""
cache = NoopCache(None)
cache.set_key("test", "valueA")
cache.set_key("testA", "valueB")
assert cache.get_key("test") is None
assert cache.get_key("testA") is None
cache.set_key("testA", "valueC")
assert cache.get_key("testA") is None
cache.get_key("test", table="prompt") is None
cache.set_key("test", "valueA", table="prompt")
cache.get_key("test", table="prompt") is None
# Assert always not cached
test_request = {"test": "hello", "testA": "world"}
compute = lambda: {"choices": [{"text": "hello"}]}
response = cache.get(test_request, overwrite_cache=False, compute=compute)
assert response.get_response() == "hello"
assert not response.is_cached()
assert response.get_request() == test_request
response = cache.get(test_request, overwrite_cache=False, compute=compute)
assert response.get_response() == "hello"
assert not response.is_cached()
assert response.get_request() == test_request

View File

@ -11,9 +11,9 @@ def test_init():
assert str(exc_info.value) == "Response must be str or dict"
with pytest.raises(ValueError) as exc_info:
response = Response({"test": "hello"}, False, {})
assert (
str(exc_info.value)
== "Response must be serialized to a dict with a list of choices"
assert str(exc_info.value) == (
"Response must be serialized to a dict with a list of choices. "
"Response is {'test': 'hello'}."
)
with pytest.raises(ValueError) as exc_info:
response = Response({"choices": [{"blah": "hello"}]}, False, {})