mirror of https://github.com/HazyResearch/manifest
Zoo (#23)
* Zoo Model * Remove optional import zoo * Read model name from zoo model * Logprobs passed through raw response for gold choices Co-authored-by: Simran <emailsimran@gmail.com> Co-authored-by: Dan Fu <danfu@cs.stanford.edu>laurel/helm
parent
e0a76d1f93
commit
5428afdc58
@ -0,0 +1,94 @@
|
|||||||
|
"""Zoo model."""
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
|
from manifest.api.models.model import Model
|
||||||
|
|
||||||
|
ZOO_PATH = os.environ.get("ZOO_PATH", None)
|
||||||
|
if not ZOO_PATH:
|
||||||
|
raise ImportError("ZOO_PATH environment variable not set.")
|
||||||
|
sys.path.append(ZOO_PATH)
|
||||||
|
|
||||||
|
from src.models.s4_seq import S4LMManifest # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class ZooModel(Model):
|
||||||
|
"""Zoo model."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name_or_path: str,
|
||||||
|
model_config: str,
|
||||||
|
cache_dir: str,
|
||||||
|
device: int,
|
||||||
|
use_accelerate: bool,
|
||||||
|
use_parallelize: bool,
|
||||||
|
perc_max_gpu_mem_red: float,
|
||||||
|
use_fp16: bool,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize model.
|
||||||
|
|
||||||
|
All arguments will be passed in the request from Manifest.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name_or_path: model name string.
|
||||||
|
model_config: model config path.
|
||||||
|
cache_dir: cache directory for model.
|
||||||
|
device: device to use for model.
|
||||||
|
use_accelerate: whether to use accelerate for multi-gpu inference.
|
||||||
|
use_parallelize: use HF default parallelize
|
||||||
|
perc_max_gpu_mem_red: percent max memory reduction in accelerate
|
||||||
|
use_fp16: use fp16 for model weights.
|
||||||
|
"""
|
||||||
|
# Check if providing path
|
||||||
|
self.model_path = model_name_or_path
|
||||||
|
self.model_config = model_config
|
||||||
|
if not self.model_config:
|
||||||
|
raise ValueError("Must provide model config.")
|
||||||
|
self.model = S4LMManifest(
|
||||||
|
config_path=self.model_config,
|
||||||
|
weights_path=self.model_path,
|
||||||
|
)
|
||||||
|
# Can only load this after the model has been initialized
|
||||||
|
self.model_name = self.model.get_model_name()
|
||||||
|
|
||||||
|
def get_init_params(self) -> Dict:
|
||||||
|
"""Return init params to determine what model is being used."""
|
||||||
|
return {
|
||||||
|
"model_name": self.model_name,
|
||||||
|
"model_path": self.model_path,
|
||||||
|
"model_config": self.model_config,
|
||||||
|
}
|
||||||
|
|
||||||
|
def generate(self, prompt: str, **kwargs: Any) -> List[str]:
|
||||||
|
"""
|
||||||
|
Generate the prompt from model.
|
||||||
|
|
||||||
|
Outputs must be generated text, not including prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: promt to generate from.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list of generated text (list of length 1 for 1 generation).
|
||||||
|
"""
|
||||||
|
print(prompt)
|
||||||
|
final_results = self.model.generate(prompt, **kwargs)
|
||||||
|
return final_results
|
||||||
|
|
||||||
|
def logits_scoring(
|
||||||
|
self, prompt: str, gold_choices: List[str], **kwargs: Any
|
||||||
|
) -> Tuple[str, float]:
|
||||||
|
"""
|
||||||
|
Given the prompt and gold choices, choose the best choice with max logits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: promt to generate from.
|
||||||
|
gold_choices: list of choices to choose from.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
the returned gold choice and the score
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
@ -0,0 +1,102 @@
|
|||||||
|
"""Zoo client."""
|
||||||
|
import logging
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from manifest.clients.client import Client
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# User param -> (client param, default value)
|
||||||
|
ZOO_PARAMS: Dict[str, Tuple[str, str]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
class ZooClient(Client):
|
||||||
|
"""Zoo client."""
|
||||||
|
|
||||||
|
def connect(
|
||||||
|
self,
|
||||||
|
connection_str: Optional[str] = None,
|
||||||
|
client_args: Dict[str, Any] = {},
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Connect to the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connection_str: connection string.
|
||||||
|
client_args: client arguments.
|
||||||
|
"""
|
||||||
|
self.host = connection_str.rstrip("/")
|
||||||
|
for key in ZOO_PARAMS:
|
||||||
|
setattr(self, key, client_args.pop(key, ZOO_PARAMS[key][1]))
|
||||||
|
self.model_params = self.get_model_params()
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close the client."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_model_params(self) -> Dict:
|
||||||
|
"""
|
||||||
|
Get model params.
|
||||||
|
|
||||||
|
By getting model params from the server, we can add to request
|
||||||
|
and make sure cache keys are unique to model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
model params.
|
||||||
|
"""
|
||||||
|
res = requests.post(self.host + "/params")
|
||||||
|
return res.json()
|
||||||
|
|
||||||
|
def get_model_inputs(self) -> List:
|
||||||
|
"""
|
||||||
|
Get allowable model inputs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
model inputs.
|
||||||
|
"""
|
||||||
|
return list(ZOO_PARAMS.keys())
|
||||||
|
|
||||||
|
def get_request(
|
||||||
|
self, query: str, request_args: Dict[str, Any] = {}
|
||||||
|
) -> Tuple[Callable[[], Dict], Dict]:
|
||||||
|
"""
|
||||||
|
Get request string function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: query string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
request function that takes no input.
|
||||||
|
request parameters as dict.
|
||||||
|
"""
|
||||||
|
request_params = {"prompt": query}
|
||||||
|
# Zoo is greedy and takes all params
|
||||||
|
# TODO: Once zoo is finalized, fix this
|
||||||
|
for key in list(request_args.keys()):
|
||||||
|
request_params[key] = request_args.pop(key, None)
|
||||||
|
request_params.update(self.model_params)
|
||||||
|
|
||||||
|
def _run_completion() -> Dict:
|
||||||
|
post_str = self.host + "/completions"
|
||||||
|
res = requests.post(post_str, json=request_params)
|
||||||
|
return res.json()
|
||||||
|
|
||||||
|
return _run_completion, request_params
|
||||||
|
|
||||||
|
def get_choice_logit_request(
|
||||||
|
self, query: str, gold_choices: List[str], request_args: Dict[str, Any] = {}
|
||||||
|
) -> Tuple[Callable[[], Dict], Dict]:
|
||||||
|
"""
|
||||||
|
Get request string function for choosing max choices.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: query string.
|
||||||
|
gold_choices: choices for model to choose from via max logits.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
request function that takes no input.
|
||||||
|
request parameters as dict.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Zoo does not support choice logit request.")
|
Loading…
Reference in New Issue