mirror of
https://github.com/HazyResearch/manifest
synced 2024-10-31 15:20:26 +00:00
feat: helm support added
This commit is contained in:
parent
56eae406ce
commit
934a0bd5cd
@ -6,6 +6,7 @@ Added
|
||||
* Standard request base model for all language inputs.
|
||||
* ChatGPT client. Requires CHATGPT_SESSION_KEY to be passed in.
|
||||
* Diffusion model support
|
||||
* HELM support
|
||||
|
||||
Fixed
|
||||
^^^^^^^^
|
||||
|
12
README.md
12
README.md
@ -22,17 +22,27 @@ pip install manifest-ml[chatgpt]
|
||||
```
|
||||
This installs [pyChatGPT](https://github.com/terry3041/pyChatGPT) and uses the ChatGPT session key to start a session. This key must be set as the `CHATGPT_SESSION_KEY` environment variable or passed in with `client_connection`.
|
||||
|
||||
Install with [HELM](https://crfm-helm.readthedocs.io/en/latest/) Support:
|
||||
```bash
|
||||
pip install manifest-ml[helm]
|
||||
```
|
||||
This requires a HELM api key.
|
||||
|
||||
Install with HuggingFace API Support:
|
||||
```bash
|
||||
pip install manifest-ml[api]
|
||||
```
|
||||
|
||||
Dev Install:
|
||||
Manual Install:
|
||||
```bash
|
||||
git clone git@github.com:HazyResearch/manifest.git
|
||||
cd manifest
|
||||
make dev
|
||||
```
|
||||
or
|
||||
```
|
||||
pip install .[all]
|
||||
```
|
||||
|
||||
# Getting Started
|
||||
Running is simple to get started. If using OpenAI, set `export OPENAI_API_KEY=<OPENAIKEY>` (or pass key in through variable `client_connection`) then run
|
||||
|
79
examples/manifest_helm.ipynb
Normal file
79
examples/manifest_helm.ipynb
Normal file
@ -0,0 +1,79 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%load_ext autoreload\n",
|
||||
"%autoreload 2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from manifest import Manifest\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"# ChatGPT tries hard not to give people programmatic access.\n",
|
||||
"# As a warning, this will open a browser window.\n",
|
||||
"# You need to install xvfb and chromium for linux headless mode to work\n",
|
||||
"# See https://github.com/terry3041/pyChatGPT\n",
|
||||
"\n",
|
||||
"# The responses are not fast\n",
|
||||
"manifest = Manifest(\n",
|
||||
" client_name=\"chatgpt\",\n",
|
||||
" client_connection=os.environ.get(\"CHATGPT_SESSION_KEY\"),\n",
|
||||
")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Sure! Pickling is a way to save things, like food or toys, so that they can be used later. Imagine you have a toy that you really like, but you have to go to school and can't play with it. You can put the toy in a special jar and close the lid tight to keep it safe until you get home. That's kind of like pickling. You're taking something that you want to save, and putting it in a special container so it won't go bad or get lost. Just like the toy in the jar, pickled food can last a long time without going bad.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(manifest.run(\"Can you explain the pickling process to a four-year old?\"))\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "bootleg",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.12 | packaged by conda-forge | (default, Jan 30 2022, 23:36:06) \n[Clang 11.1.0 ]"
|
||||
},
|
||||
"orig_nbformat": 4,
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "7a3f97ab0465937066e9b79893b779dfc8a12d73c41f9d98a7bf05133c798250"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
"""Client class."""
|
||||
"""ChatGPT client."""
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Optional, Tuple
|
||||
|
173
manifest/clients/helm.py
Normal file
173
manifest/clients/helm.py
Normal file
@ -0,0 +1,173 @@
|
||||
"""HELM client."""
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Optional, Tuple
|
||||
|
||||
from helm.common.authentication import Authentication
|
||||
from helm.common.request import Request as HELMRequest
|
||||
from helm.proxy.services.remote_service import RemoteService
|
||||
|
||||
from manifest.clients.client import Client
|
||||
from manifest.request import LMRequest, Request
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
HELM_ENGINES = {
|
||||
"ai21/j1-jumbo" "ai21/j1-grande",
|
||||
"ai21/j1-grande-v2-beta",
|
||||
"ai21/j1-large",
|
||||
"AlephAlpha/luminous-base",
|
||||
"AlephAlpha/luminous-extended",
|
||||
"AlephAlpha/luminous-supreme",
|
||||
"anthropic/stanford-online-all-v4-s3",
|
||||
"together/bloom",
|
||||
"together/t0pp",
|
||||
"cohere/xlarge-20220609",
|
||||
"cohere/xlarge-20221108",
|
||||
"cohere/large-20220720",
|
||||
"cohere/medium-20220720",
|
||||
"cohere/medium-20221108",
|
||||
"cohere/small-20220720",
|
||||
"together/gpt-j-6b",
|
||||
"together/gpt-neox-20b",
|
||||
"gooseai/gpt-neo-20b",
|
||||
"gooseai/gpt-j-6b",
|
||||
"huggingface/gpt-j-6b",
|
||||
"together/t5-11b",
|
||||
"together/ul2",
|
||||
"huggingface/gpt2",
|
||||
"openai/davinci",
|
||||
"openai/curie",
|
||||
"openai/babbage",
|
||||
"openai/ada",
|
||||
"openai/text-davinci-003",
|
||||
"openai/text-davinci-002",
|
||||
"openai/text-davinci-001",
|
||||
"openai/text-curie-001",
|
||||
"openai/text-babbage-001",
|
||||
"openai/text-ada-001",
|
||||
"openai/code-davinci-002",
|
||||
"openai/code-davinci-001",
|
||||
"openai/code-cushman-001",
|
||||
"openai/chat-gpt",
|
||||
"openai/text-similarity-davinci-001",
|
||||
"openai/text-similarity-curie-001",
|
||||
"openai/text-similarity-babbage-001",
|
||||
"openai/text-similarity-ada-001",
|
||||
"together/opt-175b",
|
||||
"together/opt-66b",
|
||||
"microsoft/TNLGv2_530B",
|
||||
"microsoft/TNLGv2_7B",
|
||||
"together/Together-gpt-JT-6B-v1",
|
||||
"together/glm",
|
||||
"together/yalm",
|
||||
}
|
||||
|
||||
|
||||
class HELMClient(Client):
|
||||
"""HELM client."""
|
||||
|
||||
# User param -> (client param, default value)
|
||||
PARAMS = {
|
||||
"engine": ("model", "openai/text-davinci-002"),
|
||||
"temperature": ("temperature", 1.0),
|
||||
"max_tokens": ("max_tokens", 10),
|
||||
"n": ("num_completions", 1),
|
||||
"top_p": ("top_p", 1.0),
|
||||
"top_k": ("top_k_per_token", 1),
|
||||
"stop_sequences": ("stop_sequences", None), # HELM doesn't like empty lists
|
||||
"presence_penalty": ("presence_penalty", 0.0),
|
||||
"frequency_penalty": ("frequency_penalty", 0.0),
|
||||
"client_timeout": ("client_timeout", 60), # seconds
|
||||
}
|
||||
REQUEST_CLS = LMRequest
|
||||
|
||||
def connect(
|
||||
self,
|
||||
connection_str: Optional[str] = None,
|
||||
client_args: Dict[str, Any] = {},
|
||||
) -> None:
|
||||
"""
|
||||
Create a HELM instance.
|
||||
|
||||
connection_str is passed as default HELM_API_KEY if variable not set.
|
||||
|
||||
Args:
|
||||
connection_str: connection string.
|
||||
client_args: client arguments.
|
||||
"""
|
||||
self.api_key = os.environ.get("HELM_API_KEY", connection_str)
|
||||
if self.api_key is None:
|
||||
raise ValueError(
|
||||
"HELM API key not set. Set HELM_API_KEY environment "
|
||||
"variable or pass through `client_connection`."
|
||||
)
|
||||
self._helm_auth = Authentication(api_key=self.api_key)
|
||||
self._help_api = RemoteService("https://crfm-models.stanford.edu")
|
||||
self._helm_account = self._help_api.get_account(self._helm_auth)
|
||||
for key in self.PARAMS:
|
||||
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
|
||||
if getattr(self, "engine") not in HELM_ENGINES:
|
||||
raise ValueError(
|
||||
f"Invalid engine {getattr(self, 'engine')}. Must be {HELM_ENGINES}."
|
||||
)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the client."""
|
||||
self._help_api = None
|
||||
|
||||
def get_generation_url(self) -> str:
|
||||
"""Get generation URL."""
|
||||
return ""
|
||||
|
||||
def get_generation_header(self) -> Dict[str, str]:
|
||||
"""
|
||||
Get generation header.
|
||||
|
||||
Returns:
|
||||
header.
|
||||
"""
|
||||
return ""
|
||||
|
||||
def supports_batch_inference(self) -> bool:
|
||||
"""Return whether the client supports batch inference."""
|
||||
return False
|
||||
|
||||
def get_model_params(self) -> Dict:
|
||||
"""
|
||||
Get model params.
|
||||
|
||||
By getting model params from the server, we can add to request
|
||||
and make sure cache keys are unique to model.
|
||||
|
||||
Returns:
|
||||
model params.
|
||||
"""
|
||||
return {"model_name": "HELM", "engine": getattr(self, "engine")}
|
||||
|
||||
def get_request(self, request: Request) -> Tuple[Callable[[], Dict], Dict]:
|
||||
"""
|
||||
Get request string function.
|
||||
|
||||
Args:
|
||||
request: request.
|
||||
|
||||
Returns:
|
||||
request function that takes no input.
|
||||
request parameters as dict.
|
||||
"""
|
||||
if isinstance(request.prompt, list):
|
||||
raise ValueError("HELM does not support batch inference.")
|
||||
|
||||
request_params = request.to_dict(self.PARAMS)
|
||||
|
||||
def _run_completion() -> Dict:
|
||||
try:
|
||||
request = HELMRequest(**request_params)
|
||||
request_result = self._help_api.make_request(self._helm_auth, request)
|
||||
except Exception as e:
|
||||
logger.error(f"HELM error {e}.")
|
||||
raise e
|
||||
return self.format_response(request_result.__dict__())
|
||||
|
||||
return _run_completion, request_params
|
Loading…
Reference in New Issue
Block a user