You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
manifest/manifest/caches/cache.py

136 lines
4.3 KiB
Python

"""Cache for queries and responses."""
from abc import ABC, abstractmethod
from typing import Any, Dict, Type, Union
from manifest.caches.serializers import ArraySerializer, NumpyByteSerializer, Serializer
from manifest.request import DiffusionRequest, EmbeddingRequest, LMRequest, Request
from manifest.response import Response
# Non-text return type caches
ARRAY_CACHE_TYPES = {EmbeddingRequest, DiffusionRequest}
class Cache(ABC):
"""A cache for request/response pairs."""
def __init__(
self,
connection_str: str,
request_type: Type[Request] = LMRequest,
cache_args: Dict[str, Any] = {},
):
"""
Initialize cache.
Args:
connection_str: connection string.
request_type: request type.
cache_args: arguments for cache.
cache_args are any arguments needed to initialize the cache.
Further, cache_args can contain `array_serializer` as a string
for embedding or image return types (e.g. diffusers) with values
as `local_file` or `byte_string`. `local_file` will save the
array in a local file and cache a pointer to the file.
`byte_string` will convert the array to a byte string and cache
the entire byte string. `byte_string` is default.
Args:
connection_str: connection string for cache.
cache_args: cache arguments.
"""
self.request_type = request_type
self.connect(connection_str, cache_args)
if self.request_type in ARRAY_CACHE_TYPES:
array_serializer = cache_args.pop("array_serializer", "byte_string")
if array_serializer not in ["local_file", "byte_string"]:
raise ValueError(
"array_serializer must be local_file or byte_string,"
f" not {array_serializer}"
)
self.serializer = (
ArraySerializer()
if array_serializer == "local_file"
else NumpyByteSerializer()
)
else:
# If user has array_serializer type, it will throw an error as
# it is not recognized for non-array return types.
self.serializer = Serializer()
@abstractmethod
def close(self) -> None:
"""Close the cache."""
raise NotImplementedError()
@abstractmethod
def connect(self, connection_str: str, cache_args: Dict[str, Any]) -> None:
"""
Connect to cache.
Args:
connection_str: connection string.
"""
raise NotImplementedError()
@abstractmethod
def get_key(self, key: str, table: str = "default") -> Union[str, None]:
"""
Get the key for a request.
With return None if key is not in cache.
Args:
key: key for cache.
table: table to get key in.
"""
raise NotImplementedError()
@abstractmethod
def set_key(self, key: str, value: str, table: str = "default") -> None:
"""
Set the value for the key.
Will override old value.
Args:
key: key for cache.
value: new value for key.
table: table to set key in.
"""
raise NotImplementedError()
@abstractmethod
def commit(self) -> None:
"""Commit any results."""
raise NotImplementedError()
def get(self, request: Dict) -> Union[Response, None]:
"""Get the result of request (by calling compute as needed).
Args:
request: request to get.
response: response to get.
Returns:
Response object or None if not in cache.
"""
key = self.serializer.request_to_key(request)
cached_response = self.get_key(key)
if cached_response:
response = self.serializer.key_to_response(cached_response)
response["cached"] = True
return Response.from_dict(response, request_dict=request)
return None
def set(self, request: Dict, response: Dict) -> None:
"""Set the value for the key.
Args:
request: request to set.
response: response to set.
"""
key = self.serializer.request_to_key(request)
self.set_key(key, self.serializer.response_to_key(response))