mirror of https://github.com/HazyResearch/manifest
commit
7766d3d6c2
@ -0,0 +1,47 @@
|
||||
"""Noop cache."""
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from manifest.caches import Cache
|
||||
|
||||
|
||||
class NoopCache(Cache):
|
||||
"""A Noop cache that caches nothing for request/response pairs."""
|
||||
|
||||
def connect(self, connection_str: str, cache_args: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Connect to client.
|
||||
|
||||
Args:
|
||||
connection_str: connection string.
|
||||
cache_args: cache arguments.
|
||||
"""
|
||||
pass
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the client."""
|
||||
pass
|
||||
|
||||
def get_key(self, key: str, table: str = "default") -> Union[str, None]:
|
||||
"""
|
||||
Return None key for never in cache.
|
||||
|
||||
Args:
|
||||
key: key for cache.
|
||||
table: table to get key in.
|
||||
"""
|
||||
return None
|
||||
|
||||
def set_key(self, key: str, value: str, table: str = "default") -> None:
|
||||
"""
|
||||
Do not set anything as no cache.
|
||||
|
||||
Args:
|
||||
key: key for cache.
|
||||
value: new value for key.
|
||||
table: table to set key in.
|
||||
"""
|
||||
pass
|
||||
|
||||
def commit(self) -> None:
|
||||
"""Commit any results."""
|
||||
pass
|
@ -0,0 +1,103 @@
|
||||
"""OpenAI client."""
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Optional, Tuple
|
||||
|
||||
import requests
|
||||
|
||||
from manifest.clients.client import Client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
AI21_ENGINES = {
|
||||
"j1-jumbo",
|
||||
"j1-grande",
|
||||
"j1-large",
|
||||
}
|
||||
|
||||
|
||||
class AI21Client(Client):
|
||||
"""AI21Client client."""
|
||||
|
||||
def connect(
|
||||
self,
|
||||
connection_str: Optional[str] = None,
|
||||
client_args: Dict[str, Any] = {},
|
||||
) -> None:
|
||||
"""
|
||||
Connect to the AI21 server.
|
||||
|
||||
connection_str is passed as default AI21_API_KEY if variable not set.
|
||||
|
||||
Args:
|
||||
connection_str: connection string.
|
||||
client_args: client arguments.
|
||||
"""
|
||||
# Taken from https://studio.ai21.com/docs/api/
|
||||
self.host = "https://api.ai21.com/studio/v1"
|
||||
self.api_key = os.environ.get("AI21_API_KEY", connection_str)
|
||||
if self.api_key is None:
|
||||
raise ValueError(
|
||||
"AI21 API key not set. Set AI21_API_KEY environment "
|
||||
"variable or pass through `connection_str`."
|
||||
)
|
||||
self.engine = client_args.pop("engine", "j1-large")
|
||||
if self.engine not in AI21_ENGINES:
|
||||
raise ValueError(f"Invalid engine {self.engine}. Must be {AI21_ENGINES}.")
|
||||
self.temperature = client_args.pop("temperature", 0.0)
|
||||
self.max_tokens = client_args.pop("max_tokens", 10)
|
||||
self.top_k_return = client_args.pop("topKReturn", 1.0)
|
||||
self.num_results = client_args.pop("numResults", 1)
|
||||
self.top_p = client_args.pop("topP", 1.0)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the client."""
|
||||
pass
|
||||
|
||||
def get_model_params(self) -> Dict:
|
||||
"""
|
||||
Get model params.
|
||||
|
||||
By getting model params from the server, we can add to request
|
||||
and make sure cache keys are unique to model.
|
||||
|
||||
Returns:
|
||||
model params.
|
||||
"""
|
||||
return {"model_name": "ai21", "engine": self.engine}
|
||||
|
||||
def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]:
|
||||
"""
|
||||
Get request string function.
|
||||
|
||||
Args:
|
||||
query: query string.
|
||||
|
||||
Returns:
|
||||
request function that takes no input.
|
||||
request parameters as dict.
|
||||
"""
|
||||
request_params = {
|
||||
"engine": kwargs.get("engine", self.engine),
|
||||
"prompt": query,
|
||||
"temperature": kwargs.get("temperature", self.temperature),
|
||||
"maxTokens": kwargs.get("maxTokens", self.max_tokens),
|
||||
"topKReturn": kwargs.get("topKReturn", self.top_k_return),
|
||||
"numResults": kwargs.get("numResults", self.num_results),
|
||||
"topP": kwargs.get("topP", self.top_p),
|
||||
}
|
||||
|
||||
def _run_completion() -> Dict:
|
||||
post_str = self.host + "/" + self.engine + "/complete"
|
||||
print(self.api_key)
|
||||
print(post_str)
|
||||
print("https://api.ai21.com/studio/v1/j1-large/complete")
|
||||
print(request_params)
|
||||
res = requests.post(
|
||||
post_str,
|
||||
headers={"Authorization": f"Bearer {self.api_key}"},
|
||||
json=request_params,
|
||||
)
|
||||
return res.json()
|
||||
|
||||
return _run_completion, request_params
|
@ -0,0 +1,138 @@
|
||||
"""OpenAI client."""
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Callable, Dict, Optional, Tuple
|
||||
|
||||
from manifest.clients.client import Client
|
||||
|
||||
crfm_code_dir = os.environ.get("CRFM_CODE_DIR", "/home/code/benchmarking")
|
||||
sys.path.append(crfm_code_dir)
|
||||
|
||||
from src.common.authentication import Authentication # type: ignore
|
||||
from src.common.request import Request, RequestResult # type: ignore
|
||||
from src.proxy.remote_service import RemoteService # type: ignore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CRFM_ENGINES = {
|
||||
"ai21/j1-jumbo",
|
||||
"ai21/j1-grande",
|
||||
"ai21/j1-large",
|
||||
}
|
||||
|
||||
|
||||
class CRFMClient(Client):
|
||||
"""CRFMClient client."""
|
||||
|
||||
def connect(
|
||||
self,
|
||||
connection_str: Optional[str] = None,
|
||||
client_args: Dict[str, Any] = {},
|
||||
) -> None:
|
||||
"""
|
||||
Connect to the CRFM endpoint.
|
||||
|
||||
connection_str is passed as default CRFM_API_KEY if variable not set.
|
||||
|
||||
Args:
|
||||
connection_str: connection string.
|
||||
client_args: client arguments.
|
||||
"""
|
||||
self.service = RemoteService("https://crfm-models.stanford.edu")
|
||||
api_key = os.environ.get("CRFM_API_KEY", connection_str)
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"CRFM API key not set. Set CRFM_API_KEY environment "
|
||||
"variable or pass through `connection_str`."
|
||||
)
|
||||
self.auth = Authentication(api_key=api_key)
|
||||
self.engine = client_args.pop("engine", "ai21/j1-large")
|
||||
if self.engine not in CRFM_ENGINES:
|
||||
raise ValueError(f"Invalid engine {self.engine}. Must be {CRFM_ENGINES}.")
|
||||
self.temperature = client_args.pop("temperature", 0.0)
|
||||
self.max_tokens = client_args.pop("max_tokens", 10)
|
||||
self.top_k_per_token = client_args.pop("top_k_per_token", 1)
|
||||
self.num_completions = client_args.pop("num_completions", 1)
|
||||
self.stop_sequences = client_args.pop("stop_sequences", [])
|
||||
self.top_p = client_args.pop("top_p", 1.0)
|
||||
self.presence_penalty = client_args.pop("presence_penalty", 1.0)
|
||||
self.frequency_penalty = client_args.pop("frequency_penalty", 1.0)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the client."""
|
||||
pass
|
||||
|
||||
def get_model_params(self) -> Dict:
|
||||
"""
|
||||
Get model params.
|
||||
|
||||
By getting model params from the server, we can add to request
|
||||
and make sure cache keys are unique to model.
|
||||
|
||||
Returns:
|
||||
model params.
|
||||
"""
|
||||
return {"model_name": "crfm", "engine": self.engine}
|
||||
|
||||
def format_response(self, response: RequestResult) -> Dict[str, Any]:
|
||||
"""
|
||||
Format RequestResult to dict.
|
||||
|
||||
Args:
|
||||
response: RequestResult
|
||||
|
||||
Return:
|
||||
response as dict
|
||||
"""
|
||||
return {
|
||||
"object": "text_completion",
|
||||
"model": self.engine,
|
||||
"choices": [
|
||||
{
|
||||
"text": text.text,
|
||||
# TODO: Add in more metadata for HF models
|
||||
# "logprobs": {
|
||||
# "tokens": result["tokens"],
|
||||
# "token_logprobs": result["token_scores"],
|
||||
# "text_offset": result["text_offset"],
|
||||
# "top_logprobs": result["top_logprobs"],
|
||||
# "finish_reason": "length",
|
||||
# },
|
||||
}
|
||||
for text in response.completions
|
||||
],
|
||||
}
|
||||
|
||||
def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]:
|
||||
"""
|
||||
Get request string function.
|
||||
|
||||
Args:
|
||||
query: query string.
|
||||
|
||||
Returns:
|
||||
request function that takes no input.
|
||||
request parameters as dict.
|
||||
"""
|
||||
request_params = {
|
||||
"model": kwargs.get("engine", self.engine),
|
||||
"prompt": query,
|
||||
"temperature": kwargs.get("temperature", self.temperature),
|
||||
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
|
||||
"top_k_per_token": kwargs.get("top_k_per_token", self.top_k_per_token),
|
||||
"num_completions": kwargs.get("num_completions", self.num_completions),
|
||||
"stop_sequences": kwargs.get("stop_sequences", self.stop_sequences),
|
||||
"top_p": kwargs.get("top_p", self.top_p),
|
||||
"presence_penalty": kwargs.get("presence_penalty", self.presence_penalty),
|
||||
"frequency_penalty": kwargs.get(
|
||||
"frequency_penalty", self.frequency_penalty
|
||||
),
|
||||
}
|
||||
|
||||
def _run_completion() -> Dict:
|
||||
request = Request(**request_params)
|
||||
request_result = self.service.make_request(self.auth, request)
|
||||
return self.format_response(request_result)
|
||||
|
||||
return _run_completion, request_params
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue