mirror of
https://github.com/HazyResearch/manifest
synced 2024-11-04 12:00:14 +00:00
fix: fixing https://github.com/HazyResearch/manifest/issues/7 and https://github.com/HazyResearch/manifest/issues/5
This commit is contained in:
parent
74b9302b1b
commit
f568875e57
1
Makefile
1
Makefile
@ -1,7 +1,6 @@
|
||||
dev:
|
||||
poetry install
|
||||
poetry run pre-commit install
|
||||
poetry run mypy --install-types
|
||||
|
||||
test: dev check
|
||||
poetry run pytest tests
|
||||
|
47
manifest/caches/noop.py
Normal file
47
manifest/caches/noop.py
Normal file
@ -0,0 +1,47 @@
|
||||
"""Noop cache."""
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from manifest.caches import Cache
|
||||
|
||||
|
||||
class NoopCache(Cache):
|
||||
"""A Noop cache that caches nothing for request/response pairs."""
|
||||
|
||||
def connect(self, connection_str: str, cache_args: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Connect to client.
|
||||
|
||||
Args:
|
||||
connection_str: connection string.
|
||||
cache_args: cache arguments.
|
||||
"""
|
||||
pass
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the client."""
|
||||
pass
|
||||
|
||||
def get_key(self, key: str, table: str = "default") -> Union[str, None]:
|
||||
"""
|
||||
Return None key for never in cache.
|
||||
|
||||
Args:
|
||||
key: key for cache.
|
||||
table: table to get key in.
|
||||
"""
|
||||
return None
|
||||
|
||||
def set_key(self, key: str, value: str, table: str = "default") -> None:
|
||||
"""
|
||||
Do not set anything as no cache.
|
||||
|
||||
Args:
|
||||
key: key for cache.
|
||||
value: new value for key.
|
||||
table: table to set key in.
|
||||
"""
|
||||
pass
|
||||
|
||||
def commit(self) -> None:
|
||||
"""Commit any results."""
|
||||
pass
|
103
manifest/clients/ai21.py
Normal file
103
manifest/clients/ai21.py
Normal file
@ -0,0 +1,103 @@
|
||||
"""OpenAI client."""
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Optional, Tuple
|
||||
|
||||
import requests
|
||||
|
||||
from manifest.clients.client import Client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
AI21_ENGINES = {
|
||||
"j1-jumbo",
|
||||
"j1-grande",
|
||||
"j1-large",
|
||||
}
|
||||
|
||||
|
||||
class AI21Client(Client):
|
||||
"""AI21Client client."""
|
||||
|
||||
def connect(
|
||||
self,
|
||||
connection_str: Optional[str] = None,
|
||||
client_args: Dict[str, Any] = {},
|
||||
) -> None:
|
||||
"""
|
||||
Connect to the AI21 server.
|
||||
|
||||
connection_str is passed as default AI21_API_KEY if variable not set.
|
||||
|
||||
Args:
|
||||
connection_str: connection string.
|
||||
client_args: client arguments.
|
||||
"""
|
||||
# Taken from https://studio.ai21.com/docs/api/
|
||||
self.host = "https://api.ai21.com/studio/v1"
|
||||
self.api_key = os.environ.get("AI21_API_KEY", connection_str)
|
||||
if self.api_key is None:
|
||||
raise ValueError(
|
||||
"AI21 API key not set. Set AI21_API_KEY environment "
|
||||
"variable or pass through `connection_str`."
|
||||
)
|
||||
self.engine = client_args.pop("engine", "j1-large")
|
||||
if self.engine not in AI21_ENGINES:
|
||||
raise ValueError(f"Invalid engine {self.engine}. Must be {AI21_ENGINES}.")
|
||||
self.temperature = client_args.pop("temperature", 0.0)
|
||||
self.max_tokens = client_args.pop("max_tokens", 10)
|
||||
self.top_k_return = client_args.pop("topKReturn", 1.0)
|
||||
self.num_results = client_args.pop("numResults", 1)
|
||||
self.top_p = client_args.pop("topP", 1.0)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the client."""
|
||||
pass
|
||||
|
||||
def get_model_params(self) -> Dict:
|
||||
"""
|
||||
Get model params.
|
||||
|
||||
By getting model params from the server, we can add to request
|
||||
and make sure cache keys are unique to model.
|
||||
|
||||
Returns:
|
||||
model params.
|
||||
"""
|
||||
return {"model_name": "ai21", "engine": self.engine}
|
||||
|
||||
def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]:
|
||||
"""
|
||||
Get request string function.
|
||||
|
||||
Args:
|
||||
query: query string.
|
||||
|
||||
Returns:
|
||||
request function that takes no input.
|
||||
request parameters as dict.
|
||||
"""
|
||||
request_params = {
|
||||
"engine": kwargs.get("engine", self.engine),
|
||||
"prompt": query,
|
||||
"temperature": kwargs.get("temperature", self.temperature),
|
||||
"maxTokens": kwargs.get("maxTokens", self.max_tokens),
|
||||
"topKReturn": kwargs.get("topKReturn", self.top_k_return),
|
||||
"numResults": kwargs.get("numResults", self.num_results),
|
||||
"topP": kwargs.get("topP", self.top_p),
|
||||
}
|
||||
|
||||
def _run_completion() -> Dict:
|
||||
post_str = self.host + "/" + self.engine + "/complete"
|
||||
print(self.api_key)
|
||||
print(post_str)
|
||||
print("https://api.ai21.com/studio/v1/j1-large/complete")
|
||||
print(request_params)
|
||||
res = requests.post(
|
||||
post_str,
|
||||
headers={"Authorization": f"Bearer {self.api_key}"},
|
||||
json=request_params,
|
||||
)
|
||||
return res.json()
|
||||
|
||||
return _run_completion, request_params
|
@ -28,6 +28,19 @@ class Client(ABC):
|
||||
"""Close the client."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def get_model_params(self) -> Dict:
|
||||
"""
|
||||
Get model params.
|
||||
|
||||
By getting model params from the server, we can add to request
|
||||
and make sure cache keys are unique to model.
|
||||
|
||||
Returns:
|
||||
model params.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def connect(self, connection_str: str, client_args: Dict[str, Any]) -> None:
|
||||
"""
|
||||
|
138
manifest/clients/crfm.py
Normal file
138
manifest/clients/crfm.py
Normal file
@ -0,0 +1,138 @@
|
||||
"""OpenAI client."""
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Callable, Dict, Optional, Tuple
|
||||
|
||||
from manifest.clients.client import Client
|
||||
|
||||
crfm_code_dir = os.environ.get("CRFM_CODE_DIR", "/home/code/benchmarking")
|
||||
sys.path.append(crfm_code_dir)
|
||||
|
||||
from src.common.authentication import Authentication # type: ignore
|
||||
from src.common.request import Request, RequestResult # type: ignore
|
||||
from src.proxy.remote_service import RemoteService # type: ignore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CRFM_ENGINES = {
|
||||
"ai21/j1-jumbo",
|
||||
"ai21/j1-grande",
|
||||
"ai21/j1-large",
|
||||
}
|
||||
|
||||
|
||||
class CRFMClient(Client):
|
||||
"""CRFMClient client."""
|
||||
|
||||
def connect(
|
||||
self,
|
||||
connection_str: Optional[str] = None,
|
||||
client_args: Dict[str, Any] = {},
|
||||
) -> None:
|
||||
"""
|
||||
Connect to the CRFM endpoint.
|
||||
|
||||
connection_str is passed as default CRFM_API_KEY if variable not set.
|
||||
|
||||
Args:
|
||||
connection_str: connection string.
|
||||
client_args: client arguments.
|
||||
"""
|
||||
self.service = RemoteService("https://crfm-models.stanford.edu")
|
||||
api_key = os.environ.get("CRFM_API_KEY", connection_str)
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"CRFM API key not set. Set CRFM_API_KEY environment "
|
||||
"variable or pass through `connection_str`."
|
||||
)
|
||||
self.auth = Authentication(api_key=api_key)
|
||||
self.engine = client_args.pop("engine", "ai21/j1-large")
|
||||
if self.engine not in CRFM_ENGINES:
|
||||
raise ValueError(f"Invalid engine {self.engine}. Must be {CRFM_ENGINES}.")
|
||||
self.temperature = client_args.pop("temperature", 0.0)
|
||||
self.max_tokens = client_args.pop("max_tokens", 10)
|
||||
self.top_k_per_token = client_args.pop("top_k_per_token", 1)
|
||||
self.num_completions = client_args.pop("num_completions", 1)
|
||||
self.stop_sequences = client_args.pop("stop_sequences", [])
|
||||
self.top_p = client_args.pop("top_p", 1.0)
|
||||
self.presence_penalty = client_args.pop("presence_penalty", 1.0)
|
||||
self.frequency_penalty = client_args.pop("frequency_penalty", 1.0)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the client."""
|
||||
pass
|
||||
|
||||
def get_model_params(self) -> Dict:
|
||||
"""
|
||||
Get model params.
|
||||
|
||||
By getting model params from the server, we can add to request
|
||||
and make sure cache keys are unique to model.
|
||||
|
||||
Returns:
|
||||
model params.
|
||||
"""
|
||||
return {"model_name": "crfm", "engine": self.engine}
|
||||
|
||||
def format_response(self, response: RequestResult) -> Dict[str, Any]:
|
||||
"""
|
||||
Format RequestResult to dict.
|
||||
|
||||
Args:
|
||||
response: RequestResult
|
||||
|
||||
Return:
|
||||
response as dict
|
||||
"""
|
||||
return {
|
||||
"object": "text_completion",
|
||||
"model": self.engine,
|
||||
"choices": [
|
||||
{
|
||||
"text": text.text,
|
||||
# TODO: Add in more metadata for HF models
|
||||
# "logprobs": {
|
||||
# "tokens": result["tokens"],
|
||||
# "token_logprobs": result["token_scores"],
|
||||
# "text_offset": result["text_offset"],
|
||||
# "top_logprobs": result["top_logprobs"],
|
||||
# "finish_reason": "length",
|
||||
# },
|
||||
}
|
||||
for text in response.completions
|
||||
],
|
||||
}
|
||||
|
||||
def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]:
|
||||
"""
|
||||
Get request string function.
|
||||
|
||||
Args:
|
||||
query: query string.
|
||||
|
||||
Returns:
|
||||
request function that takes no input.
|
||||
request parameters as dict.
|
||||
"""
|
||||
request_params = {
|
||||
"model": kwargs.get("engine", self.engine),
|
||||
"prompt": query,
|
||||
"temperature": kwargs.get("temperature", self.temperature),
|
||||
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
|
||||
"top_k_per_token": kwargs.get("top_k_per_token", self.top_k_per_token),
|
||||
"num_completions": kwargs.get("num_completions", self.num_completions),
|
||||
"stop_sequences": kwargs.get("stop_sequences", self.stop_sequences),
|
||||
"top_p": kwargs.get("top_p", self.top_p),
|
||||
"presence_penalty": kwargs.get("presence_penalty", self.presence_penalty),
|
||||
"frequency_penalty": kwargs.get(
|
||||
"frequency_penalty", self.frequency_penalty
|
||||
),
|
||||
}
|
||||
|
||||
def _run_completion() -> Dict:
|
||||
request = Request(**request_params)
|
||||
request_result = self.service.make_request(self.auth, request)
|
||||
return self.format_response(request_result)
|
||||
|
||||
return _run_completion, request_params
|
@ -30,6 +30,18 @@ class DummyClient(Client):
|
||||
"""Close the client."""
|
||||
pass
|
||||
|
||||
def get_model_params(self) -> Dict:
|
||||
"""
|
||||
Get model params.
|
||||
|
||||
By getting model params from the server, we can add to request
|
||||
and make sure cache keys are unique to model.
|
||||
|
||||
Returns:
|
||||
model params.
|
||||
"""
|
||||
return {"engine": "dummy"}
|
||||
|
||||
def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]:
|
||||
"""
|
||||
Get request string function.
|
||||
|
@ -49,6 +49,8 @@ class OpenAIClient(Client):
|
||||
self.temperature = client_args.pop("temperature", 0.0)
|
||||
self.max_tokens = client_args.pop("max_tokens", 10)
|
||||
self.top_p = client_args.pop("top_p", 1.0)
|
||||
self.logprobs = client_args.pop("logprobs", None)
|
||||
self.best_of = client_args.pop("best_of", 1)
|
||||
self.frequency_penalty = client_args.pop("frequency_penalty", 0.0)
|
||||
self.presence_penalty = client_args.pop("presence_penalty", 0.0)
|
||||
self.n = client_args.pop("n", 1)
|
||||
@ -57,6 +59,18 @@ class OpenAIClient(Client):
|
||||
"""Close the client."""
|
||||
pass
|
||||
|
||||
def get_model_params(self) -> Dict:
|
||||
"""
|
||||
Get model params.
|
||||
|
||||
By getting model params from the server, we can add to request
|
||||
and make sure cache keys are unique to model.
|
||||
|
||||
Returns:
|
||||
model params.
|
||||
"""
|
||||
return {"model_name": "openai", "engine": self.engine}
|
||||
|
||||
def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]:
|
||||
"""
|
||||
Get request string function.
|
||||
@ -77,6 +91,8 @@ class OpenAIClient(Client):
|
||||
"frequency_penalty": kwargs.get(
|
||||
"frequency_penalty", self.frequency_penalty
|
||||
),
|
||||
"logprobs": kwargs.get("logprobs", self.logprobs),
|
||||
"best_of": kwargs.get("best_of", self.best_of),
|
||||
"presence_penalty": kwargs.get("presence_penalty", self.presence_penalty),
|
||||
"n": kwargs.get("n", self.n),
|
||||
}
|
||||
|
@ -34,6 +34,18 @@ class OPTClient(Client):
|
||||
"""Close the client."""
|
||||
pass
|
||||
|
||||
def get_model_params(self) -> Dict:
|
||||
"""
|
||||
Get model params.
|
||||
|
||||
By getting model params from the server, we can add to request
|
||||
and make sure cache keys are unique to model.
|
||||
|
||||
Returns:
|
||||
model params.
|
||||
"""
|
||||
return {"model_name": "opt"}
|
||||
|
||||
def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]:
|
||||
"""
|
||||
Get request string function.
|
||||
|
@ -4,8 +4,10 @@ from typing import Any, Iterable, List, Optional, Union
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from manifest.caches.noop import NoopCache
|
||||
from manifest.caches.redis import RedisCache
|
||||
from manifest.caches.sqlite import SQLiteCache
|
||||
from manifest.clients.ai21 import AI21Client
|
||||
from manifest.clients.dummy import DummyClient
|
||||
from manifest.clients.huggingface import HuggingFaceClient
|
||||
from manifest.clients.openai import OpenAIClient
|
||||
@ -18,6 +20,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
CLIENT_CONSTRUCTORS = {
|
||||
"openai": OpenAIClient,
|
||||
"ai21": AI21Client,
|
||||
"huggingface": HuggingFaceClient,
|
||||
"opt": OPTClient,
|
||||
"dummy": DummyClient,
|
||||
@ -26,8 +29,17 @@ CLIENT_CONSTRUCTORS = {
|
||||
CACHE_CONSTRUCTORS = {
|
||||
"redis": RedisCache,
|
||||
"sqlite": SQLiteCache,
|
||||
"noop": NoopCache,
|
||||
}
|
||||
|
||||
try:
|
||||
from manifest.clients.crfm import CRFMClient
|
||||
|
||||
CLIENT_CONSTRUCTORS["crfm"] = CRFMClient
|
||||
except ImportError:
|
||||
# TODO: remove this when CRFM is public
|
||||
pass
|
||||
|
||||
|
||||
class Manifest:
|
||||
"""Manifest session object."""
|
||||
@ -36,8 +48,8 @@ class Manifest:
|
||||
self,
|
||||
client_name: str = "openai",
|
||||
client_connection: Optional[str] = None,
|
||||
cache_name: str = "redis",
|
||||
cache_connection: str = "localhost:6379",
|
||||
cache_name: str = "noop",
|
||||
cache_connection: Optional[str] = None,
|
||||
stop_token: str = "",
|
||||
**kwargs: Any,
|
||||
):
|
||||
@ -65,11 +77,13 @@ class Manifest:
|
||||
f"Choices are {list(CACHE_CONSTRUCTORS.keys())}"
|
||||
)
|
||||
self.client_name = client_name
|
||||
# Must pass kwargs as dict to client "pop" methods removed used arguments
|
||||
# Must pass kwargs as dict for client "pop" methods removed used arguments
|
||||
self.client = CLIENT_CONSTRUCTORS[client_name]( # type: ignore
|
||||
client_connection, client_args=kwargs
|
||||
)
|
||||
self.cache = CACHE_CONSTRUCTORS[cache_name](cache_connection, cache_args=kwargs)
|
||||
self.cache = CACHE_CONSTRUCTORS[cache_name]( # type: ignore
|
||||
cache_connection, cache_args=kwargs
|
||||
)
|
||||
if len(kwargs) > 0:
|
||||
raise ValueError(f"{list(kwargs.items())} arguments are not recognized.")
|
||||
|
||||
@ -82,7 +96,7 @@ class Manifest:
|
||||
|
||||
def run(
|
||||
self,
|
||||
prompt: Prompt,
|
||||
prompt: Union[Prompt, str],
|
||||
input: Optional[Any] = None,
|
||||
overwrite_cache: bool = False,
|
||||
stop_token: Optional[str] = None,
|
||||
@ -93,7 +107,7 @@ class Manifest:
|
||||
Run the prompt.
|
||||
|
||||
Args:
|
||||
prompt: prompt to run.
|
||||
prompt: prompt to run. If string, will cast to prompt.
|
||||
input: input to prompt.
|
||||
overwrite_cache: whether to overwrite cache.
|
||||
stop_token: stop token for prompt generation.
|
||||
@ -103,6 +117,8 @@ class Manifest:
|
||||
Returns:
|
||||
response from prompt.
|
||||
"""
|
||||
if isinstance(prompt, str):
|
||||
prompt = Prompt(prompt)
|
||||
stop_token = stop_token if stop_token is not None else self.stop_token
|
||||
prompt_str = prompt(input)
|
||||
possible_request, full_kwargs = self.client.get_request(prompt_str, **kwargs)
|
||||
@ -143,6 +159,11 @@ class Manifest:
|
||||
Returns:
|
||||
batch of responses.
|
||||
"""
|
||||
if isinstance(prompt, str):
|
||||
raise ValueError(
|
||||
"Prompt must be a Prompt object for batch run on data. "
|
||||
"We only support strings in `manifest.run`."
|
||||
)
|
||||
if input is None:
|
||||
input = [None]
|
||||
return [
|
||||
|
@ -16,7 +16,8 @@ class Response:
|
||||
not isinstance(self._response["choices"], list)
|
||||
):
|
||||
raise ValueError(
|
||||
"Response must be serialized to a dict with a list of choices"
|
||||
"Response must be serialized to a dict with a list of choices. "
|
||||
f"Response is {self._response}."
|
||||
)
|
||||
if len(self._response["choices"]) > 0:
|
||||
if "text" not in self._response["choices"][0]:
|
||||
|
1862
poetry.lock
generated
1862
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -22,8 +22,8 @@ openai = "^0.18.1"
|
||||
redis = "^4.3.1"
|
||||
dill = "^0.3.5"
|
||||
Flask = "^2.1.2"
|
||||
transformers = "^4.19.2"
|
||||
torch = "^1.11.0"
|
||||
#transformers = "^4.19.2"
|
||||
#torch = "^1.11.0"
|
||||
requests = "^2.27.1"
|
||||
tqdm = "^4.64.0"
|
||||
types-redis = "^4.2.6"
|
||||
|
@ -3,6 +3,7 @@ import pytest
|
||||
from redis import Redis
|
||||
from sqlitedict import SqliteDict
|
||||
|
||||
from manifest.caches.noop import NoopCache
|
||||
from manifest.caches.redis import RedisCache
|
||||
from manifest.caches.sqlite import SQLiteCache
|
||||
|
||||
@ -69,3 +70,33 @@ def test_get(sqlite_cache, redis_cache, cache_type):
|
||||
assert response.get_response() == "hello"
|
||||
assert not response.is_cached()
|
||||
assert response.get_request() == test_request
|
||||
|
||||
|
||||
def test_noop_cache():
|
||||
"""Test cache that is a no-op cache."""
|
||||
cache = NoopCache(None)
|
||||
cache.set_key("test", "valueA")
|
||||
cache.set_key("testA", "valueB")
|
||||
assert cache.get_key("test") is None
|
||||
assert cache.get_key("testA") is None
|
||||
|
||||
cache.set_key("testA", "valueC")
|
||||
assert cache.get_key("testA") is None
|
||||
|
||||
cache.get_key("test", table="prompt") is None
|
||||
cache.set_key("test", "valueA", table="prompt")
|
||||
cache.get_key("test", table="prompt") is None
|
||||
|
||||
# Assert always not cached
|
||||
test_request = {"test": "hello", "testA": "world"}
|
||||
compute = lambda: {"choices": [{"text": "hello"}]}
|
||||
|
||||
response = cache.get(test_request, overwrite_cache=False, compute=compute)
|
||||
assert response.get_response() == "hello"
|
||||
assert not response.is_cached()
|
||||
assert response.get_request() == test_request
|
||||
|
||||
response = cache.get(test_request, overwrite_cache=False, compute=compute)
|
||||
assert response.get_response() == "hello"
|
||||
assert not response.is_cached()
|
||||
assert response.get_request() == test_request
|
||||
|
@ -11,9 +11,9 @@ def test_init():
|
||||
assert str(exc_info.value) == "Response must be str or dict"
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
response = Response({"test": "hello"}, False, {})
|
||||
assert (
|
||||
str(exc_info.value)
|
||||
== "Response must be serialized to a dict with a list of choices"
|
||||
assert str(exc_info.value) == (
|
||||
"Response must be serialized to a dict with a list of choices. "
|
||||
"Response is {'test': 'hello'}."
|
||||
)
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
response = Response({"choices": [{"blah": "hello"}]}, False, {})
|
||||
|
Loading…
Reference in New Issue
Block a user