Laurel/more models (#98)

* fix: google models

* fix: azure models and refactor
pull/99/head
Laurel Orr 1 year ago committed by GitHub
parent 4903c7e7e8
commit b52a4d9a4b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,5 +1,14 @@
0.1.8 - Unreleased
---------------------
Added
^^^^^
* Azure model support (completion and chat)
* Google Vertex API model support (completion and chat)
Fixed
^^^^^
* `run` with batches now acts the same as async run except not async. We will batch requests into appropriate batchs sizes.
* Refactored client so unified preprocess and postprocess of requests and responses to better support model variants in request/response format.
0.1.7 - 2023-05-17
---------------------

@ -0,0 +1,149 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"AZURE_KEY = \"API_KEY::URL\"\n",
"OPENAI_KEY = \"sk-XXX\""
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Use Azure and OpenAI models"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from manifest import Manifest\n",
"from manifest.connections.client_pool import ClientConnection\n",
"from pathlib import Path\n",
"\n",
"cache_path = Path(\"manifest.db\")\n",
"if cache_path.exists():\n",
" cache_path.unlink()\n",
"\n",
"\n",
"azure = ClientConnection(\n",
" client_name=\"azureopenai\",\n",
" client_connection=AZURE_KEY,\n",
" engine=\"text-davinci-003\",\n",
")\n",
"\n",
"manifest = Manifest(client_pool=[azure], \n",
" cache_name=\"sqlite\",\n",
" cache_connection=\"manifest.db\"\n",
")\n",
"\n",
"\n",
"openai = ClientConnection(\n",
" client_name=\"openai\",\n",
" client_connection=OPENAI_KEY,\n",
" engine=\"text-davinci-003\",\n",
")\n",
"\n",
"manifest_openai_nocache = Manifest(client_pool=[openai])\n",
"\n",
"manifest_openai = Manifest(client_pool=[openai], \n",
" cache_name=\"sqlite\",\n",
" cache_connection=\"manifest.db\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Show caches are the same\n",
"text = \"What is the meaning of life?\"\n",
"res = manifest.run(text, max_tokens=100, temperature=0.7, return_response=True)\n",
"print(res.get_response())\n",
"print(res.is_cached())\n",
"res2 = manifest_openai.run(text, max_tokens=100, temperature=0.7, return_response=True)\n",
"print(res2.is_cached())\n",
"\n",
"assert res2.get_response() == res.get_response()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"azure_chat = ClientConnection(\n",
" client_name=\"azureopenaichat\",\n",
" client_connection=AZURE_KEY,\n",
" engine=\"gpt-3.5-turbo\",\n",
")\n",
"\n",
"manifest = Manifest(client_pool=[azure_chat])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(manifest.run(\"What do you think is the best food?\", max_tokens=100))\n",
"\n",
"chat_dict = [\n",
" {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
" {\"role\": \"user\", \"content\": \"Who won the world series in 2020?\"},\n",
" {\"role\": \"assistant\", \"content\": \"The Los Angeles Dodgers won the World Series in 2020.\"},\n",
" {\"role\": \"user\", \"content\": \"Where was it played?\"}\n",
"]\n",
"print(manifest.run(chat_dict, max_tokens=100))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "manifest",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "fddffe4ac3b9f00470127629076101c1b5f38ecb1e7358b567d19305425e9491"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}

@ -62,6 +62,13 @@
"]\n",
"print(manifest.run(chat_dict, max_tokens=100))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {

@ -0,0 +1,117 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"GOOGLE_KEY = \"KEY::PROJECT_ID\""
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Use GoogleVertexAPI"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from manifest import Manifest\n",
"from manifest.connections.client_pool import ClientConnection\n",
"\n",
"google_bison = ClientConnection(\n",
" client_name=\"google\",\n",
" client_connection=GOOGLE_KEY\n",
")\n",
"\n",
"manifest = Manifest(client_pool=[google_bison])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Simple question\n",
"print(manifest.run(\"What is your name\", max_tokens=40))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from manifest import Manifest\n",
"from manifest.connections.client_pool import ClientConnection\n",
"\n",
"google_bison = ClientConnection(\n",
" client_name=\"googlechat\",\n",
" client_connection=GOOGLE_KEY\n",
")\n",
"\n",
"manifest = Manifest(client_pool=[google_bison])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"chat_dict = [\n",
" # {\"author\": \"bot\", \"content\": \"You are a helpful assistant.\"},\n",
" {\"author\": \"user\", \"content\": \"Who won the world series in 2020?\"},\n",
" {\"author\": \"bot\", \"content\": \"The Los Angeles Dodgers won the World Series in 2020.\"},\n",
" {\"author\": \"user\", \"content\": \"Where was it played?\"}\n",
"]\n",
"print(manifest.run(chat_dict, max_tokens=8))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "manifest",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "fddffe4ac3b9f00470127629076101c1b5f38ecb1e7358b567d19305425e9491"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}

@ -47,7 +47,7 @@ class AI21Client(Client):
"""
# Taken from https://studio.ai21.com/docs/api/
self.host = "https://api.ai21.com/studio/v1"
self.api_key = os.environ.get("AI21_API_KEY", connection_str)
self.api_key = connection_str or os.environ.get("AI21_API_KEY")
if self.api_key is None:
raise ValueError(
"AI21 API key not set. Set AI21_API_KEY environment "
@ -94,7 +94,7 @@ class AI21Client(Client):
"""
return {"model_name": self.NAME, "engine": getattr(self, "engine")}
def validate_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
def postprocess_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
"""
Format response to dict.

@ -0,0 +1,111 @@
"""OpenAI client."""
import logging
import os
from typing import Any, Dict, Optional, Type
from manifest.clients.openai import OPENAI_ENGINES, OpenAIClient
from manifest.request import LMRequest, Request
logger = logging.getLogger(__name__)
# Azure deployment name can only use letters and numbers, no spaces. Hyphens ("-") and
# underscores ("_") may be used, except as ending characters. We create this mapping to
# handle difference between Azure and OpenAI
AZURE_DEPLOYMENT_NAME_MAPPING = {
"gpt-3.5-turbo": "gpt-35-turbo",
"gpt-3.5-turbo-0301": "gpt-35-turbo-0301",
}
OPENAI_DEPLOYMENT_NAME_MAPPING = {
"gpt-35-turbo": "gpt-3.5-turbo",
"gpt-35-turbo-0301": "gpt-3.5-turbo-0301",
}
class AzureClient(OpenAIClient):
"""Azure client."""
PARAMS = OpenAIClient.PARAMS
REQUEST_CLS: Type[Request] = LMRequest
NAME = "azureopenai"
def connect(
self,
connection_str: Optional[str] = None,
client_args: Dict[str, Any] = {},
) -> None:
"""
Connect to the AzureOpenAI server.
connection_str is passed as default AZURE_OPENAI_KEY if variable not set.
Args:
connection_str: connection string.
client_args: client arguments.
"""
connection_parts = connection_str.split("::")
if len(connection_parts) == 1:
self.api_key = connection_parts[0]
elif len(connection_parts) == 2:
self.api_key, self.host = connection_parts
else:
raise ValueError(
"Invalid connection string. "
"Must be either AZURE_OPENAI_KEY or "
"AZURE_OPENAI_KEY::AZURE_OPENAI_ENDPOINT"
)
self.api_key = self.api_key or os.environ.get("AZURE_OPENAI_KEY")
if self.api_key is None:
raise ValueError(
"AzureOpenAI API key not set. Set AZURE_OPENAI_KEY environment "
"variable or pass through `client_connection`."
)
self.host = self.host or os.environ.get("AZURE_OPENAI_ENDPOINT")
self.host = self.host.rstrip("/")
if self.host is None:
raise ValueError(
"Azure Service URL not set "
"(e.g. https://openai-azure-service.openai.azure.com/)."
" Set AZURE_OPENAI_ENDPOINT or pass through `client_connection`."
" as AZURE_OPENAI_KEY::AZURE_OPENAI_ENDPOINT"
)
for key in self.PARAMS:
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
if getattr(self, "engine") not in OPENAI_ENGINES:
raise ValueError(
f"Invalid engine {getattr(self, 'engine')}. Must be {OPENAI_ENGINES}."
)
def get_generation_url(self) -> str:
"""Get generation URL."""
engine = getattr(self, "engine")
deployment_name = AZURE_DEPLOYMENT_NAME_MAPPING.get(engine, engine)
return (
self.host
+ "/openai/deployments/"
+ deployment_name
+ "/completions?api-version=2023-05-15"
)
def get_generation_header(self) -> Dict[str, str]:
"""
Get generation header.
Returns:
header.
"""
return {"api-key": f"{self.api_key}"}
def get_model_params(self) -> Dict:
"""
Get model params.
By getting model params from the server, we can add to request
and make sure cache keys are unique to model.
Returns:
model params.
"""
# IMPORTANT!!!
# Azure models are the same as openai models. So we want to unify their
# cached. Make sure we retrun the OpenAI name here.
return {"model_name": OpenAIClient.NAME, "engine": getattr(self, "engine")}

@ -0,0 +1,114 @@
"""OpenAI client."""
import logging
import os
from typing import Any, Dict, Optional
from manifest.clients.openai_chat import OPENAICHAT_ENGINES, OpenAIChatClient
from manifest.request import LMRequest
logger = logging.getLogger(__name__)
# Azure deployment name can only use letters and numbers, no spaces. Hyphens ("-") and
# underscores ("_") may be used, except as ending characters. We create this mapping to
# handle difference between Azure and OpenAI
AZURE_DEPLOYMENT_NAME_MAPPING = {
"gpt-3.5-turbo": "gpt-35-turbo",
"gpt-3.5-turbo-0301": "gpt-35-turbo-0301",
}
OPENAI_DEPLOYMENT_NAME_MAPPING = {
"gpt-35-turbo": "gpt-3.5-turbo",
"gpt-35-turbo-0301": "gpt-3.5-turbo-0301",
}
class AzureChatClient(OpenAIChatClient):
"""Azure chat client."""
# User param -> (client param, default value)
PARAMS = OpenAIChatClient.PARAMS
REQUEST_CLS = LMRequest
NAME = "azureopenaichat"
IS_CHAT = True
def connect(
self,
connection_str: Optional[str] = None,
client_args: Dict[str, Any] = {},
) -> None:
"""
Connect to the AzureOpenAI server.
connection_str is passed as default AZURE_OPENAI_KEY if variable not set.
Args:
connection_str: connection string.
client_args: client arguments.
"""
connection_parts = connection_str.split("::")
if len(connection_parts) == 1:
self.api_key = connection_parts[0]
elif len(connection_parts) == 2:
self.api_key, self.host = connection_parts
else:
raise ValueError(
"Invalid connection string. "
"Must be either AZURE_OPENAI_KEY or "
"AZURE_OPENAI_KEY::AZURE_OPENAI_ENDPOINT"
)
self.api_key = self.api_key or os.environ.get("AZURE_OPENAI_KEY")
if self.api_key is None:
raise ValueError(
"AzureOpenAI API key not set. Set AZURE_OPENAI_KEY environment "
"variable or pass through `client_connection`."
)
self.host = self.host or os.environ.get("AZURE_OPENAI_ENDPOINT")
self.host = self.host.rstrip("/")
if self.host is None:
raise ValueError(
"Azure Service URL not set "
"(e.g. https://openai-azure-service.openai.azure.com/)."
" Set AZURE_OPENAI_ENDPOINT or pass through `client_connection`."
" as AZURE_OPENAI_KEY::AZURE_OPENAI_ENDPOINT"
)
for key in self.PARAMS:
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
if getattr(self, "engine") not in OPENAICHAT_ENGINES:
raise ValueError(
f"Invalid engine {getattr(self, 'engine')}. "
f"Must be {OPENAICHAT_ENGINES}."
)
def get_generation_url(self) -> str:
"""Get generation URL."""
engine = getattr(self, "engine")
deployment_name = AZURE_DEPLOYMENT_NAME_MAPPING.get(engine, engine)
return (
self.host
+ "/openai/deployments/"
+ deployment_name
+ "/chat/completions?api-version=2023-05-15"
)
def get_generation_header(self) -> Dict[str, str]:
"""
Get generation header.
Returns:
header.
"""
return {"api-key": f"{self.api_key}"}
def get_model_params(self) -> Dict:
"""
Get model params.
By getting model params from the server, we can add to request
and make sure cache keys are unique to model.
Returns:
model params.
"""
# IMPORTANT!!!
# Azure models are the same as openai models. So we want to unify their
# cached. Make sure we retrun the OpenAI name here.
return {"model_name": OpenAIChatClient.NAME, "engine": getattr(self, "engine")}

@ -8,12 +8,14 @@ from typing import Any, Dict, List, Optional, Tuple, Union, cast
import aiohttp
import requests
import tqdm.asyncio
from tenacity import RetryCallState, retry, stop_after_attempt, wait_random_exponential
from manifest.request import (
DEFAULT_REQUEST_KEYS,
NOT_CACHE_KEYS,
LMChatRequest,
LMRequest,
LMScoreRequest,
Request,
)
@ -29,6 +31,14 @@ from manifest.response import (
logger = logging.getLogger(__name__)
ATTEMPTS_BEFORE_STOP = 20
ATTEMPTS_TIMEOUT = 120
# http_status mainly for azure and e.code mainly for openai usage
# e.http_status == 408 occurs when Azure times out
# e.code == 429 rate lime
# e.code == 500 or 502 occurs when server error
API_ERROR_CODE = {408, 429, 500, 502}
def retry_if_ratelimit(retry_base: RetryCallState) -> bool:
"""Return whether to retry if ratelimited."""
@ -38,13 +48,32 @@ def retry_if_ratelimit(retry_base: RetryCallState) -> bool:
requests.exceptions.HTTPError, retry_base.outcome.exception()
)
# 500 is a server error, 429 is a rate limit error
if exception.response.status_code in {429, 500}: # type: ignore
if exception.response.status_code in API_ERROR_CODE: # type: ignore
return True
except Exception:
pass
return False
def return_error_response(retry_state: RetryCallState) -> dict:
"""Return error response if all retries failed."""
request_params = retry_state.args[1]
number_of_prompts = (
len(request_params["prompt"])
if "prompt" in request_params
else len(request_params["messages"])
)
return {
"choices": [],
"usage": {
"total_tokens": 0,
"prompt_tokens": 0,
"completion_tokens": 0,
},
"errors": [str(retry_state.outcome.exception())] * number_of_prompts,
}
class Client(ABC):
"""Client class."""
@ -121,6 +150,15 @@ class Client(ABC):
"""
raise NotImplementedError()
def get_tokenizer(self, model: str) -> Tuple[Any, int]:
"""Get tokenizer for model.
Returns:
tokenizer: tokenizer with encoder and decode
max_length: max length of model
"""
return None, -1
def get_model_inputs(self) -> List:
"""
Get allowable model inputs.
@ -130,6 +168,51 @@ class Client(ABC):
"""
return list(self.PARAMS.keys())
def split_usage(self, request: Dict, choices: List[str]) -> List[Dict[str, int]]:
"""Split usage into list of usages for each prompt."""
# TODO: add this in using default tokenizer
return []
def preprocess_request_params(self, request: Dict[str, Any]) -> Dict[str, Any]:
"""
Preprocess request params.
Args:
request: request params.
Returns:
request params.
"""
return request
def postprocess_response(
self, response: Dict[str, Any], request: Dict[str, Any]
) -> Dict[str, Any]:
"""
Postprocess and validate response as dict.
Args:
response: response
request: request
Return:
response as dict
"""
if "choices" not in response:
raise ValueError(f"Invalid response: {response}")
if "usage" in response:
# Handle splitting the usages for batch requests
if len(response["choices"]) == 1:
if isinstance(response["usage"], list):
response["usage"] = response["usage"][0]
response["usage"] = [response["usage"]]
else:
# Try to split usage
split_usage = self.split_usage(request, response["choices"])
if split_usage:
response["usage"] = split_usage
return response
def get_request(
self, prompt: Union[str, List[str]], request_args: Dict[str, Any]
) -> Request:
@ -155,7 +238,7 @@ class Client(ABC):
params[key] = request_args.pop(key)
return self.REQUEST_CLS(**params) # type: ignore
def get_request_params(self, request: Request) -> Dict[str, Any]:
def _get_request_params(self, request: Request) -> Dict[str, Any]:
"""Get request params.
Add default keys that we need for requests such as batch_size.
@ -174,7 +257,7 @@ class Client(ABC):
Skip keys that are not cache keys such as batch_size.
"""
request_params = self.get_request_params(request)
request_params = self._get_request_params(request)
for key in NOT_CACHE_KEYS:
request_params.pop(key, None)
# Make sure to add model params and request class
@ -182,49 +265,7 @@ class Client(ABC):
request_params["request_cls"] = request.__class__.__name__
return request_params
def split_usage(self, request: Dict, choices: List[str]) -> List[Dict[str, int]]:
"""Split usage into list of usages for each prompt."""
return []
def get_model_choices(self, response: Dict) -> ModelChoices:
"""Format response to ModelChoices."""
# Array or text response
response_type = RESPONSE_CONSTRUCTORS[self.REQUEST_CLS]["response_type"]
if response_type == "array":
choices: List[Union[LMModelChoice, ArrayModelChoice]] = [
ArrayModelChoice(**choice) for choice in response["choices"]
]
else:
choices = [LMModelChoice(**choice) for choice in response["choices"]]
return ModelChoices(choices=choices)
def validate_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
"""
Validate response as dict.
Args:
response: response
request: request
Return:
response as dict
"""
if "choices" not in response:
raise ValueError(f"Invalid response: {response}")
if "usage" in response:
# Handle splitting the usages for batch requests
if len(response["choices"]) == 1:
if isinstance(response["usage"], list):
response["usage"] = response["usage"][0]
response["usage"] = [response["usage"]]
else:
# Try to split usage
split_usage = self.split_usage(request, response["choices"])
if split_usage:
response["usage"] = split_usage
return response
def split_requests(
def _split_requests(
self, request_params: Dict[str, Any], batch_size: int, key: str = "prompt"
) -> List[Dict[str, Any]]:
"""Split request into batch_sized request.
@ -246,11 +287,75 @@ class Client(ABC):
request_params_list.append(params)
return request_params_list
def _get_model_choices(self, response: Dict) -> ModelChoices:
"""Format response to ModelChoices."""
# Array or text response
response_type = RESPONSE_CONSTRUCTORS[self.REQUEST_CLS]["response_type"]
if response_type == "array":
choices: List[Union[LMModelChoice, ArrayModelChoice]] = [
ArrayModelChoice(**choice) for choice in response["choices"]
]
else:
choices = [LMModelChoice(**choice) for choice in response["choices"]]
return ModelChoices(choices=choices)
def _stitch_responses(self, request: Request, responses: List[Dict]) -> Response:
"""Stitch responses together.
Useful for batch requests.
"""
choices = []
usages = []
for res_dict in responses:
choices.extend(res_dict["choices"])
if "usage" in res_dict:
usages.extend(res_dict["usage"])
final_response_dict = {"choices": choices}
final_usages = None
if usages:
final_usages = Usages(usages=[Usage(**usage) for usage in usages])
return Response(
self._get_model_choices(final_response_dict),
cached=False,
request=request,
usages=final_usages,
**RESPONSE_CONSTRUCTORS[self.REQUEST_CLS], # type: ignore
)
def _verify_request_lengths(
self, request: Dict[str, Any], model: str, max_tokens: int
) -> None:
"""Verify that the request length is not too long."""
encoder, max_length = self.get_tokenizer(model)
if not encoder or max_length < 0:
return
if isinstance(request["prompt"], str):
prompts = [request["prompt"]]
else:
prompts = request["prompt"]
for i in range(len(prompts)):
prompt = prompts[i]
encoded_prompt = encoder.encode(prompt)
if len(encoded_prompt) + max_tokens > max_length:
logger.warning(
f"Prompt {prompt} is too long for model {model}. "
"Truncating prompt from left."
)
# -20 to be safe
prompt = encoder.decode(
encoded_prompt[-int(max_length - max_tokens - 20) :]
)
prompts[i] = prompt
if isinstance(request["prompt"], str):
request["prompt"] = prompts[0]
else:
request["prompt"] = prompts
@retry(
reraise=True,
retry=retry_if_ratelimit,
wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(10),
wait=wait_random_exponential(min=1, max=ATTEMPTS_TIMEOUT),
stop=stop_after_attempt(ATTEMPTS_BEFORE_STOP),
)
def _run_completion(
self, request_params: Dict[str, Any], retry_timeout: int
@ -264,6 +369,7 @@ class Client(ABC):
Returns:
response as dict.
"""
request_params = self.preprocess_request_params(request_params)
post_str = self.get_generation_url()
res = requests.post(
post_str,
@ -276,27 +382,27 @@ class Client(ABC):
except requests.exceptions.HTTPError:
logger.error(res.json())
raise requests.exceptions.HTTPError(res.json())
return self.validate_response(res.json(), request_params)
return self.postprocess_response(res.json(), request_params)
@retry(
reraise=True,
retry=retry_if_ratelimit,
wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(10),
wait=wait_random_exponential(min=1, max=ATTEMPTS_TIMEOUT),
stop=stop_after_attempt(ATTEMPTS_BEFORE_STOP),
)
async def _arun_completion(
self, request_params: Dict[str, Any], retry_timeout: int, batch_size: int
self, request_params: Dict[str, Any], retry_timeout: int
) -> Dict:
"""Async execute completion request.
Args:
request_params: request params.
retry_timeout: retry timeout.
batch_size: batch size for requests.
Returns:
response as dict.
"""
request_params = self.preprocess_request_params(request_params)
post_str = self.get_generation_url()
async with aiohttp.ClientSession(timeout=retry_timeout) as session:
async with session.post(
@ -307,7 +413,7 @@ class Client(ABC):
) as res:
res.raise_for_status()
res_json = await res.json(content_type=None)
return self.validate_response(res_json, request_params)
return self.postprocess_response(res_json, request_params)
def run_request(self, request: Request) -> Response:
"""
@ -319,36 +425,61 @@ class Client(ABC):
Returns:
response.
"""
if isinstance(request.prompt, list) and not self.supports_batch_inference():
raise ValueError(
# Make everything list for consistency
if isinstance(request.prompt, list):
prompt_list = request.prompt
else:
prompt_list = [request.prompt]
request_params = self._get_request_params(request)
# Set the params as a list. Do not set the request
# object itself as the cache will then store it as a
# list which is inconsistent with the request input.
request_params["prompt"] = prompt_list
# If batch_size is not set, set it to 1
batch_size = request_params.pop("batch_size") or 1
if not self.supports_batch_inference() and batch_size != 1:
logger.warning(
f"{self.__class__.__name__} does not support batch inference."
" Setting batch size to 1"
)
batch_size = 1
request_params = self.get_request_params(request)
# Take the default keys we need and drop the rest as they
# are not part of the model request.
retry_timeout = request_params.pop("client_timeout")
for key in DEFAULT_REQUEST_KEYS:
request_params.pop(key, None)
response_dict = self._run_completion(request_params, retry_timeout)
usages = None
if "usage" in response_dict:
usages = [Usage(**usage) for usage in response_dict["usage"]]
return Response(
response=self.get_model_choices(response_dict),
cached=False,
request=request,
usages=Usages(usages=usages) if usages else None,
**RESPONSE_CONSTRUCTORS[self.REQUEST_CLS], # type: ignore
)
# Make sure requests are in the request length
# If no tokenizer is set or not LM request, this
# will do nothing
if isinstance(request, LMRequest):
self._verify_request_lengths(
request_params, model=request.engine, max_tokens=request.max_tokens
)
# Batch requests
num_batches = len(prompt_list) // batch_size
if len(prompt_list) % batch_size != 0:
batch_size = int(math.ceil(len(prompt_list) / (num_batches + 1)))
request_batches = self._split_requests(request_params, batch_size)
response_dicts = [
self._run_completion(batch, retry_timeout) for batch in request_batches
]
# Flatten responses
return self._stitch_responses(request, response_dicts)
async def arun_batch_request(self, request: Request) -> Response:
async def arun_batch_request(
self, request: Request, verbose: bool = False
) -> Response:
"""
Run async request.
Args:
request: request.
request: request.s
Returns:
response.
@ -361,7 +492,7 @@ class Client(ABC):
"request.prompt must be a list for async batch inference."
)
request_params = self.get_request_params(request)
request_params = self._get_request_params(request)
# Take the default keys we need and drop the rest as they
# are not part of the model request.
retry_timeout = request_params.pop("client_timeout")
@ -370,34 +501,27 @@ class Client(ABC):
for key in DEFAULT_REQUEST_KEYS:
request_params.pop(key, None)
# Make sure requests are in the request length
# If no tokenizer is set or not LM request, this
# will do nothing
if isinstance(request, LMRequest):
self._verify_request_lengths(
request_params, model=request.engine, max_tokens=request.max_tokens
)
# Batch requests
num_batches = len(request.prompt) // batch_size
if len(request.prompt) % batch_size != 0:
batch_size = int(math.ceil(len(request.prompt) / (num_batches + 1)))
request_batches = self.split_requests(request_params, batch_size)
request_batches = self._split_requests(request_params, batch_size)
all_tasks = [
asyncio.create_task(self._arun_completion(batch, retry_timeout, batch_size))
asyncio.create_task(self._arun_completion(batch, retry_timeout))
for batch in request_batches
]
responses = await asyncio.gather(*all_tasks)
responses = await tqdm.asyncio.tqdm.gather(*all_tasks, disable=not verbose)
# Flatten responses
choices = []
usages = []
for res_dict in responses:
choices.extend(res_dict["choices"])
if "usage" in res_dict:
usages.extend(res_dict["usage"])
final_response_dict = {"choices": choices}
final_usages = None
if usages:
final_usages = Usages(usages=[Usage(**usage) for usage in usages])
return Response(
self.get_model_choices(final_response_dict),
cached=False,
request=request,
usages=final_usages,
**RESPONSE_CONSTRUCTORS[self.REQUEST_CLS], # type: ignore
)
return self._stitch_responses(request, responses)
def run_chat_request(
self,
@ -412,19 +536,27 @@ class Client(ABC):
Returns:
response.
"""
request_params = self.get_request_params(request)
request_params = self._get_request_params(request)
# Take the default keys we need and drop the rest as they
# are not part of the model request.
retry_timeout = request_params.pop("client_timeout")
for key in DEFAULT_REQUEST_KEYS:
request_params.pop(key, None)
# Make sure requests are in the request length
# If no tokenizer is set or not LM request, this
# will do nothing
self._verify_request_lengths(
request_params, model=request.engine, max_tokens=request.max_tokens
)
response_dict = self._run_completion(request_params, retry_timeout)
usages = None
if "usage" in response_dict:
usages = [Usage(**usage) for usage in response_dict["usage"]]
return Response(
response=self.get_model_choices(response_dict),
response=self._get_model_choices(response_dict),
cached=False,
request=request,
usages=Usages(usages=usages) if usages else None,

@ -44,7 +44,7 @@ class CohereClient(Client):
connection_str: connection string.
client_args: client arguments.
"""
self.api_key = os.environ.get("COHERE_API_KEY", connection_str)
self.api_key = connection_str or os.environ.get("COHERE_API_KEY")
if self.api_key is None:
raise ValueError(
"Cohere API key not set. Set COHERE_API_KEY environment "
@ -93,7 +93,7 @@ class CohereClient(Client):
"""
return {"model_name": self.NAME, "engine": getattr(self, "engine")}
def validate_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
def postprocess_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
"""
Format response to dict.

@ -86,7 +86,7 @@ class DiffuserClient(Client):
res["client_name"] = self.NAME
return res
def validate_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
def postprocess_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
"""
Format response to dict.

@ -111,7 +111,9 @@ class DummyClient(Client):
request_type=self.REQUEST_CLS,
)
async def arun_batch_request(self, request: Request) -> Response:
async def arun_batch_request(
self, request: Request, verbose: bool = False
) -> Response:
"""
Get async request string function.

@ -0,0 +1,190 @@
"""OpenAI client."""
import logging
import os
import subprocess
from typing import Any, Dict, Optional, Type
from manifest.clients.client import Client
from manifest.request import LMRequest, Request
logger = logging.getLogger(__name__)
# https://cloud.google.com/vertex-ai/docs/generative-ai/start/quickstarts/api-quickstart
GOOGLE_ENGINES = {
"text-bison",
}
def get_project_id() -> Optional[str]:
"""Get project ID.
Run
`gcloud config get-value project`
"""
try:
project_id = subprocess.run(
["gcloud", "config", "get-value", "project"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
if project_id.stderr.decode("utf-8").strip():
return None
return project_id.stdout.decode("utf-8").strip()
except Exception:
return None
class GoogleClient(Client):
"""Google client."""
# User param -> (client param, default value)
PARAMS = {
"engine": ("model", "text-bison"),
"temperature": ("temperature", 1.0),
"max_tokens": ("maxOutputTokens", 10),
"top_p": ("topP", 1.0),
"top_k": ("topK", 1),
"batch_size": ("batch_size", 20),
}
REQUEST_CLS: Type[Request] = LMRequest
NAME = "google"
def connect(
self,
connection_str: Optional[str] = None,
client_args: Dict[str, Any] = {},
) -> None:
"""
Connect to the GoogleVertex API.
connection_str is passed as default GOOGLE_API_KEY if variable not set.
Args:
connection_str: connection string.
client_args: client arguments.
"""
connection_parts = connection_str.split("::")
if len(connection_parts) == 1:
self.api_key = connection_parts[0]
self.project_id = None
elif len(connection_parts) == 2:
self.api_key, self.project_id = connection_parts
else:
raise ValueError(
"Invalid connection string. "
"Must be either API_KEY or API_KEY::PROJECT_ID"
)
self.api_key = self.api_key or os.environ.get("GOOGLE_API_KEY")
if self.api_key is None:
raise ValueError(
"GoogleVertex API key not set. Set GOOGLE_API_KEY environment "
"variable or pass through `client_connection`. This can be "
"found by running `gcloud auth print-access-token`"
)
self.project_id = (
self.project_id or os.environ.get("GOOGLE_PROJECT_ID") or get_project_id()
)
if self.project_id is None:
raise ValueError("GoogleVertex project ID not set. Set GOOGLE_PROJECT_ID")
self.host = f"https://us-central1-aiplatform.googleapis.com/v1/projects/{self.project_id}/locations/us-central1/publishers/google/models" # noqa: E501
for key in self.PARAMS:
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
if getattr(self, "engine") not in GOOGLE_ENGINES:
raise ValueError(
f"Invalid engine {getattr(self, 'engine')}. Must be {GOOGLE_ENGINES}."
)
def close(self) -> None:
"""Close the client."""
pass
def get_generation_url(self) -> str:
"""Get generation URL."""
model = getattr(self, "engine")
return self.host + f"/{model}:predict"
def get_generation_header(self) -> Dict[str, str]:
"""
Get generation header.
Returns:
header.
"""
return {"Authorization": f"Bearer {self.api_key}"}
def supports_batch_inference(self) -> bool:
"""Return whether the client supports batch inference."""
return True
def get_model_params(self) -> Dict:
"""
Get model params.
By getting model params from the server, we can add to request
and make sure cache keys are unique to model.
Returns:
model params.
"""
return {"model_name": self.NAME, "engine": getattr(self, "engine")}
def preprocess_request_params(self, request: Dict[str, Any]) -> Dict[str, Any]:
"""
Preprocess request params.
Args:
request: request params.
Returns:
request params.
"""
# Refortmat the request params for google
prompt = request.pop("prompt")
if isinstance(prompt, str):
prompt_list = [prompt]
else:
prompt_list = prompt
google_request = {
"instances": [{"prompt": prompt} for prompt in prompt_list],
"parameters": request,
}
return super().preprocess_request_params(google_request)
def postprocess_response(
self, response: Dict[str, Any], request: Dict[str, Any]
) -> Dict[str, Any]:
"""
Validate response as dict.
Assumes response is dict
{
"predictions": [
{
"safetyAttributes": {
"categories": ["Violent", "Sexual"],
"blocked": false,
"scores": [0.1, 0.1]
},
"content": "SELECT * FROM "WWW";"
}
]
}
Args:
response: response
request: request
Return:
response as dict
"""
google_predictions = response.pop("predictions")
new_response = {
"choices": [
{
"text": prediction["content"],
}
for prediction in google_predictions
]
}
return super().postprocess_response(new_response, request)

@ -0,0 +1,155 @@
"""OpenAI client."""
import copy
import logging
import os
from typing import Any, Dict, Optional, Type
from manifest.clients.google import GoogleClient, get_project_id
from manifest.request import LMRequest, Request
logger = logging.getLogger(__name__)
# https://cloud.google.com/vertex-ai/docs/generative-ai/start/quickstarts/api-quickstart
GOOGLE_ENGINES = {
"chat-bison",
}
class GoogleChatClient(GoogleClient):
"""GoogleChat client."""
# User param -> (client param, default value)
PARAMS = {
"engine": ("model", "chat-bison"),
"temperature": ("temperature", 1.0),
"max_tokens": ("maxOutputTokens", 10),
"top_p": ("topP", 1.0),
"top_k": ("topK", 1),
"batch_size": ("batch_size", 20),
}
REQUEST_CLS: Type[Request] = LMRequest
NAME = "googlechat"
IS_CHAT = True
def connect(
self,
connection_str: Optional[str] = None,
client_args: Dict[str, Any] = {},
) -> None:
"""
Connect to the GoogleVertex API.
connection_str is passed as default GOOGLE_API_KEY if variable not set.
Args:
connection_str: connection string.
client_args: client arguments.
"""
connection_parts = connection_str.split("::")
if len(connection_parts) == 1:
self.api_key = connection_parts[0]
elif len(connection_parts) == 2:
self.api_key, self.project_id = connection_parts
else:
raise ValueError(
"Invalid connection string. "
"Must be either API_KEY or API_KEY::PROJECT_ID"
)
self.api_key = self.api_key or os.environ.get("GOOGLE_API_KEY")
if self.api_key is None:
raise ValueError(
"GoogleVertex API key not set. Set GOOGLE_API_KEY environment "
"variable or pass through `client_connection`. This can be "
"found by running `gcloud auth print-access-token`"
)
self.project_id = (
self.project_id or os.environ.get("GOOGLE_PROJECT_ID") or get_project_id()
)
if self.project_id is None:
raise ValueError("GoogleVertex project ID not set. Set GOOGLE_PROJECT_ID")
self.host = f"https://us-central1-aiplatform.googleapis.com/v1/projects/{self.project_id}/locations/us-central1/publishers/google/models" # noqa: E501
for key in self.PARAMS:
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
if getattr(self, "engine") not in GOOGLE_ENGINES:
raise ValueError(
f"Invalid engine {getattr(self, 'engine')}. Must be {GOOGLE_ENGINES}."
)
def supports_batch_inference(self) -> bool:
"""Return whether the client supports batch inference."""
return False
def preprocess_request_params(self, request: Dict[str, Any]) -> Dict[str, Any]:
"""
Preprocess request params.
Args:
request: request params.
Returns:
request params.
"""
# Format for chat model
request = copy.deepcopy(request)
prompt = request.pop("prompt")
if isinstance(prompt, str):
messages = [{"author": "user", "content": prompt}]
elif isinstance(prompt, list) and isinstance(prompt[0], str):
prompt_list = prompt
messages = [{"author": "user", "content": prompt} for prompt in prompt_list]
elif isinstance(prompt, list) and isinstance(prompt[0], dict):
for pmt_dict in prompt:
if "author" not in pmt_dict or "content" not in pmt_dict:
raise ValueError(
"Prompt must be list of dicts with 'author' and 'content' "
f"keys. Got {prompt}."
)
messages = prompt
else:
raise ValueError(
"Prompt must be string, list of strings, or list of dicts."
f"Got {prompt}"
)
new_request = {
"instances": [{"messages": messages}],
"parameters": request,
}
return super(GoogleClient, self).preprocess_request_params(new_request)
def postprocess_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
"""
Validate response as dict.
Assumes response is dict
{
"candidates": [
{
"safetyAttributes": {
"categories": ["Violent", "Sexual"],
"blocked": false,
"scores": [0.1, 0.1]
},
"author": "1",
"content": "SELECT * FROM "WWW";"
}
]
}
Args:
response: response
request: request
Return:
response as dict
"""
google_predictions = response.pop("predictions")
new_response = {
"choices": [
{
"text": prediction["candidates"][0]["content"],
}
for prediction in google_predictions
]
}
return super(GoogleClient, self).postprocess_response(new_response, request)

@ -94,7 +94,7 @@ class HuggingFaceClient(Client):
request function that takes no input.
request parameters as dict.
"""
request_params = self.get_request_params(request)
request_params = self._get_request_params(request)
retry_timeout = request_params.pop("client_timeout")
for key in DEFAULT_REQUEST_KEYS:
request_params.pop(key, None)

@ -72,7 +72,7 @@ class HuggingFaceEmbeddingClient(Client):
res["client_name"] = self.NAME
return res
def validate_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
def postprocess_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
"""
Format response to dict.

@ -60,7 +60,7 @@ class OpenAIClient(Client):
connection_str: connection string.
client_args: client arguments.
"""
self.api_key = os.environ.get("OPENAI_API_KEY", connection_str)
self.api_key = connection_str or os.environ.get("OPENAI_API_KEY")
if self.api_key is None:
raise ValueError(
"OpenAI API key not set. Set OPENAI_API_KEY environment "
@ -107,7 +107,7 @@ class OpenAIClient(Client):
"""
return {"model_name": self.NAME, "engine": getattr(self, "engine")}
def validate_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
def postprocess_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
"""
Validate response as dict.
@ -118,7 +118,7 @@ class OpenAIClient(Client):
Return:
response as dict
"""
validated_response = super().validate_response(response, request)
validated_response = super().postprocess_response(response, request)
# Handle logprobs
for choice in validated_response["choices"]:
if "logprobs" in choice:

@ -26,6 +26,7 @@ class OpenAIChatClient(OpenAIClient):
"stop_sequences": ("stop", None), # OpenAI doesn't like empty lists
"presence_penalty": ("presence_penalty", 0.0),
"frequency_penalty": ("frequency_penalty", 0.0),
"batch_size": ("batch_size", 1),
}
REQUEST_CLS = LMRequest
NAME = "openaichat"
@ -45,7 +46,7 @@ class OpenAIChatClient(OpenAIClient):
connection_str: connection string.
client_args: client arguments.
"""
self.api_key = os.environ.get("OPENAI_API_KEY", connection_str)
self.api_key = connection_str or os.environ.get("OPENAI_API_KEY")
if self.api_key is None:
raise ValueError(
"OpenAI API key not set. Set OPENAI_API_KEY environment "
@ -80,18 +81,19 @@ class OpenAIChatClient(OpenAIClient):
"""
return {"model_name": self.NAME, "engine": getattr(self, "engine")}
def _format_request_for_chat(self, request_params: Dict[str, Any]) -> Dict:
"""Format request params for chat.
def preprocess_request_params(self, request: Dict[str, Any]) -> Dict[str, Any]:
"""
Preprocess request params.
Args:
request_params: request params.
request: request params.
Returns:
formatted request params.
request params.
"""
# Format for chat model
request_params = copy.deepcopy(request_params)
prompt = request_params.pop("prompt")
request = copy.deepcopy(request)
prompt = request.pop("prompt")
if isinstance(prompt, str):
messages = [{"role": "user", "content": prompt}]
elif isinstance(prompt, list) and isinstance(prompt[0], str):
@ -110,62 +112,23 @@ class OpenAIChatClient(OpenAIClient):
"Prompt must be string, list of strings, or list of dicts."
f"Got {prompt}"
)
request_params["messages"] = messages
return request_params
request["messages"] = messages
return super().preprocess_request_params(request)
def _format_request_from_chat(self, response_dict: Dict[str, Any]) -> Dict:
"""Format response for standard response from chat.
def postprocess_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
"""
Postprocess and validate response as dict.
Args:
response_dict: response.
response: response
request: request
Return:
formatted response.
response as dict
"""
new_choices = []
response_dict = copy.deepcopy(response_dict)
for message in response_dict["choices"]:
response = copy.deepcopy(response)
for message in response["choices"]:
new_choices.append({"text": message["message"]["content"]})
response_dict["choices"] = new_choices
return response_dict
def _run_completion(
self, request_params: Dict[str, Any], retry_timeout: int
) -> Dict:
"""Execute completion request.
Args:
request_params: request params.
retry_timeout: retry timeout.
Returns:
response as dict.
"""
# Format for chat model
request_params = self._format_request_for_chat(request_params)
response_dict = super()._run_completion(request_params, retry_timeout)
# Reformat for text model
response_dict = self._format_request_from_chat(response_dict)
return response_dict
async def _arun_completion(
self, request_params: Dict[str, Any], retry_timeout: int, batch_size: int
) -> Dict:
"""Async execute completion request.
Args:
request_params: request params.
retry_timeout: retry timeout.
batch_size: batch size for requests.
Returns:
response as dict.
"""
# Format for chat model
request_params = self._format_request_for_chat(request_params)
response_dict = await super()._arun_completion(
request_params, retry_timeout, batch_size
)
# Reformat for text model
response_dict = self._format_request_from_chat(response_dict)
return response_dict
response["choices"] = new_choices
return super().postprocess_response(response, request)

@ -41,7 +41,7 @@ class OpenAIEmbeddingClient(OpenAIClient):
connection_str: connection string.
client_args: client arguments.
"""
self.api_key = os.environ.get("OPENAI_API_KEY", connection_str)
self.api_key = connection_str or os.environ.get("OPENAI_API_KEY")
if self.api_key is None:
raise ValueError(
"OpenAI API key not set. Set OPENAI_API_KEY environment "
@ -76,7 +76,7 @@ class OpenAIEmbeddingClient(OpenAIClient):
"""
return {"model_name": self.NAME, "engine": getattr(self, "engine")}
def validate_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
def postprocess_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
"""
Format response to dict.
@ -157,23 +157,20 @@ class OpenAIEmbeddingClient(OpenAIClient):
return response_dict
async def _arun_completion(
self, request_params: Dict[str, Any], retry_timeout: int, batch_size: int
self, request_params: Dict[str, Any], retry_timeout: int
) -> Dict:
"""Async execute completion request.
Args:
request_params: request params.
retry_timeout: retry timeout.
batch_size: batch size for requests.
Returns:
response as dict.
"""
# Format for embedding model
request_params = self._format_request_for_embedding(request_params)
response_dict = await super()._arun_completion(
request_params, retry_timeout, batch_size
)
response_dict = await super()._arun_completion(request_params, retry_timeout)
# Reformat for text model
response_dict = self._format_request_from_embedding(response_dict)
return response_dict

@ -143,7 +143,7 @@ class TOMAClient(Client):
}
return heartbeats
def validate_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
def postprocess_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
"""
Format response to dict.

@ -46,7 +46,7 @@ class TOMADiffuserClient(TOMAClient):
"""
return {"model_name": self.NAME, "engine": getattr(self, "engine")}
def validate_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
def postprocess_response(self, response: Dict, request: Dict) -> Dict[str, Any]:
"""
Format response to dict.

@ -6,9 +6,13 @@ from typing import Any, Dict, List, Optional, Type
from pydantic import BaseModel, Extra
from manifest.clients.ai21 import AI21Client
from manifest.clients.azureopenai import AzureClient
from manifest.clients.azureopenai_chat import AzureChatClient
from manifest.clients.client import Client
from manifest.clients.cohere import CohereClient
from manifest.clients.dummy import DummyClient
from manifest.clients.google import GoogleClient
from manifest.clients.google_chat import GoogleChatClient
from manifest.clients.huggingface import HuggingFaceClient
from manifest.clients.huggingface_embedding import HuggingFaceEmbeddingClient
from manifest.clients.openai import OpenAIClient
@ -21,14 +25,18 @@ logging.getLogger("openai").setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
CLIENT_CONSTRUCTORS = {
OpenAIClient.NAME: OpenAIClient,
OpenAIChatClient.NAME: OpenAIChatClient,
OpenAIEmbeddingClient.NAME: OpenAIEmbeddingClient,
CohereClient.NAME: CohereClient,
AI21Client.NAME: AI21Client,
AzureClient.NAME: AzureClient,
AzureChatClient.NAME: AzureChatClient,
CohereClient.NAME: CohereClient,
DummyClient.NAME: DummyClient,
GoogleClient.NAME: GoogleClient,
GoogleChatClient.NAME: GoogleChatClient,
HuggingFaceClient.NAME: HuggingFaceClient,
HuggingFaceEmbeddingClient.NAME: HuggingFaceEmbeddingClient,
DummyClient.NAME: DummyClient,
OpenAIClient.NAME: OpenAIClient,
OpenAIChatClient.NAME: OpenAIChatClient,
OpenAIEmbeddingClient.NAME: OpenAIEmbeddingClient,
TOMAClient.NAME: TOMAClient,
}

@ -474,6 +474,7 @@ class Manifest:
stop_token: Optional[str] = None,
return_response: bool = False,
chunk_size: int = -1,
verbose: bool = False,
**kwargs: Any,
) -> Union[List[str], List[np.ndarray], Response]:
"""
@ -500,6 +501,7 @@ class Manifest:
For a single manifest client, there is no impact to
setting chunk_size. For a client pool, chunk_size
can be used to distribute the load across the clients.
verbose: whether to print progress of async tasks.
Returns:
response from prompt.
@ -511,7 +513,7 @@ class Manifest:
if not isinstance(prompts[0], str):
raise ValueError("Prompts must be a list of strings.")
# Split the prompts into chunks
# Split the prompts into chunks for connection pool
prompt_chunks: List[Tuple[Client, List[str]]] = []
if chunk_size > 0:
for i in range(0, len(prompts), chunk_size):
@ -530,11 +532,11 @@ class Manifest:
prompts=chunk,
client=client,
overwrite_cache=overwrite_cache,
verbose=verbose,
**kwargs,
)
)
)
print(f"Running {len(tasks)} tasks across all clients.")
logger.info(f"Running {len(tasks)} tasks across all clients.")
responses = await asyncio.gather(*tasks)
final_response = Response.union_all(responses)
@ -554,6 +556,7 @@ class Manifest:
prompts: List[str],
client: Client,
overwrite_cache: bool = False,
verbose: bool = False,
**kwargs: Any,
) -> Response:
"""
@ -563,6 +566,7 @@ class Manifest:
prompts: prompts to run.
client: client to run.
overwrite_cache: whether to overwrite cache.
verbose: whether to print progress of async tasks.
Returns:
response from prompt.
@ -580,7 +584,7 @@ class Manifest:
# If not None value or empty list - run new request
if request_params.prompt:
self.client_pool.start_timer()
response = await client.arun_batch_request(request_params)
response = await client.arun_batch_request(request_params, verbose=verbose)
self.client_pool.end_timer()
else:
# Nothing to run

@ -6,19 +6,21 @@ strict_optional = false
[[tool.mypy.overrides]]
ignore_missing_imports = true
module = [
"accelerate",
"accelerate.utils.modeling",
"deepspeed",
"numpy",
"diffusers",
"sentence_transformers",
"sqlitedict",
"sqlalchemy",
"dill",
"accelerate",
"accelerate.utils.modeling",
"transformers",
"flask",
"torch",
"numpy",
"pyChatGPT",
"torch",
"transformers",
"tqdm",
"tqdm.asyncio",
"sentence_transformers",
"sqlalchemy",
"sqlitedict",
]
[tool.isort]

Loading…
Cancel
Save