You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
manifest/manifest/manifest.py

759 lines
28 KiB
Python

"""Manifest class."""
import asyncio
import copy
import logging
from typing import (
Any,
Dict,
Generator,
Iterator,
List,
Optional,
Tuple,
Type,
Union,
cast,
)
import numpy as np
from manifest.caches.noop import NoopCache
from manifest.caches.postgres import PostgresCache
from manifest.caches.redis import RedisCache
from manifest.caches.sqlite import SQLiteCache
from manifest.clients.client import Client
from manifest.clients.huggingface import HuggingFaceClient
from manifest.connections.client_pool import (
CLIENT_CONSTRUCTORS,
ClientConnection,
ClientConnectionPool,
)
from manifest.request import LMChatRequest, LMScoreRequest, Request
from manifest.response import ModelChoices, Response, Usage, Usages
logging.getLogger("openai").setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
CACHE_CONSTRUCTORS = {
"redis": RedisCache,
"sqlite": SQLiteCache,
"noop": NoopCache,
"postgres": PostgresCache,
}
class Manifest:
"""Manifest session object."""
def __init__(
self,
client_name: Optional[str] = None,
client_connection: Optional[str] = None,
client_pool: Optional[List[ClientConnection]] = None,
client_pool_schedule: str = "round_robin",
cache_name: str = "noop",
cache_connection: Optional[str] = None,
stop_token: str = "",
**kwargs: Any,
):
"""
Initialize manifest.
Args:
client_name: name of client.
client_connection: connection string for client.
client_pool: list of client connections for multi-client.
client_pool_schedule: schedule for client pool.
cache_name: name of cache.
cache_connection: connection string for cache.
stop_token: stop token prompt generation.
Can be overridden in run
Remaining kwargs sent to client and cache.
"""
if not client_name and not client_pool:
raise ValueError(
"Must specify client_name or client_pool. "
f"Choices are {list(CLIENT_CONSTRUCTORS.keys())}"
)
if client_name and client_pool:
raise ValueError("Cannot specify both client_name and client_pool")
if client_name:
client_pool = [
ClientConnection(
client_name=client_name,
client_connection=client_connection,
# Remove engine from kwargs
engine=kwargs.pop("engine", None),
)
]
self.client_pool = ClientConnectionPool(
client_pool, client_pool_schedule, client_args=kwargs
)
if cache_name not in CACHE_CONSTRUCTORS:
raise ValueError(
f"Unknown cache name: {cache_name}. "
f"Choices are {list(CACHE_CONSTRUCTORS.keys())}"
)
# Must pass kwargs as dict for client "pop" methods removed used arguments
self.cache = CACHE_CONSTRUCTORS[cache_name]( # type: ignore
cache_connection, self.client_pool.request_type, cache_args=kwargs
)
if len(kwargs) > 0:
raise ValueError(f"{list(kwargs.items())} arguments are not recognized.")
self.stop_token = stop_token
def close(self) -> None:
"""Close the client and cache."""
self.client_pool.close()
self.cache.close()
def _validate_kwargs(self, kwargs: Dict, request_params: Request) -> None:
"""Validate kwargs.
Args:
kwargs: kwargs to validate.
request_params: request object to validate against.
"""
# Check for invalid kwargs
non_request_kwargs = [
(k, v) for k, v in kwargs.items() if k not in request_params.__dict__
]
if len(non_request_kwargs) > 0:
raise ValueError(
f"{list(non_request_kwargs)} arguments are not recognized."
)
# Warn for valid but unused kwargs
request_unused_kwargs = [
(k, v) for k, v in kwargs.items() if k not in non_request_kwargs
]
if len(request_unused_kwargs) > 0:
logger.warning(f"{list(request_unused_kwargs)} arguments are unused.")
return
def _split_cached_requests(
self,
request: Request,
client: Client,
overwrite_cache: bool,
) -> Tuple[Dict[int, Response], Request]:
"""Split a request into cached responses and Requests to run.
Args:
request: request object.
overwrite_cache: whether to overwrite cache.
Returns:
cached_idx_to_response: dict of cached responses.
new_request: request object with only prompts to run.
"""
cached_idx_to_response: Dict[int, Response] = {}
new_request = copy.deepcopy(request)
if not overwrite_cache:
if isinstance(new_request.prompt, list) and not isinstance(
request, LMChatRequest
):
new_request.prompt = []
for idx, prompt_str in enumerate(request.prompt):
single_request = copy.deepcopy(request)
single_request.prompt = prompt_str
possible_response = self.cache.get(
client.get_cache_key(single_request)
)
if possible_response:
cached_idx_to_response[idx] = possible_response
else:
new_request.prompt.append(prompt_str)
# Chat or single string requests are not broken down into
# subprompts for caching.
elif (isinstance(new_request.prompt, str)) or (
isinstance(new_request.prompt, list)
and isinstance(request, LMChatRequest)
):
possible_response = self.cache.get(client.get_cache_key(new_request))
if possible_response:
cached_idx_to_response[0] = possible_response
new_request.prompt = None
else:
raise ValueError(
f"Invalid prompt type: {type(new_request.prompt)}"
f" with request type: {type(request)}"
)
return cached_idx_to_response, new_request
def _stitch_responses_and_cache(
self,
request: Request,
client: Client,
response: Union[Response, None],
cached_idx_to_response: Dict[int, Response],
) -> Response:
"""Stich together the cached and uncached responses."""
# We stitch the responses (the choices) here from both the new request the
# cached entries.
all_model_choices = []
all_usages = []
all_input_prompts: List[Union[str, List[str], List[Dict]]] = []
response_idx = 0
number_prompts = len(cached_idx_to_response)
single_completion_output = False
if response:
if isinstance(response.get_request_obj().prompt, str):
single_completion_output = True
number_prompts += 1
elif isinstance(response.get_request_obj().prompt, list) and not isinstance(
request, LMChatRequest
):
number_prompts += len(response.get_request_obj().prompt)
elif isinstance(response.get_request_obj().prompt, list) and isinstance(
request, LMChatRequest
):
assert len(cached_idx_to_response) <= 1
number_prompts += 1
else:
raise ValueError(
f"Invalid prompt type: {type(response.get_request_obj().prompt)}"
f" with request type: {type(request)}"
)
response_type = None
request_type: Type[Request] = None
for idx in range(number_prompts):
if idx in cached_idx_to_response:
cached_res = cached_idx_to_response[idx]
response_type = cached_res._response_type
request_type = cached_res._request_type
all_input_prompts.append(cached_res.get_request_obj().prompt)
if request.n == 1:
assert (
len(cached_res.get_response_obj().choices) == 1
), "cached response should have only one choice"
all_model_choices.extend(cached_res.get_response_obj().choices)
if cached_res.get_usage_obj().usages:
all_usages.extend(cached_res.get_usage_obj().usages)
else:
assert response is not None, "response should not be None"
response = cast(Response, response)
response_type = response._response_type
request_type = response._request_type
# the choices list in the response is a flat one.
# length is request.n * num_prompts
current_choices = response.get_response_obj().choices[
response_idx * request.n : (response_idx + 1) * request.n
]
all_model_choices.extend(current_choices)
if isinstance(
response.get_request_obj().prompt, list
) and not isinstance(request, LMChatRequest):
prompt: Union[
str, List[str], List[Dict]
] = response.get_request_obj().prompt[response_idx]
# Chat request
elif isinstance(response.get_request_obj().prompt, list) and isinstance(
request, LMChatRequest
):
# We will only have response_idx == 0 here as we can only
# support single chat requests.
assert request.n == 1
assert number_prompts <= 1
prompt = response.get_request_obj().prompt
else:
prompt = str(response.get_request_obj().prompt)
usages: Optional[List[Usage]] = None
if response.get_usage_obj().usages:
usages = response.get_usage_obj().usages[
response_idx * request.n : (response_idx + 1) * request.n
]
all_usages.extend(usages)
all_input_prompts.append(prompt)
# set cache
new_request = copy.deepcopy(request)
new_request.prompt = prompt # type: ignore
cache_key = client.get_cache_key(new_request)
new_response = copy.deepcopy(response)
new_response._response.choices = current_choices
new_response._usages = Usages(usages=(usages or []))
self.cache.set(cache_key, new_response.to_dict(drop_request=True))
response_idx += 1
new_request = copy.deepcopy(request)
new_request.prompt = (
all_input_prompts # type: ignore
if len(all_input_prompts) > 1 or not single_completion_output
else all_input_prompts[0]
)
response_obj = Response(
response=ModelChoices(choices=all_model_choices),
cached=len(cached_idx_to_response) > 0,
request=new_request,
usages=Usages(usages=all_usages),
response_type=response_type,
request_type=request_type,
)
return response_obj
def run(
self,
prompt: Union[str, List[str], List[Dict[str, str]]],
overwrite_cache: bool = False,
stop_token: Optional[str] = None,
return_response: bool = False,
stream: bool = False,
**kwargs: Any,
) -> Union[
str,
List[str],
np.ndarray,
List[np.ndarray],
Response,
Iterator[str],
Iterator[Response],
]:
"""
Run the prompt.
Orchestrates between the standard run and chat run and batch run.
Args:
prompt: prompt(s) to run.
overwrite_cache: whether to overwrite cache.
stop_token: stop token for prompt generation.
Default is self.stop_token.
"" for no stop token.
return_response: whether to return Response object.
stream: whether to stream the prompt. Only supported
for single string prompts and LMs.
Returns:
response from prompt.
"""
if not isinstance(prompt, list) and not isinstance(prompt, str):
raise ValueError(
f"Invalid prompt type: {type(prompt)}. "
"Prompt must be a string or list of strings "
"or list of dicts."
)
if isinstance(prompt, list) and not prompt:
raise ValueError("Prompt cannot be empty list")
# Get the client to run
client = self.client_pool.get_next_client()
if stream:
if not client.supports_streaming_inference():
raise ValueError(
f"Client {client} does not support streaming inference."
)
if not isinstance(prompt, str):
raise ValueError(
"Stream is only supported for single string prompts. "
"It will soon be supported for chat dictionary prompts, too."
)
return self._run_stream(
prompt=cast(str, prompt),
client=client,
overwrite_cache=overwrite_cache,
stop_token=stop_token,
return_response=return_response,
**kwargs,
)
if isinstance(prompt, list) and isinstance(prompt[0], dict):
if not client.IS_CHAT:
raise ValueError(
f"Client {client} does not support dict chat prompt. "
"Please use a chat model."
)
if stop_token:
logger.warning(
"stop_token is not supported for chat prompt. "
"Ignoring stop_token."
)
return self._run_chat(
prompt=cast(List[Dict[str, str]], prompt),
client=client,
overwrite_cache=overwrite_cache,
return_response=return_response,
**kwargs,
)
return self._run(
prompt=cast(Union[str, List[str]], prompt),
client=client,
overwrite_cache=overwrite_cache,
stop_token=stop_token,
return_response=return_response,
**kwargs,
)
def _run(
self,
prompt: Union[str, List[str]],
client: Client,
overwrite_cache: bool = False,
stop_token: Optional[str] = None,
return_response: bool = False,
**kwargs: Any,
) -> Union[str, List[str], np.ndarray, List[np.ndarray], Response]:
"""
Run the prompt.
Args:
prompt: prompt(s) to run.
client: client to run.
overwrite_cache: whether to overwrite cache.
stop_token: stop token for prompt generation.
Default is self.stop_token.
"" for no stop token.
return_response: whether to return Response object.
Returns:
response from prompt.
"""
is_batch = isinstance(prompt, list)
stop_token = stop_token if stop_token is not None else self.stop_token
# Must pass kwargs as dict for client "pop" methods removed used arguments
request_params = client.get_request(prompt, kwargs)
# Avoid nested list of results - enforce n = 1 for batch
if is_batch and request_params.n > 1:
raise ValueError("Batch mode does not support n > 1.")
self._validate_kwargs(kwargs, request_params)
cached_idx_to_response, request_params = self._split_cached_requests(
request_params, client, overwrite_cache
)
# If not None value or empty list - run new request
if request_params.prompt:
# Start timing metrics
self.client_pool.start_timer()
response = client.run_request(request_params)
self.client_pool.end_timer()
else:
# Nothing to run
response = None
final_response = self._stitch_responses_and_cache(
request=request_params,
client=client,
response=response,
cached_idx_to_response=cached_idx_to_response,
)
# Extract text results
if return_response:
return final_response
else:
return final_response.get_response(stop_token, is_batch)
def _run_chat(
self,
prompt: List[Dict[str, str]],
client: Client,
overwrite_cache: bool = False,
return_response: bool = False,
**kwargs: Any,
) -> Union[str, Response]:
"""
Run the prompt.
Args:
prompt: prompt dictionary to run.
client: client to run.
overwrite_cache: whether to overwrite cache.
stop_token: stop token for prompt generation.
Default is self.stop_token.
"" for no stop token.
return_response: whether to return Response object.
Returns:
response from prompt.
"""
is_batch = False
# Get a request for an empty prompt to handle all kwargs
request_params = client.get_request("", kwargs)
# Add prompt and cast as chat request
request_params_dict = request_params.to_dict()
request_params_dict["prompt"] = prompt
request_params_as_chat = LMChatRequest(**request_params_dict)
# Avoid nested list of results - enforce n = 1 for batch
if request_params_as_chat.n > 1:
raise ValueError("Chat mode does not support n > 1.")
self._validate_kwargs(kwargs, request_params_as_chat)
cached_idx_to_response, request_params_as_chat = self._split_cached_requests( # type: ignore # noqa: E501
request_params_as_chat, client, overwrite_cache
)
# If not None value or empty list - run new request
if request_params_as_chat.prompt:
# Start timing metrics
self.client_pool.start_timer()
response = client.run_chat_request(request_params_as_chat)
self.client_pool.end_timer()
else:
# Nothing to run
response = None
final_response = self._stitch_responses_and_cache(
request=request_params_as_chat,
client=client,
response=response,
cached_idx_to_response=cached_idx_to_response,
)
# Extract text results
if return_response:
return final_response
else:
return cast(str, final_response.get_response("", is_batch))
def _run_stream(
self,
prompt: str,
client: Client,
overwrite_cache: bool = False,
stop_token: Optional[str] = None,
return_response: bool = False,
**kwargs: Any,
) -> Union[Generator[str, None, None], Generator[Response, None, None]]:
"""
Run the prompt in a stream.
Args:
prompt: prompt(s) to run.
client: client to run.
overwrite_cache: whether to overwrite cache.
stop_token: stop token for prompt generation.
Default is self.stop_token.
"" for no stop token.
return_response: whether to return Response object.
Returns:
response from prompt.
"""
is_batch = False
stop_token = stop_token if stop_token is not None else self.stop_token
# Must pass kwargs as dict for client "pop" methods removed used arguments
request_params = client.get_request(prompt, kwargs)
# Avoid nested list of results - enforce n = 1 for batch
if request_params.n > 1:
raise ValueError("Stream mode does not support n > 1.")
self._validate_kwargs(kwargs, request_params)
cached_idx_to_response, request_params = self._split_cached_requests(
request_params, client, overwrite_cache
)
if request_params.prompt:
# Because we are streaming, we should have either a cached response
# a prompt to run
assert len(cached_idx_to_response) == 0
response_iter = client.run_streaming_request(request_params)
is_cached = False
else:
assert len(cached_idx_to_response) == 1
response_iter = cached_idx_to_response[0].as_iter()
is_cached = True
saved_responses = []
# Start timing metrics
self.client_pool.start_timer()
for response_token in response_iter:
saved_responses.append(response_token)
if return_response:
yield response_token
else:
yield cast(
Union[str, Response], response_token.get_response("", is_batch)
)
self.client_pool.end_timer()
if not is_cached:
final_response = Response.union_all(
saved_responses, as_single_lmchoice=True
)
self._stitch_responses_and_cache(
request=request_params,
client=client,
response=final_response,
cached_idx_to_response=cached_idx_to_response,
)
async def arun_batch(
self,
prompts: List[str],
overwrite_cache: bool = False,
stop_token: Optional[str] = None,
return_response: bool = False,
chunk_size: int = -1,
verbose: bool = False,
**kwargs: Any,
) -> Union[List[str], List[np.ndarray], Response]:
"""
Run a batch of prompts with async.
If the client pool is a single client, all prompts will be sent
to one client and batch_size (which is passed it as kwargs) will
determine how the prompts are split.
If the client pool is a pool of clients, the prompts will be split
into chunks and sent to the clients. Each client will split the
chunk into batch_size prompts to send to the model.
Args:
prompts: prompts to run.
overwrite_cache: whether to overwrite cache.
stop_token: stop token for prompt generation.
Default is self.stop_token.
"" for no stop token.
return_response: whether to return Response object.
chunk_size: number of prompts to send to a client in chunks.
For each chunk, the client will split the chunk into
batch_sized prompts to send to the model.
For a single manifest client, there is no impact to
setting chunk_size. For a client pool, chunk_size
can be used to distribute the load across the clients.
verbose: whether to print progress of async tasks.
Returns:
response from prompt.
"""
if not isinstance(prompts, list):
raise ValueError("Prompts must be a list of strings.")
if not prompts:
raise ValueError("Prompts must not be empty.")
if not isinstance(prompts[0], str):
raise ValueError("Prompts must be a list of strings.")
# Split the prompts into chunks for connection pool
prompt_chunks: List[Tuple[Client, List[str]]] = []
if chunk_size > 0:
for i in range(0, len(prompts), chunk_size):
prompt_chunks.append(
(self.client_pool.get_next_client(), prompts[i : i + chunk_size])
)
else:
prompt_chunks = [(self.client_pool.get_next_client(), prompts)]
# Run the chunks
tasks = []
for client, chunk in prompt_chunks:
tasks.append(
asyncio.create_task(
self._arun_batch_client(
prompts=chunk,
client=client,
overwrite_cache=overwrite_cache,
verbose=verbose,
**kwargs,
)
)
)
logger.info(f"Running {len(tasks)} tasks across all clients.")
responses = await asyncio.gather(*tasks)
final_response = Response.union_all(responses)
stop_token = stop_token if stop_token is not None else self.stop_token
# Extract text results
if return_response:
return final_response
else:
return cast(
Union[List[str], List[np.ndarray]],
final_response.get_response(stop_token, True),
)
async def _arun_batch_client(
self,
prompts: List[str],
client: Client,
overwrite_cache: bool = False,
verbose: bool = False,
**kwargs: Any,
) -> Response:
"""
Run a batch of prompts with async for single client.
Args:
prompts: prompts to run.
client: client to run.
overwrite_cache: whether to overwrite cache.
verbose: whether to print progress of async tasks.
Returns:
response from prompt.
"""
# Must pass kwargs as dict for client "pop" methods removed used arguments
request_params = client.get_request(prompts, kwargs)
# Avoid nested list of results - enforce n = 1 for batch
if request_params.n > 1:
raise ValueError("Batch mode does not support n > 1.")
self._validate_kwargs(kwargs, request_params)
cached_idx_to_response, request_params = self._split_cached_requests(
request_params, client, overwrite_cache
)
# If not None value or empty list - run new request
if request_params.prompt:
self.client_pool.start_timer()
response = await client.arun_batch_request(request_params, verbose=verbose)
self.client_pool.end_timer()
else:
# Nothing to run
response = None
final_response = self._stitch_responses_and_cache(
request=request_params,
client=client,
response=response,
cached_idx_to_response=cached_idx_to_response,
)
return final_response
def score_prompt(
self,
prompt: Union[str, List[str]],
overwrite_cache: bool = False,
**kwargs: Any,
) -> Dict:
"""
Score the prompt via forward pass of the model - no sampling or generation.
Returns the response object with logits of the prompt.
Args:
prompt: prompt(s) to run.
overwrite_cache: whether to overwrite cache.
Returns:
response from prompt.
"""
client = self.client_pool.get_next_client()
# Must pass kwargs as dict for client "pop" methods removed used arguments
request_params = client.get_request(prompt, kwargs)
request_params_as_score = LMScoreRequest(**request_params.to_dict())
if request_params_as_score.n > 1:
raise ValueError("Sequence scoring does not support n > 1.")
self._validate_kwargs(kwargs, request_params_as_score)
cached_idx_to_response, request_params_as_score = self._split_cached_requests( # type: ignore # noqa: E501
request_params_as_score, client, overwrite_cache
)
# If not None value or empty list - run new request
if request_params_as_score.prompt:
try:
response = cast(HuggingFaceClient, client).run_score_prompt_request(
request_params_as_score
)
except AttributeError:
raise ValueError("`score_prompt` only supported for HF models.")
else:
# Nothing to run
response = None
final_response = self._stitch_responses_and_cache(
request=request_params_as_score,
client=client,
response=response,
cached_idx_to_response=cached_idx_to_response,
)
return final_response.to_dict()