mirror of
https://github.com/HazyResearch/manifest
synced 2024-10-31 15:20:26 +00:00
feat: chatgpt client added (#47)
This commit is contained in:
parent
defc63bf36
commit
56eae406ce
1
.gitignore
vendored
1
.gitignore
vendored
@ -35,7 +35,6 @@ wheels/
|
|||||||
*.egg-info/
|
*.egg-info/
|
||||||
.installed.cfg
|
.installed.cfg
|
||||||
*.egg
|
*.egg
|
||||||
MANIFEST
|
|
||||||
|
|
||||||
# PyInstaller
|
# PyInstaller
|
||||||
# Usually these files are written by a python script from a template
|
# Usually these files are written by a python script from a template
|
||||||
|
@ -4,6 +4,8 @@ Added
|
|||||||
^^^^^
|
^^^^^
|
||||||
* Batched inference support in `manifest.run`. No more separate `manifest.run_batch` method.
|
* Batched inference support in `manifest.run`. No more separate `manifest.run_batch` method.
|
||||||
* Standard request base model for all language inputs.
|
* Standard request base model for all language inputs.
|
||||||
|
* ChatGPT client. Requires CHATGPT_SESSION_KEY to be passed in.
|
||||||
|
* Diffusion model support
|
||||||
|
|
||||||
Fixed
|
Fixed
|
||||||
^^^^^^^^
|
^^^^^^^^
|
||||||
|
@ -16,6 +16,12 @@ Install:
|
|||||||
pip install manifest-ml
|
pip install manifest-ml
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Install with ChatGPT Support:
|
||||||
|
```bash
|
||||||
|
pip install manifest-ml[chatgpt]
|
||||||
|
```
|
||||||
|
This installs [pyChatGPT](https://github.com/terry3041/pyChatGPT) and uses the ChatGPT session key to start a session. This key must be set as the `CHATGPT_SESSION_KEY` environment variable or passed in with `client_connection`.
|
||||||
|
|
||||||
Install with HuggingFace API Support:
|
Install with HuggingFace API Support:
|
||||||
```bash
|
```bash
|
||||||
pip install manifest-ml[api]
|
pip install manifest-ml[api]
|
||||||
|
@ -428,7 +428,7 @@
|
|||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "manifest",
|
"display_name": "bootleg",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
"name": "python3"
|
"name": "python3"
|
||||||
},
|
},
|
||||||
@ -442,11 +442,11 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.8.15"
|
"version": "3.8.12 | packaged by conda-forge | (default, Jan 30 2022, 23:36:06) \n[Clang 11.1.0 ]"
|
||||||
},
|
},
|
||||||
"vscode": {
|
"vscode": {
|
||||||
"interpreter": {
|
"interpreter": {
|
||||||
"hash": "0d67557b0e03f6eb64c46b70fb42ce7c6498a7305f9f3922c351822f3fc8e363"
|
"hash": "7a3f97ab0465937066e9b79893b779dfc8a12d73c41f9d98a7bf05133c798250"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
79
examples/manifest_chatgpt.ipynb
Normal file
79
examples/manifest_chatgpt.ipynb
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"%load_ext autoreload\n",
|
||||||
|
"%autoreload 2"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from manifest import Manifest\n",
|
||||||
|
"import os\n",
|
||||||
|
"\n",
|
||||||
|
"# ChatGPT tries hard not to give people programmatic access.\n",
|
||||||
|
"# As a warning, this will open a browser window.\n",
|
||||||
|
"# You need to install xvfb and chromium for linux headless mode to work\n",
|
||||||
|
"# See https://github.com/terry3041/pyChatGPT\n",
|
||||||
|
"\n",
|
||||||
|
"# The responses are not fast\n",
|
||||||
|
"manifest = Manifest(\n",
|
||||||
|
" client_name=\"chatgpt\",\n",
|
||||||
|
" client_connection=os.environ.get(\"CHATGPT_SESSION_KEY\"),\n",
|
||||||
|
")\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Sure! Pickling is a way to save things, like food or toys, so that they can be used later. Imagine you have a toy that you really like, but you have to go to school and can't play with it. You can put the toy in a special jar and close the lid tight to keep it safe until you get home. That's kind of like pickling. You're taking something that you want to save, and putting it in a special container so it won't go bad or get lost. Just like the toy in the jar, pickled food can last a long time without going bad.\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"print(manifest.run(\"Can you explain the pickling process to a four-year old?\"))\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "bootleg",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.8.12"
|
||||||
|
},
|
||||||
|
"orig_nbformat": 4,
|
||||||
|
"vscode": {
|
||||||
|
"interpreter": {
|
||||||
|
"hash": "7a3f97ab0465937066e9b79893b779dfc8a12d73c41f9d98a7bf05133c798250"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
130
manifest/clients/chatgpt.py
Normal file
130
manifest/clients/chatgpt.py
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
"""Client class."""
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Any, Callable, Dict, Optional, Tuple
|
||||||
|
|
||||||
|
from pyChatGPT import ChatGPT
|
||||||
|
|
||||||
|
from manifest.clients.client import Client
|
||||||
|
from manifest.request import LMRequest, Request
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatGPTClient(Client):
|
||||||
|
"""ChatGPT Client class."""
|
||||||
|
|
||||||
|
# No params for ChatGPT
|
||||||
|
PARAMS = {}
|
||||||
|
REQUEST_CLS = LMRequest
|
||||||
|
|
||||||
|
def connect(
|
||||||
|
self, connection_str: Optional[str], client_args: Dict[str, Any]
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Connect to ChatGPT.
|
||||||
|
|
||||||
|
We use https://github.com/terry3041/pyChatGPT.
|
||||||
|
|
||||||
|
Arsg:
|
||||||
|
connection_str: connection string.
|
||||||
|
client_args: client arguments.
|
||||||
|
"""
|
||||||
|
self.session_key = os.environ.get("CHATGPT_SESSION_KEY", connection_str)
|
||||||
|
if self.session_key is None:
|
||||||
|
raise ValueError(
|
||||||
|
"ChatGPT session key not set. Set CHATGPT_SESSION_KEY environment "
|
||||||
|
"variable or pass through `client_connection`. "
|
||||||
|
"For details, see https://github.com/terry3041/pyChatGPT "
|
||||||
|
"and go through instructions for getting a session key."
|
||||||
|
)
|
||||||
|
self.host = None
|
||||||
|
for key in self.PARAMS:
|
||||||
|
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
|
||||||
|
self._chat_session = ChatGPT(self.session_key, verbose=False)
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close the client."""
|
||||||
|
self._chat_session = None
|
||||||
|
|
||||||
|
def clear_conversations(self) -> None:
|
||||||
|
"""Clear conversations.
|
||||||
|
|
||||||
|
Only works for ChatGPT.
|
||||||
|
"""
|
||||||
|
self._chat_session.clear_conversations()
|
||||||
|
|
||||||
|
def get_generation_url(self) -> str:
|
||||||
|
"""Get generation URL."""
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def get_generation_header(self) -> Dict[str, str]:
|
||||||
|
"""
|
||||||
|
Get generation header.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
header.
|
||||||
|
"""
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def supports_batch_inference(self) -> bool:
|
||||||
|
"""Return whether the client supports batch inference."""
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_model_params(self) -> Dict:
|
||||||
|
"""
|
||||||
|
Get model params.
|
||||||
|
|
||||||
|
By getting model params from the server, we can add to request
|
||||||
|
and make sure cache keys are unique to model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
model params.
|
||||||
|
"""
|
||||||
|
return {"model_name": "chatgpt", "engine": "chatgpt"}
|
||||||
|
|
||||||
|
def format_response(self, response: Dict) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Format response to dict.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: response
|
||||||
|
|
||||||
|
Return:
|
||||||
|
response as dict
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"model": "chatgpt",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"text": response["message"],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_request(self, request: Request) -> Tuple[Callable[[], Dict], Dict]:
|
||||||
|
"""
|
||||||
|
Get request string function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
request function that takes no input.
|
||||||
|
request parameters as dict.
|
||||||
|
"""
|
||||||
|
if isinstance(request.prompt, list):
|
||||||
|
raise ValueError("ChatGPT does not support batch inference.")
|
||||||
|
|
||||||
|
prompt = str(request.prompt)
|
||||||
|
request_params = request.to_dict(self.PARAMS)
|
||||||
|
|
||||||
|
def _run_completion() -> Dict:
|
||||||
|
try:
|
||||||
|
res = self._chat_session.send_message(prompt)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"ChatGPT error {e}.")
|
||||||
|
raise e
|
||||||
|
return self.format_response(res)
|
||||||
|
|
||||||
|
return _run_completion, request_params
|
@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
|
|||||||
# Engines are dynamically instantiated from API
|
# Engines are dynamically instantiated from API
|
||||||
# but a few example engines are listed below.
|
# but a few example engines are listed below.
|
||||||
TOMA_ENGINES = {
|
TOMA_ENGINES = {
|
||||||
"StableDiffusion",
|
# "StableDiffusion",
|
||||||
"Together-gpt-JT-6B-v1",
|
"Together-gpt-JT-6B-v1",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -8,6 +8,7 @@ from manifest.caches.noop import NoopCache
|
|||||||
from manifest.caches.redis import RedisCache
|
from manifest.caches.redis import RedisCache
|
||||||
from manifest.caches.sqlite import SQLiteCache
|
from manifest.caches.sqlite import SQLiteCache
|
||||||
from manifest.clients.ai21 import AI21Client
|
from manifest.clients.ai21 import AI21Client
|
||||||
|
from manifest.clients.chatgpt import ChatGPTClient
|
||||||
from manifest.clients.cohere import CohereClient
|
from manifest.clients.cohere import CohereClient
|
||||||
from manifest.clients.diffuser import DiffuserClient
|
from manifest.clients.diffuser import DiffuserClient
|
||||||
from manifest.clients.dummy import DummyClient
|
from manifest.clients.dummy import DummyClient
|
||||||
@ -22,6 +23,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
CLIENT_CONSTRUCTORS = {
|
CLIENT_CONSTRUCTORS = {
|
||||||
"openai": OpenAIClient,
|
"openai": OpenAIClient,
|
||||||
|
"chatgpt": ChatGPTClient,
|
||||||
"cohere": CohereClient,
|
"cohere": CohereClient,
|
||||||
"ai21": AI21Client,
|
"ai21": AI21Client,
|
||||||
"huggingface": HuggingFaceClient,
|
"huggingface": HuggingFaceClient,
|
||||||
|
@ -15,7 +15,8 @@ module = [
|
|||||||
"accelerate.utils.modeling",
|
"accelerate.utils.modeling",
|
||||||
"transformers",
|
"transformers",
|
||||||
"flask",
|
"flask",
|
||||||
"torch"
|
"torch",
|
||||||
|
"pyChatGPT",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.isort]
|
[tool.isort]
|
||||||
|
9
setup.py
9
setup.py
@ -34,13 +34,18 @@ EXTRAS = {
|
|||||||
"api": [
|
"api": [
|
||||||
"diffusers>=0.6.0",
|
"diffusers>=0.6.0",
|
||||||
"Flask>=2.1.2",
|
"Flask>=2.1.2",
|
||||||
"fastapi>=0.70.0",
|
|
||||||
"uvicorn>=0.18.0",
|
|
||||||
"accelerate>=0.10.0",
|
"accelerate>=0.10.0",
|
||||||
"transformers>=4.20.0",
|
"transformers>=4.20.0",
|
||||||
"torch>=1.8.0",
|
"torch>=1.8.0",
|
||||||
"numpy>=1.20.0",
|
"numpy>=1.20.0",
|
||||||
],
|
],
|
||||||
|
"app": [
|
||||||
|
"fastapi>=0.70.0",
|
||||||
|
"uvicorn>=0.18.0",
|
||||||
|
],
|
||||||
|
"chatgpt": [
|
||||||
|
"pyChatGPT>=0.4.3",
|
||||||
|
],
|
||||||
"dev": [
|
"dev": [
|
||||||
"autopep8>=1.6.0",
|
"autopep8>=1.6.0",
|
||||||
"black>=22.3.0",
|
"black>=22.3.0",
|
||||||
|
Loading…
Reference in New Issue
Block a user