feat: helm support added

This commit is contained in:
lorr1 2023-01-09 09:58:11 -08:00
parent 56eae406ce
commit 934a0bd5cd
6 changed files with 268 additions and 2 deletions

View File

@ -6,6 +6,7 @@ Added
* Standard request base model for all language inputs.
* ChatGPT client. Requires CHATGPT_SESSION_KEY to be passed in.
* Diffusion model support
* HELM support
Fixed
^^^^^^^^

View File

@ -22,17 +22,27 @@ pip install manifest-ml[chatgpt]
```
This installs [pyChatGPT](https://github.com/terry3041/pyChatGPT) and uses the ChatGPT session key to start a session. This key must be set as the `CHATGPT_SESSION_KEY` environment variable or passed in with `client_connection`.
Install with [HELM](https://crfm-helm.readthedocs.io/en/latest/) Support:
```bash
pip install manifest-ml[helm]
```
This requires a HELM api key.
Install with HuggingFace API Support:
```bash
pip install manifest-ml[api]
```
Dev Install:
Manual Install:
```bash
git clone git@github.com:HazyResearch/manifest.git
cd manifest
make dev
```
or
```
pip install .[all]
```
# Getting Started
Running is simple to get started. If using OpenAI, set `export OPENAI_API_KEY=<OPENAIKEY>` (or pass key in through variable `client_connection`) then run

View File

@ -0,0 +1,79 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from manifest import Manifest\n",
"import os\n",
"\n",
"# ChatGPT tries hard not to give people programmatic access.\n",
"# As a warning, this will open a browser window.\n",
"# You need to install xvfb and chromium for linux headless mode to work\n",
"# See https://github.com/terry3041/pyChatGPT\n",
"\n",
"# The responses are not fast\n",
"manifest = Manifest(\n",
" client_name=\"chatgpt\",\n",
" client_connection=os.environ.get(\"CHATGPT_SESSION_KEY\"),\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sure! Pickling is a way to save things, like food or toys, so that they can be used later. Imagine you have a toy that you really like, but you have to go to school and can't play with it. You can put the toy in a special jar and close the lid tight to keep it safe until you get home. That's kind of like pickling. You're taking something that you want to save, and putting it in a special container so it won't go bad or get lost. Just like the toy in the jar, pickled food can last a long time without going bad.\n"
]
}
],
"source": [
"print(manifest.run(\"Can you explain the pickling process to a four-year old?\"))\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "bootleg",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.12 | packaged by conda-forge | (default, Jan 30 2022, 23:36:06) \n[Clang 11.1.0 ]"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "7a3f97ab0465937066e9b79893b779dfc8a12d73c41f9d98a7bf05133c798250"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -1,4 +1,4 @@
"""Client class."""
"""ChatGPT client."""
import logging
import os
from typing import Any, Callable, Dict, Optional, Tuple

173
manifest/clients/helm.py Normal file
View File

@ -0,0 +1,173 @@
"""HELM client."""
import logging
import os
from typing import Any, Callable, Dict, Optional, Tuple
from helm.common.authentication import Authentication
from helm.common.request import Request as HELMRequest
from helm.proxy.services.remote_service import RemoteService
from manifest.clients.client import Client
from manifest.request import LMRequest, Request
logger = logging.getLogger(__name__)
HELM_ENGINES = {
"ai21/j1-jumbo" "ai21/j1-grande",
"ai21/j1-grande-v2-beta",
"ai21/j1-large",
"AlephAlpha/luminous-base",
"AlephAlpha/luminous-extended",
"AlephAlpha/luminous-supreme",
"anthropic/stanford-online-all-v4-s3",
"together/bloom",
"together/t0pp",
"cohere/xlarge-20220609",
"cohere/xlarge-20221108",
"cohere/large-20220720",
"cohere/medium-20220720",
"cohere/medium-20221108",
"cohere/small-20220720",
"together/gpt-j-6b",
"together/gpt-neox-20b",
"gooseai/gpt-neo-20b",
"gooseai/gpt-j-6b",
"huggingface/gpt-j-6b",
"together/t5-11b",
"together/ul2",
"huggingface/gpt2",
"openai/davinci",
"openai/curie",
"openai/babbage",
"openai/ada",
"openai/text-davinci-003",
"openai/text-davinci-002",
"openai/text-davinci-001",
"openai/text-curie-001",
"openai/text-babbage-001",
"openai/text-ada-001",
"openai/code-davinci-002",
"openai/code-davinci-001",
"openai/code-cushman-001",
"openai/chat-gpt",
"openai/text-similarity-davinci-001",
"openai/text-similarity-curie-001",
"openai/text-similarity-babbage-001",
"openai/text-similarity-ada-001",
"together/opt-175b",
"together/opt-66b",
"microsoft/TNLGv2_530B",
"microsoft/TNLGv2_7B",
"together/Together-gpt-JT-6B-v1",
"together/glm",
"together/yalm",
}
class HELMClient(Client):
"""HELM client."""
# User param -> (client param, default value)
PARAMS = {
"engine": ("model", "openai/text-davinci-002"),
"temperature": ("temperature", 1.0),
"max_tokens": ("max_tokens", 10),
"n": ("num_completions", 1),
"top_p": ("top_p", 1.0),
"top_k": ("top_k_per_token", 1),
"stop_sequences": ("stop_sequences", None), # HELM doesn't like empty lists
"presence_penalty": ("presence_penalty", 0.0),
"frequency_penalty": ("frequency_penalty", 0.0),
"client_timeout": ("client_timeout", 60), # seconds
}
REQUEST_CLS = LMRequest
def connect(
self,
connection_str: Optional[str] = None,
client_args: Dict[str, Any] = {},
) -> None:
"""
Create a HELM instance.
connection_str is passed as default HELM_API_KEY if variable not set.
Args:
connection_str: connection string.
client_args: client arguments.
"""
self.api_key = os.environ.get("HELM_API_KEY", connection_str)
if self.api_key is None:
raise ValueError(
"HELM API key not set. Set HELM_API_KEY environment "
"variable or pass through `client_connection`."
)
self._helm_auth = Authentication(api_key=self.api_key)
self._help_api = RemoteService("https://crfm-models.stanford.edu")
self._helm_account = self._help_api.get_account(self._helm_auth)
for key in self.PARAMS:
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
if getattr(self, "engine") not in HELM_ENGINES:
raise ValueError(
f"Invalid engine {getattr(self, 'engine')}. Must be {HELM_ENGINES}."
)
def close(self) -> None:
"""Close the client."""
self._help_api = None
def get_generation_url(self) -> str:
"""Get generation URL."""
return ""
def get_generation_header(self) -> Dict[str, str]:
"""
Get generation header.
Returns:
header.
"""
return ""
def supports_batch_inference(self) -> bool:
"""Return whether the client supports batch inference."""
return False
def get_model_params(self) -> Dict:
"""
Get model params.
By getting model params from the server, we can add to request
and make sure cache keys are unique to model.
Returns:
model params.
"""
return {"model_name": "HELM", "engine": getattr(self, "engine")}
def get_request(self, request: Request) -> Tuple[Callable[[], Dict], Dict]:
"""
Get request string function.
Args:
request: request.
Returns:
request function that takes no input.
request parameters as dict.
"""
if isinstance(request.prompt, list):
raise ValueError("HELM does not support batch inference.")
request_params = request.to_dict(self.PARAMS)
def _run_completion() -> Dict:
try:
request = HELMRequest(**request_params)
request_result = self._help_api.make_request(self._helm_auth, request)
except Exception as e:
logger.error(f"HELM error {e}.")
raise e
return self.format_response(request_result.__dict__())
return _run_completion, request_params

View File

@ -46,6 +46,9 @@ EXTRAS = {
"chatgpt": [
"pyChatGPT>=0.4.3",
],
"helm": [
"crfm-helm>=0.1.0",
],
"dev": [
"autopep8>=1.6.0",
"black>=22.3.0",