Laurel/diffusion (#40)

* Sketch of diffusers added

* [WIP] Array caching implemented with end2end diffusion working

* [WIP] Make initial pass on CLIP model

* [WIP] Get endpoint running for CLIP

* Add support for clip images

* [chore] merge main

* chore: fix xxhash install

Co-authored-by: Sabri Eyuboglu <eyuboglu@stanford.edu>
laurel/helm
Laurel Orr 2 years ago committed by GitHub
parent 26e440b6a6
commit 6f5b64f0df

@ -193,7 +193,7 @@ python3 -m manifest.api.app \
# Development
Before submitting a PR, run
```bash
export REDIS_PORT="6380" # or whatever PORT local redis is running for those tests
export REDIS_PORT="6379" # or whatever PORT local redis is running for those tests
cd <REDIS_PATH>
docker run -d -p 127.0.0.1:${REDIS_PORT}:6379 -v `pwd`:`pwd` -w `pwd` --name manifest_redis_test redis
make test

@ -1,5 +1,6 @@
"""Flask app."""
import argparse
import io
import json
import logging
import os
@ -9,7 +10,8 @@ from typing import Dict
import pkg_resources
from flask import Flask, Response, request
from manifest.api.models.huggingface import HuggingFaceModel
from manifest.api.models.diffuser import DiffuserModel
from manifest.api.models.huggingface import CrossModalEncoderModel, TextGenerationModel
from manifest.api.response import ModelResponse
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@ -18,9 +20,12 @@ logger = logging.getLogger(__name__)
app = Flask(__name__) # define app using Flask
# Will be global
model = None
model_type = None
PORT = int(os.environ.get("FLASK_PORT", 5000))
MODEL_CONSTRUCTORS = {
"huggingface": HuggingFaceModel,
"huggingface": TextGenerationModel,
"huggingface_crossmodal": CrossModalEncoderModel,
"diffuser": DiffuserModel,
}
@ -33,7 +38,7 @@ def parse_args() -> argparse.Namespace:
type=str,
required=True,
help="Model type used for finding constructor.",
choices=["huggingface", "zoo"],
choices=MODEL_CONSTRUCTORS.keys(),
)
parser.add_argument(
"--model_name_or_path",
@ -97,7 +102,7 @@ def main() -> None:
kwargs = parse_args()
if is_port_in_use(PORT):
raise ValueError(f"Port {PORT} is already in use.")
global model_type
model_type = kwargs.model_type
model_name_or_path = kwargs.model_name_or_path
if not model_name_or_path:
@ -150,15 +155,19 @@ def completions() -> Response:
if not isinstance(prompt, (str, list)):
raise ValueError("Prompt must be a str or list of str")
try:
results_text = []
result_gens = []
for generations in model.generate(prompt, **generation_args):
results_text.append(generations)
results = [{"text": r[0], "text_logprob": r[1]} for r in results_text]
result_gens.append(generations)
if model_type == "diffuser":
# Assign None logprob as it's not supported in diffusers
results = [{"array": r[0], "logprob": None} for r in result_gens]
res_type = "image_generation"
else:
results = [{"text": r[0], "logprob": r[1]} for r in result_gens]
res_type = "text_completion"
# transform the result into the openai format
return Response(
json.dumps(
ModelResponse(results, response_type="text_completion").__dict__()
),
json.dumps(ModelResponse(results, response_type=res_type).__dict__()),
status=200,
)
except Exception as e:
@ -169,6 +178,34 @@ def completions() -> Response:
)
@app.route("/embed", methods=["POST"])
def embed() -> Dict:
"""Get embed for generation."""
modality = request.json["modality"]
if modality == "text":
prompts = request.json["prompts"]
elif modality == "image":
import base64
from PIL import Image
prompts = [
Image.open(io.BytesIO(base64.b64decode(data)))
for data in request.json["prompts"]
]
else:
raise ValueError("modality must be text or image")
results = []
embeddings = model.embed(prompts)
for embedding in embeddings:
results.append(embedding.tolist())
# transform the result into the openai format
# return Response(results, response_type="text_completion").__dict__()
return {"result": results}
@app.route("/choice_logits", methods=["POST"])
def choice_logits() -> Response:
"""Get maximal likely choice via max logits after generation."""

@ -0,0 +1,106 @@
"""Huggingface model."""
from pathlib import Path
from typing import Any, Dict, List, Tuple, Union
import torch
from diffusers import StableDiffusionPipeline
from manifest.api.models.model import Model
class DiffuserModel(Model):
"""Diffuser model."""
def __init__(
self,
model_name_or_path: str,
model_config: str = None,
cache_dir: str = None,
device: int = 0,
use_accelerate: bool = False,
use_parallelize: bool = False,
use_bitsandbytes: bool = False,
perc_max_gpu_mem_red: float = 1.0,
use_fp16: bool = True,
):
"""
Initialize model.
All arguments will be passed in the request from Manifest.
Args:
model_name_or_path: model name string.
model_config: model config string.
cache_dir: cache directory for model.
device: device to use for model.
use_accelerate: whether to use accelerate for multi-gpu inference.
use_parallelize: use HF default parallelize
use_bitsandbytes: use HF bits and bytes
perc_max_gpu_mem_red: percent max memory reduction in accelerate
use_fp16: use fp16 for model weights.
"""
if use_accelerate or use_parallelize or use_bitsandbytes:
raise ValueError(
"Cannot use accelerate or parallelize or bitsandbytes with diffusers"
)
# Check if providing path
self.model_path = model_name_or_path
if Path(self.model_path).exists() and Path(self.model_path).is_dir():
model_name_or_path = Path(self.model_path).name
self.model_name = model_name_or_path
print("Model Name:", self.model_name, "Model Path:", self.model_path)
dtype = torch.float16 if use_fp16 else None
torch_device = (
torch.device("cpu")
if (device == -1 or not torch.cuda.is_available())
else torch.device(f"cuda:{device}")
)
self.pipeline = StableDiffusionPipeline.from_pretrained(
self.model_path,
torch_dtype=dtype,
revision="fp16" if str(dtype) == "float16" else None,
)
self.pipeline.safety_checker = None
self.pipeline.to(torch_device)
def get_init_params(self) -> Dict:
"""Return init params to determine what model is being used."""
return {"model_name": self.model_name, "model_path": self.model_path}
@torch.no_grad()
def generate(
self, prompt: Union[str, List[str]], **kwargs: Any
) -> List[Tuple[Any, float]]:
"""
Generate the prompt from model.
Outputs must be generated text and score, not including prompt.
Args:
prompt: promt to generate from.
Returns:
list of generated text (list of length 1 for 1 generation).
"""
# TODO: Is this correct for getting arguments in?
if isinstance(prompt, str):
prompt = [prompt]
result = self.pipeline(prompt, output_type="np.array", **kwargs)
# Return None for logprobs
return [(im, None) for im in result["images"]]
@torch.no_grad()
def logits_scoring(
self, prompt: Union[str, List[str]], gold_choices: List[str], **kwargs: Any
) -> List[Tuple[Any, float]]:
"""
Given the prompt and gold choices, choose the best choice with max logits.
Args:
prompt: promt to generate from.
gold_choices: list of choices to choose from.
Returns:
the returned gold choice
"""
raise NotImplementedError("Logits scoring not supported for diffusers")

@ -3,6 +3,8 @@ import json
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union, cast
import numpy as np
import PIL
import torch
from accelerate import dispatch_model, infer_auto_device_map
from accelerate.utils.modeling import get_max_memory as acc_get_max_memory
@ -11,6 +13,8 @@ from transformers import (
AutoModelForSeq2SeqLM,
AutoTokenizer,
BloomForCausalLM,
CLIPModel,
CLIPProcessor,
GPT2LMHeadModel,
GPTJForCausalLM,
GPTNeoForCausalLM,
@ -29,21 +33,17 @@ MODEL_REGISTRY = {
"EleutherAI/gpt-neo-2.7B": GPTNeoForCausalLM,
"EleutherAI/gpt-j-6B": GPTJForCausalLM,
"EleutherAI/gpt-neox-20b": GPTNeoXForCausalLM,
"Salesforce/codegen-2B-mono": AutoModelForCausalLM,
"Salesforce/codegen-6B-mono": AutoModelForCausalLM,
"facebook/opt-125m": OPTForCausalLM,
"facebook/opt-350m": OPTForCausalLM,
"Salesforce/codegen-2B-mono": AutoModelForCausalLM,
"Salesforce/codegen-6B-mono": AutoModelForCausalLM,
"facebook/opt-1.3b": OPTForCausalLM,
"facebook/opt-2.7b": OPTForCausalLM,
"facebook/opt-6.7b": OPTForCausalLM,
"facebook/opt-13b": OPTForCausalLM,
"facebook/opt-30b": OPTForCausalLM,
"facebook/galactica-125m": OPTForCausalLM,
"facebook/galactica-1.3b": OPTForCausalLM,
"facebook/galactica-6.7b": OPTForCausalLM,
"facebook/galactica-30b": OPTForCausalLM,
"facebook/galactica-120b": OPTForCausalLM,
"gpt2": GPT2LMHeadModel,
"openai/clip-vit-base-patch32": CLIPModel,
"bigscience/bloom-560m": BloomForCausalLM,
"bigscience/bloom-1b7": BloomForCausalLM,
"bigscience/bloom-3b": BloomForCausalLM,
@ -75,7 +75,7 @@ def get_max_memory(gpu_reduction: float) -> Dict[int, str]:
return max_mem_dict
class Pipeline:
class GenerationPipeline:
"""
Custom Pipeline.
@ -113,7 +113,9 @@ class Pipeline:
print(f"Usings max_length: {self.max_length}")
self.tokenizer = tokenizer
# self.device = device
# With bits and bytes, do not want to place inputs on any device
# if self.device:
self.device = (
torch.device("cpu")
if (device == -1 or not torch.cuda.is_available())
@ -147,8 +149,7 @@ class Pipeline:
)
encoded_prompt = encoded_prompt.to(self.device)
output_dict = self.model.generate( # type: ignore
input_ids=encoded_prompt.input_ids,
attention_mask=encoded_prompt.attention_mask,
**encoded_prompt,
max_new_tokens=kwargs.get("max_new_tokens"),
temperature=kwargs.get("temperature", None),
top_k=kwargs.get("top_k", None),
@ -181,11 +182,12 @@ class Pipeline:
class HuggingFaceModel(Model):
"""Huggingface model."""
"""HuggingFace Model."""
def __init__(
self,
model_name_or_path: str,
model_config: Optional[str] = None,
cache_dir: Optional[str] = None,
device: int = 0,
use_accelerate: bool = False,
@ -202,6 +204,7 @@ class HuggingFaceModel(Model):
Args:
model_name_or_path: model name string.
model_config: model config string.
cache_dir: cache directory for model.
device: device to use for model.
use_accelerate: whether to use accelerate for multi-gpu inference.
@ -225,79 +228,6 @@ class HuggingFaceModel(Model):
model_name_or_path = config["_name_or_path"]
self.model_name = model_name_or_path
print("Model Name:", self.model_name, "Model Path:", self.model_path)
try:
tokenizer = AutoTokenizer.from_pretrained(
self.model_name, truncation_side="left", padding_side="left"
)
except ValueError:
tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
truncation_side="left",
padding_side="left",
use_fast=False,
)
dtype = torch.float16 if use_fp16 else "auto"
if use_bitsandbytes:
print("WARNING!!! Cannot use sampling with bitsandbytes.")
max_memory = get_max_memory(perc_max_gpu_mem_red)
model = MODEL_REGISTRY[self.model_name].from_pretrained( # type: ignore
self.model_path,
cache_dir=cache_dir,
load_in_8bit=True,
device_map="auto",
max_memory=max_memory,
)
else:
try:
# Try to explicitely find a fp16 copy (gpt-j-6B for example)
model = MODEL_REGISTRY[self.model_name].from_pretrained( # type: ignore
self.model_path,
cache_dir=cache_dir,
revision="float16",
torch_dtype=torch.float16,
)
except Exception:
model = MODEL_REGISTRY[self.model_name].from_pretrained( # type: ignore
self.model_path, cache_dir=cache_dir, torch_dtype=dtype
)
model.eval()
print(f"Loaded Model DType {model.dtype}")
self.is_encdec = model.config.is_encoder_decoder
# Set pad tokens for galactic
if self.model_name.startswith("facebook/galactic"):
# https://github.com/paperswithcode/galai/blob/main/galai/model.py
tokenizer.pad_token = "[PAD]"
# tokenizer.add_special_tokens({'pad_token': '[PAD]'})
tokenizer.pad_token_id = 1
elif not self.is_encdec:
tokenizer.pad_token = tokenizer.eos_token
if not use_bitsandbytes:
if use_accelerate:
self._dispatch_accelerate_model(model, perc_max_gpu_mem_red)
device = 0
elif use_parallelize:
model.parallelize()
device = 0
elif use_deepspeed:
self._dispatch_deepspeed_model(model)
device = 0
else:
if device > -1:
torch_device = (
torch.device("cpu")
if (device == -1 or not torch.cuda.is_available())
else torch.device(f"cuda:{device}")
)
model = model.to(torch_device) # type: ignore
self.pipeline = Pipeline( # type: ignore
model=model,
tokenizer=tokenizer,
device=device,
bitsandbytes=use_bitsandbytes,
is_encdec=self.is_encdec,
)
def get_init_params(self) -> Dict:
"""Return init params to determine what model is being used."""
@ -379,10 +309,222 @@ class HuggingFaceModel(Model):
dispatch_model(model, device_map=device_map)
return
class CrossModalEncoderModel(HuggingFaceModel):
"""CrossModalEncoderModel."""
def __init__(
self,
model_name_or_path: str,
model_config: Optional[str] = None,
cache_dir: Optional[str] = None,
device: int = 0,
use_accelerate: bool = False,
use_parallelize: bool = False,
use_bitsandbytes: bool = False,
use_deepspeed: bool = False,
perc_max_gpu_mem_red: float = 1.0,
use_fp16: bool = False,
):
"""
Initialize model.
All arguments will be passed in the request from Manifest.
Args:
model_name_or_path: model name string.
model_config: model config string.
cache_dir: cache directory for model.
device: device to use for model.
use_accelerate: whether to use accelerate for multi-gpu inference.
use_parallelize: use HF default parallelize
use_bitsandbytes: use HF bits and bytes
use_deepspeed: use deepspeed
perc_max_gpu_mem_red: percent max memory reduction in accelerate
use_fp16: use fp16 for model weights.
"""
super().__init__(
model_name_or_path,
model_config,
cache_dir,
device,
use_accelerate,
use_parallelize,
use_bitsandbytes,
use_deepspeed,
perc_max_gpu_mem_red,
use_fp16,
)
# TODO: make this generalizable
self.processor = CLIPProcessor.from_pretrained(self.model_path)
model = MODEL_REGISTRY[self.model_name].from_pretrained(
self.model_path,
cache_dir=cache_dir,
)
model.eval()
torch_device = (
torch.device("cpu")
if (device == -1 or not torch.cuda.is_available())
else torch.device(f"cuda:{device}")
)
print("T", torch_device)
self.model = model.to(torch_device) # type: ignore
@torch.no_grad()
def embed(self, prompt: Union[str, List[str]], **kwargs: Any) -> np.ndarray:
"""
Compute embedding for prompts.
Args:
prompt: promt to generate from.
Returns:
embedding
"""
if isinstance(prompt, str):
inputs = self.processor(text=prompt, return_tensors="pt", padding=True)
elif isinstance(prompt, PIL.Image.Image):
inputs = self.processor(images=prompt, return_tensors="pt", padding=True)
else:
raise ValueError("Prompt must be a string or an image")
outputs = self.model(**inputs)
return outputs
class TextGenerationModel(HuggingFaceModel):
"""Huggingface model."""
def __init__(
self,
model_name_or_path: str,
model_config: Optional[str] = None,
cache_dir: Optional[str] = None,
device: int = 0,
use_accelerate: bool = False,
use_parallelize: bool = False,
use_bitsandbytes: bool = False,
use_deepspeed: bool = False,
perc_max_gpu_mem_red: float = 1.0,
use_fp16: bool = False,
):
"""
Initialize model.
All arguments will be passed in the request from Manifest.
Args:
model_name_or_path: model name string.
model_config: model config string.
cache_dir: cache directory for model.
device: device to use for model.
use_accelerate: whether to use accelerate for multi-gpu inference.
use_parallelize: use HF default parallelize
use_bitsandbytes: use HF bits and bytes
use_deepspeed: use deepspeed
perc_max_gpu_mem_red: percent max memory reduction in accelerate
use_fp16: use fp16 for model weights.
"""
super().__init__(
model_name_or_path,
model_config,
cache_dir,
device,
use_accelerate,
use_parallelize,
use_bitsandbytes,
use_deepspeed,
perc_max_gpu_mem_red,
use_fp16,
)
try:
tokenizer = AutoTokenizer.from_pretrained(
self.model_name, truncation_side="left", padding_side="left"
)
except ValueError:
tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
truncation_side="left",
padding_side="left",
use_fast=False,
)
dtype = torch.float16 if use_fp16 else "auto"
if use_bitsandbytes:
print("WARNING!!! Cannot use sampling with bitsandbytes.")
max_memory = get_max_memory(perc_max_gpu_mem_red)
model = MODEL_REGISTRY[self.model_name].from_pretrained( # type: ignore
self.model_path,
cache_dir=cache_dir,
load_in_8bit=True,
device_map="auto",
max_memory=max_memory,
)
else:
try:
# Try to explicitely find a fp16 copy (gpt-j-6B for example)
model = MODEL_REGISTRY[self.model_name].from_pretrained( # type: ignore
self.model_path,
cache_dir=cache_dir,
revision="float16",
torch_dtype=torch.float16,
)
except Exception:
model = MODEL_REGISTRY[self.model_name].from_pretrained( # type: ignore
self.model_path, cache_dir=cache_dir, torch_dtype=dtype
)
model.eval()
print(f"Loaded Model DType {model.dtype}")
self.is_encdec = model.config.is_encoder_decoder
if not self.is_encdec:
tokenizer.pad_token = tokenizer.eos_token
if not use_bitsandbytes:
if use_accelerate:
self._dispatch_accelerate_model(model, perc_max_gpu_mem_red)
device = 0
elif use_parallelize:
model.parallelize()
device = 0
elif use_deepspeed:
self._dispatch_deepspeed_model(model)
device = 0
else:
if device > -1:
torch_device = (
torch.device("cpu")
if (device == -1 or not torch.cuda.is_available())
else torch.device(f"cuda:{device}")
)
model = model.to(torch_device) # type: ignore
self.pipeline = GenerationPipeline( # type: ignore
model=model,
tokenizer=tokenizer,
device=device,
bitsandbytes=use_bitsandbytes,
is_encdec=self.is_encdec,
)
@torch.no_grad()
def embed(self, prompt: Union[str, List[str]], **kwargs: Any) -> np.ndarray:
"""
Compute embedding for prompts.
Args:
prompt: promt to generate from.
Returns:
embedding
"""
pass
@torch.no_grad()
def generate(
self, prompt: Union[str, List[str]], **kwargs: Any
) -> List[Tuple[str, float]]:
) -> List[Tuple[Any, float]]:
"""
Generate the prompt from model.
@ -416,7 +558,7 @@ class HuggingFaceModel(Model):
@torch.no_grad()
def logits_scoring(
self, prompt: Union[str, List[str]], gold_choices: List[str], **kwargs: Any
) -> List[Tuple[str, float]]:
) -> List[Tuple[Any, float]]:
"""
Given the prompt and gold choices, choose the best choice with max logits.

@ -1,6 +1,8 @@
"""Model class."""
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Tuple, Union
import numpy as np
class Model(ABC):
@ -42,8 +44,9 @@ class Model(ABC):
"""Return init params to determine what model is being used."""
raise NotImplementedError()
@abstractmethod
def generate(self, prompt: str, **kwargs: Any) -> List[Tuple[str, float]]:
def generate(
self, prompt: Union[str, List[str]], **kwargs: Any
) -> List[Tuple[Any, float]]:
"""
Generate the prompt from model.
@ -58,9 +61,21 @@ class Model(ABC):
raise NotImplementedError()
@abstractmethod
def embed(self, prompt: Union[str, List[str]], **kwargs: Any) -> np.ndarray:
"""
Compute embedding for prompts.
Args:
prompt: promt to generate from.
Returns:
embedding
"""
raise NotImplementedError()
def logits_scoring(
self, prompt: str, gold_choices: List[str], **kwargs: Any
) -> List[Tuple[str, float]]:
self, prompt: Union[str, List[str]], gold_choices: List[str], **kwargs: Any
) -> List[Tuple[Any, float]]:
"""
Given the prompt and gold choices, choose the best choice with max logits.

@ -12,7 +12,11 @@ class ModelResponse:
"""Initialize response."""
self.results = results
self.response_type = response_type
if self.response_type not in {"text_completion", "choice_selection"}:
if self.response_type not in {
"text_completion",
"choice_selection",
"image_generation",
}:
raise ValueError(
f"Invalid response type: {self.response_type}. "
"Must be one of: text_completion, choice_selection."
@ -22,6 +26,7 @@ class ModelResponse:
def __dict__(self) -> Dict[str, Any]: # type: ignore
"""Return dictionary representation of response."""
key = "text" if self.response_type != "image_generation" else "array"
return {
"id": self.response_id,
"object": self.response_type,
@ -29,16 +34,13 @@ class ModelResponse:
"model": "flask_model",
"choices": [
{
"text": result["text"],
"text_logprob": result["text_logprob"],
# TODO: Add in more metadata for HF models
# "logprobs": {
# "tokens": result["tokens"],
# "token_logprobs": result["token_scores"],
# "text_offset": result["text_offset"],
# "top_logprobs": result["top_logprobs"],
# "finish_reason": "length",
# },
key: result[key],
"logprob": result["logprob"],
}
if key == "text"
else {
key: result[key].tolist(),
"logprob": result["logprob"],
}
for result in self.results
],

@ -0,0 +1,115 @@
"""Array cache."""
from pathlib import Path
from typing import Union
import numpy as np
from sqlitedict import SqliteDict
def open_mmap_arr(file: Union[Path, str], size: float) -> np.memmap:
"""Open memmap."""
if not Path(file).exists():
mode = "w+"
else:
mode = "r+"
arr = np.memmap( # type: ignore
str(file),
dtype=np.float32, # This means we only support float 32
mode=mode,
shape=size,
)
return arr
class ArrayCache:
"""Array cache."""
def __init__(self, folder: Union[str, Path]) -> None:
"""
Initialize the array writer.
Args:
folder: folder to write to.
"""
self.folder = Path(folder)
self.folder.mkdir(exist_ok=True, parents=True)
self.hash2arrloc = SqliteDict(
self.folder / "hash2arrloc.sqlite", autocommit=True
)
# Approx 1GB (I think)
self.max_memmap_size = 20480000
self.cur_file_idx = 0
# Get the last file idx used
for key in self.hash2arrloc:
file_data = self.hash2arrloc[key]
if file_data["file_idx"] > self.cur_file_idx:
self.cur_file_idx = file_data["file_idx"]
self.cur_memmap = open_mmap_arr(
self.folder / f"{self.cur_file_idx}.npy",
self.max_memmap_size,
)
# Make sure there is space left in the memmap
non_zero = np.nonzero(self.cur_memmap)[0]
if len(non_zero) > 0:
self.cur_offset = int(np.max(non_zero) + 1)
else:
self.cur_offset = 0
# If no space, make a new memmap
if self.cur_offset == self.max_memmap_size:
self.cur_file_idx += 1
self.cur_memmap = open_mmap_arr(
self.folder / f"{self.cur_file_idx}.npy",
self.max_memmap_size,
)
self.cur_offset = 0
def contains_key(self, key: str) -> bool:
"""
Check if the key is in the cache.
Args:
key: key to check.
Returns:
True if the key is in the cache.
"""
return key in self.hash2arrloc
def put(self, key: str, arr: np.ndarray) -> None:
"""Save array in store and associate location with key."""
# Check if there is space in the memmap
arr_shape = arr.shape
arr = arr.flatten()
if len(arr) > self.max_memmap_size:
raise ValueError(
f"Array is too large to be cached. Max is {self.max_memmap_size}"
)
if self.cur_offset + len(arr) > self.max_memmap_size:
self.cur_file_idx += 1
self.cur_memmap = open_mmap_arr(
self.folder / f"{self.cur_file_idx}.npy",
self.max_memmap_size,
)
self.cur_offset = 0
self.cur_memmap[self.cur_offset : self.cur_offset + len(arr)] = arr
self.cur_memmap.flush()
self.hash2arrloc[key] = {
"file_idx": self.cur_file_idx,
"offset": self.cur_offset,
"flatten_size": len(arr),
"shape": arr_shape,
}
self.cur_offset += len(arr)
return
def get(self, key: str) -> np.ndarray:
"""Get array associated with location from key."""
file_data = self.hash2arrloc[key]
memmap = open_mmap_arr(
self.folder / f"{file_data['file_idx']}.npy",
self.max_memmap_size,
)
arr = memmap[
file_data["offset"] : file_data["offset"] + file_data["flatten_size"]
]
return arr.reshape(file_data["shape"])

@ -1,70 +1,38 @@
"""Cache for queries and responses."""
import json
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Union
from manifest.caches.serializers import ArraySerializer, Serializer
from manifest.response import Response
RESPONSE_CONSTRUCTORS = {
"diffuser": {
"generation_key": "choices",
"logits_key": "logprobs",
"item_key": "array",
},
}
def request_to_key(request: Dict) -> str:
"""
Normalize a request into a key.
Args:
request: request to normalize.
Returns:
normalized key.
"""
return json.dumps(request, sort_keys=True)
def key_to_request(key: str) -> Dict:
"""
Convert the normalized version to the request.
Args:
key: normalized key to convert.
Returns:
unnormalized request dict.
"""
return json.loads(key)
def response_to_key(response: Dict) -> str:
"""
Normalize a response into a key.
Args:
response: response to normalize.
Returns:
normalized key.
"""
return json.dumps(response, sort_keys=True)
def key_to_response(key: str) -> Dict:
"""
Convert the normalized version to the response.
Args:
key: normalized key to convert.
Returns:
unnormalized response dict.
"""
return json.loads(key)
CACHE_CONSTRUCTOR = {"diffuser": ArraySerializer}
class Cache(ABC):
"""A cache for request/response pairs."""
def __init__(self, connection_str: str, cache_args: Dict[str, Any] = {}):
def __init__(
self,
connection_str: str,
client_name: str = "None",
cache_args: Dict[str, Any] = {},
):
"""
Initialize client.
Args:
connection_str: connection string.
client_name: name of client.
cache_args: arguments for cache.
cache_args are passed to client as default parameters.
For clients like OpenAI that do not require a connection,
@ -74,7 +42,9 @@ class Cache(ABC):
connection_str: connection string for client.
cache_args: cache arguments.
"""
self.client_name = client_name
self.connect(connection_str, cache_args)
self.serializer = CACHE_CONSTRUCTOR.get(client_name, Serializer)()
@abstractmethod
def close(self) -> None:
@ -127,14 +97,16 @@ class Cache(ABC):
self, request: Dict, overwrite_cache: bool, compute: Callable[[], Dict]
) -> Response:
"""Get the result of request (by calling compute as needed)."""
key = request_to_key(request)
key = self.serializer.request_to_key(request)
cached_response = self.get_key(key)
if cached_response and not overwrite_cache:
cached = True
response = key_to_response(cached_response)
response = self.serializer.key_to_response(cached_response)
else:
# Type Response
response = compute()
self.set_key(key, response_to_key(response))
self.set_key(key, self.serializer.response_to_key(response))
cached = False
return Response(response, cached, request)
return Response(
response, cached, request, **RESPONSE_CONSTRUCTORS.get(self.client_name, {})
)

@ -13,7 +13,7 @@ class NoopCache(Cache):
Args:
connection_str: connection string.
cache_args: cache arguments.
cache_args: arguments for cache.
"""
pass

@ -15,7 +15,7 @@ class RedisCache(Cache):
Args:
connection_str: connection string.
cache_args: cache arguments.
cache_args: arguments for cache.
"""
host, port = connection_str.split(":")
self.redis = redis.Redis(host=host, port=int(port), db=0)

@ -0,0 +1,140 @@
"""Serializer."""
import json
import os
from pathlib import Path
from typing import Dict
import xxhash
from manifest.caches.array_cache import ArrayCache
class Serializer:
"""Serializer."""
def request_to_key(self, request: Dict) -> str:
"""
Normalize a request into a key.
Args:
request: request to normalize.
Returns:
normalized key.
"""
return json.dumps(request, sort_keys=True)
def key_to_request(self, key: str) -> Dict:
"""
Convert the normalized version to the request.
Args:
key: normalized key to convert.
Returns:
unnormalized request dict.
"""
return json.loads(key)
def response_to_key(self, response: Dict) -> str:
"""
Normalize a response into a key.
Args:
response: response to normalize.
Returns:
normalized key.
"""
return json.dumps(response, sort_keys=True)
def key_to_response(self, key: str) -> Dict:
"""
Convert the normalized version to the response.
Args:
key: normalized key to convert.
Returns:
unnormalized response dict.
"""
return json.loads(key)
class ArraySerializer(Serializer):
"""Serializer for array."""
def __init__(self) -> None:
"""
Initialize array serializer.
We don't want to cache the array. We hash the value and
store the array in a memmap file. Store filename/offsets
in sqlitedict to keep track of hash -> array.
"""
super().__init__()
self.hash = xxhash.xxh64()
manifest_home = Path(os.environ.get("MANIFEST_HOME", Path.home()))
cache_folder = manifest_home / ".manifest" / "array_cache"
self.writer = ArrayCache(cache_folder)
def response_to_key(self, response: Dict) -> str:
"""
Normalize a response into a key.
Convert arrays to hash string for cache key.
Args:
response: response to normalize.
Returns:
normalized key.
"""
# Assume response is a dict with keys "choices" -> List dicts
# with keys "array".
choices = response["choices"]
# We don't want to modify the response in place
# but we want to avoid calling deepcopy on an array
del response["choices"]
response_copy = response.copy()
response["choices"] = choices
response_copy["choices"] = []
for choice in choices:
if "array" not in choice:
raise ValueError(
f"Choice with keys {choice.keys()} does not have array key."
)
arr = choice["array"]
# Avoid copying an array
del choice["array"]
new_choice = choice.copy()
choice["array"] = arr
self.hash.update(arr)
hash_str = self.hash.hexdigest()
self.hash.reset()
new_choice["array"] = hash_str
response_copy["choices"].append(new_choice)
if not self.writer.contains_key(hash_str):
self.writer.put(hash_str, arr)
return json.dumps(response_copy, sort_keys=True)
def key_to_response(self, key: str) -> Dict:
"""
Convert the normalized version to the response.
Convert the hash string keys to the arrays.
Args:
key: normalized key to convert.
Returns:
unnormalized response dict.
"""
response = json.loads(key)
for choice in response["choices"]:
hash_str = choice["array"]
choice["array"] = self.writer.get(hash_str)
return response

@ -18,7 +18,7 @@ class SQLiteCache(Cache):
Args:
connection_str: connection string.
cache_args: cache arguments.
cache_args: arguments for cache.
"""
self.cache_file = connection_str
if not self.cache_file:

@ -4,6 +4,7 @@ import os
from typing import Any, Dict, Optional
from manifest.clients.client import Client
from manifest.request import LMRequest
logger = logging.getLogger(__name__)
@ -28,6 +29,7 @@ class AI21Client(Client):
"stop_sequences": ("stopSequences", []),
"client_timeout": ("client_timeout", 60), # seconds
}
REQUEST_CLS = LMRequest
def connect(
self,

@ -15,6 +15,7 @@ class Client(ABC):
# Must be overridden by child class
PARAMS: Dict[str, Tuple[str, Any]] = {}
REQUEST_CLS = Request
def __init__(
self, connection_str: Optional[str] = None, client_args: Dict[str, Any] = {}
@ -108,7 +109,7 @@ class Client(ABC):
params = {"prompt": prompt}
for key in self.PARAMS:
params[key] = request_args.pop(key, getattr(self, key))
return Request(**params)
return self.REQUEST_CLS(**params)
def format_response(self, response: Dict) -> Dict[str, Any]:
"""

@ -5,6 +5,7 @@ import os
from typing import Any, Dict, Optional
from manifest.clients.client import Client
from manifest.request import LMRequest
logger = logging.getLogger(__name__)
@ -27,6 +28,7 @@ class CohereClient(Client):
"stop_sequences": ("stop_sequences", None),
"client_timeout": ("client_timeout", 60), # seconds
}
REQUEST_CLS = LMRequest
def connect(
self,

@ -0,0 +1,92 @@
"""Hugging Face client."""
import logging
from typing import Any, Dict, Optional
import numpy as np
import requests
from manifest.clients.client import Client
from manifest.request import DiffusionRequest
logger = logging.getLogger(__name__)
class DiffuserClient(Client):
"""Diffuser client."""
# User param -> (client param, default value)
PARAMS = {
"num_inference_steps": ("num_inference_steps", 50),
"height": ("height", 512),
"width": ("width", 512),
"n": ("num_images_per_prompt", 1),
"guidance_scale": ("guidance_scale", 7.5),
"eta": ("eta", 0.0),
}
REQUEST_CLS = DiffusionRequest
def connect(
self,
connection_str: Optional[str] = None,
client_args: Dict[str, Any] = {},
) -> None:
"""
Connect to the Diffuser url.
Arsg:
connection_str: connection string.
client_args: client arguments.
"""
self.host = connection_str.rstrip("/")
for key in self.PARAMS:
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
self.model_params = self.get_model_params()
def close(self) -> None:
"""Close the client."""
pass
def get_generation_url(self) -> str:
"""Get generation URL."""
return self.host + "/completions"
def get_generation_header(self) -> Dict[str, str]:
"""
Get generation header.
Returns:
header.
"""
return {}
def supports_batch_inference(self) -> bool:
"""Return whether the client supports batch inference."""
return True
def get_model_params(self) -> Dict:
"""
Get model params.
By getting model params from the server, we can add to request
and make sure cache keys are unique to model.
Returns:
model params.
"""
res = requests.post(self.host + "/params")
return res.json()
def format_response(self, response: Dict) -> Dict[str, Any]:
"""
Format response to dict.
Args:
response: response
Return:
response as dict
"""
# Convert array to np.array
for choice in response["choices"]:
choice["array"] = np.array(choice["array"])
return response

@ -3,7 +3,7 @@ import logging
from typing import Any, Callable, Dict, List, Optional, Tuple
from manifest.clients.client import Client
from manifest.request import Request
from manifest.request import LMRequest, Request
logger = logging.getLogger(__name__)
@ -15,6 +15,7 @@ class DummyClient(Client):
PARAMS = {
"n": ("num_results", 1),
}
REQUEST_CLS = LMRequest
def connect(
self,

@ -5,7 +5,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
import requests
from manifest.clients.client import Client
from manifest.request import Request
from manifest.request import LMRequest, Request
logger = logging.getLogger(__name__)
@ -24,6 +24,7 @@ class HuggingFaceClient(Client):
"do_sample": ("do_sample", True),
"client_timeout": ("client_timeout", 120), # seconds
}
REQUEST_CLS = LMRequest
def connect(
self,

@ -4,6 +4,7 @@ import os
from typing import Any, Dict, Optional
from manifest.clients.client import Client
from manifest.request import LMRequest
logger = logging.getLogger(__name__)
@ -39,6 +40,7 @@ class OpenAIClient(Client):
"frequency_penalty": ("frequency_penalty", 0.0),
"client_timeout": ("client_timeout", 60), # seconds
}
REQUEST_CLS = LMRequest
def connect(
self,

@ -7,6 +7,7 @@ from typing import Any, Dict, Optional
import requests
from manifest.clients.client import Client
from manifest.request import LMRequest
logger = logging.getLogger(__name__)
@ -33,6 +34,7 @@ class TOMAClient(Client):
"stop_sequences": ("stop", []),
"client_timeout": ("client_timeout", 120), # seconds
}
REQUEST_CLS = LMRequest
def connect(
self,

@ -2,11 +2,14 @@
import logging
from typing import Any, List, Optional, Tuple, Union, cast
import numpy as np
from manifest.caches.noop import NoopCache
from manifest.caches.redis import RedisCache
from manifest.caches.sqlite import SQLiteCache
from manifest.clients.ai21 import AI21Client
from manifest.clients.cohere import CohereClient
from manifest.clients.diffuser import DiffuserClient
from manifest.clients.dummy import DummyClient
from manifest.clients.huggingface import HuggingFaceClient
from manifest.clients.openai import OpenAIClient
@ -22,6 +25,7 @@ CLIENT_CONSTRUCTORS = {
"cohere": CohereClient,
"ai21": AI21Client,
"huggingface": HuggingFaceClient,
"diffuser": DiffuserClient,
"dummy": DummyClient,
"toma": TOMAClient,
}
@ -75,12 +79,16 @@ class Manifest:
self.client_name = client_name
# Must pass kwargs as dict for client "pop" methods removed used arguments
self.cache = CACHE_CONSTRUCTORS[cache_name]( # type: ignore
cache_connection, cache_args=kwargs
cache_connection, self.client_name, cache_args=kwargs
)
self.client = CLIENT_CONSTRUCTORS[client_name]( # type: ignore
self.client = CLIENT_CONSTRUCTORS[self.client_name]( # type: ignore
client_connection, client_args=kwargs
)
if session_id:
if session_id is not None:
if self.client_name == "diffuser":
raise NotImplementedError(
"Session logging not implemented for Diffuser client."
)
if session_id == "_default":
# Set session_id to None for Session random id
session_id = None
@ -106,7 +114,7 @@ class Manifest:
stop_token: Optional[str] = None,
return_response: bool = False,
**kwargs: Any,
) -> Union[str, List[str], Response]:
) -> Union[str, List[str], np.ndarray, List[np.ndarray], Response]:
"""
Run the prompt.

@ -13,15 +13,44 @@ class Request(BaseModel):
# Engine
engine: str = "text-ada-001"
# Number completions
n: int = 1
# Timeout
client_timeout: int = 60
def to_dict(
self, allowable_keys: Dict[str, Tuple[str, Any]] = None, add_prompt: bool = True
) -> Dict[str, Any]:
"""
Convert request to a dictionary.
Add prompt ensures the prompt is always in the output dictionary.
"""
if allowable_keys:
include_keys = set(allowable_keys.keys())
if add_prompt and "prompt":
include_keys.add("prompt")
else:
allowable_keys = {}
include_keys = None
request_dict = {
allowable_keys.get(k, (k, None))[0]: v
for k, v in self.dict(include=include_keys).items()
if v is not None
}
return request_dict
class LMRequest(Request):
"""Language Model Request object."""
# Temperature for generation
temperature: float = 0.7
# Max tokens for generation
max_tokens: int = 100
# Number completions
n: int = 1
# Nucleus sampling taking top_p probability mass tokens
top_p: float = 1.0
@ -49,27 +78,21 @@ class Request(BaseModel):
# Penalize frequency
frequency_penalty: float = 0
# Timeout
client_timeout: int = 60
def to_dict(
self, allowable_keys: Dict[str, Tuple[str, Any]] = None, add_prompt: bool = True
) -> Dict[str, Any]:
"""
Convert request to a dictionary.
class DiffusionRequest(Request):
"""Diffusion Model Request object."""
Add prompt ensures the prompt is always in the output dictionary.
"""
if allowable_keys:
include_keys = set(allowable_keys.keys())
if add_prompt and "prompt":
include_keys.add("prompt")
else:
allowable_keys = {}
include_keys = None
request_dict = {
allowable_keys.get(k, (k, None))[0]: v
for k, v in self.dict(include=include_keys).items()
if v is not None
}
return request_dict
# Number of steps
num_inference_steps: int = 50
# Height of image
height: int = 512
# Width of image
width: int = 512
# Guidance scale
guidance_scale: float = 7.5
# Eta
eta: float = 0.0

@ -1,30 +1,82 @@
"""Client response."""
import json
from typing import Dict, List, Union
from typing import Any, Dict, List, Union
import numpy as np
class NumpyArrayEncoder(json.JSONEncoder):
"""Numpy array encoder."""
def default(self, obj: Any) -> str:
"""Encode numpy array."""
if isinstance(obj, np.ndarray):
return obj.tolist()
return json.JSONEncoder.default(self, obj)
class Response:
"""Response class."""
def __init__(self, response: Dict, cached: bool, request_params: Dict):
"""Initialize response."""
def __init__(
self,
response: Dict,
cached: bool,
request_params: Dict,
generation_key: str = "choices",
logits_key: str = "logprobs",
item_key: str = "text",
):
"""
Initialize response.
Args:
response: response dict.
cached: whether response is cached.
request_params: request parameters.
generation_key: key for generation results.
logits_key: key for logits.
item_key: key for item in the generations.
"""
self.generation_key = generation_key
self.logits_key = logits_key
self.item_key = item_key
self.item_dtype = None
if isinstance(response, dict):
self._response = response
else:
raise ValueError(f"Response must be str or dict. Response is\n{response}.")
if ("choices" not in self._response) or (
not isinstance(self._response["choices"], list)
raise ValueError(f"Response must be dict. Response is\n{response}.")
if (
(self.generation_key not in self._response)
or (not isinstance(self._response[self.generation_key], list))
or (len(self._response[self.generation_key]) <= 0)
):
raise ValueError(
"Response must be serialized to a dict with a list of choices. "
f"Response is\n{self._response}."
"Response must be serialized to a dict with a nonempty"
f" list of choices. Response is\n{self._response}."
)
if len(self._response["choices"]) > 0:
if "text" not in self._response["choices"][0]:
if self.item_key not in self._response[self.generation_key][0]:
raise ValueError(
"Response must be serialized to a dict with a "
f"list of choices with {self.item_key} field"
)
if (
self.logits_key in self._response[self.generation_key][0]
and self._response[self.generation_key][0][self.logits_key]
):
if not isinstance(
self._response[self.generation_key][0][self.logits_key], list
):
raise ValueError(
"Response must be serialized to a dict with a "
"list of choices with text field"
"list of choices with logprobs field"
)
if isinstance(
self._response[self.generation_key][0][self.item_key], np.ndarray
):
self.item_dtype = str(
self._response[self.generation_key][0][self.item_key].dtype
)
self._cached = cached
self._request_params = request_params
@ -42,9 +94,9 @@ class Response:
def get_response(
self, stop_token: str = "", is_batch: bool = False
) -> Union[str, List[str], None]:
) -> Union[str, List[str], np.ndarray, List[np.ndarray]]:
"""
Get all text results from response.
Get all results from response.
Args:
stop_token: stop token for string generation
@ -53,14 +105,19 @@ class Response:
process_result = (
lambda x: x.strip().split(stop_token)[0] if stop_token else x.strip()
)
if len(self._response["choices"]) == 0:
return None
results = [
process_result(choice["text"]) for choice in self._response["choices"]
extracted_items = [
choice[self.item_key] for choice in self._response[self.generation_key]
]
if len(results) == 1 and not is_batch:
return results[0]
return results
if len(extracted_items) == 0:
return None
if isinstance(extracted_items[0], str):
processed_results = list(map(process_result, extracted_items))
else:
processed_results = extracted_items
if len(processed_results) == 1 and not is_batch:
return processed_results[0]
else:
return processed_results
def serialize(self) -> str:
"""
@ -69,7 +126,7 @@ class Response:
Returns:
serialized response.
"""
return json.dumps(self.to_dict(), sort_keys=True)
return json.dumps(self.to_dict(), sort_keys=True, cls=NumpyArrayEncoder)
@classmethod
def deserialize(cls, value: str) -> "Response":
@ -83,10 +140,19 @@ class Response:
serialized response.
"""
deserialized = json.loads(value)
item_dtype = deserialized["item_dtype"]
if item_dtype:
for choice in deserialized["response"][deserialized["generation_key"]]:
choice[deserialized["item_key"]] = np.array(
choice[deserialized["item_key"]]
).astype(item_dtype)
return cls(
deserialized["response"],
deserialized["cached"],
deserialized["request_params"],
generation_key=deserialized["generation_key"],
logits_key=deserialized["logits_key"],
item_key=deserialized["item_key"],
)
def to_dict(self) -> Dict:
@ -97,6 +163,10 @@ class Response:
dictionary representation of response.
"""
return {
"generation_key": self.generation_key,
"logits_key": self.logits_key,
"item_key": self.item_key,
"item_dtype": self.item_dtype,
"response": self._response,
"cached": self._cached,
"request_params": self._request_params,
@ -117,6 +187,9 @@ class Response:
response["response"],
response["cached"],
response["request_params"],
generation_key=response["generation_key"],
logits_key=response["logits_key"],
item_key=response["item_key"],
)
def __str__(self) -> str:

@ -6,12 +6,7 @@ import uuid
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from manifest.caches.cache import (
key_to_request,
key_to_response,
request_to_key,
response_to_key,
)
from manifest.caches.serializers import Serializer
logging.getLogger("sqlitedict").setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
@ -30,10 +25,11 @@ class Session:
session_id: session id.
"""
manifest_home = Path(os.environ.get("MANIFEST_SESSION_HOME", Path.home()))
manifest_home = Path(os.environ.get("MANIFEST_HOME", Path.home()))
self.db_file = manifest_home / ".manifest" / "session.db"
self.db_file.parent.mkdir(parents=True, exist_ok=True)
self.conn = sqlite3.connect(str(self.db_file))
self.serializer = Serializer()
self._create_table()
if not session_id:
self.session_id = str(uuid.uuid4())
@ -124,8 +120,8 @@ class Session:
query,
self.query_id,
self.session_id,
request_to_key(query_key),
response_to_key(response_key),
self.serializer.request_to_key(query_key),
self.serializer.response_to_key(response_key),
)
self.query_id += 1
return
@ -151,7 +147,10 @@ class Session:
ORDER BY query_id;"""
res = self._execute_query(query, self.session_id, first_query)
parsed_res = [
(key_to_request(pair[0]), key_to_response(pair[1]))
(
self.serializer.key_to_request(pair[0]),
self.serializer.key_to_response(pair[1]),
)
for pair in res.fetchall()
]
return parsed_res

@ -8,8 +8,7 @@ ignore_missing_imports = true
module = [
"deepspeed",
"numpy",
"tqdm",
"tqdm.auto",
"diffusers",
"sqlitedict",
"dill",
"accelerate",

@ -27,20 +27,17 @@ REQUIRES_PYTHON = ">=3.8.0"
VERSION = main_ns["__version__"]
# What packages are required for this module to be executed?
REQUIRED = [
"dill>=0.3.5",
"redis>=4.3.1",
"requests>=2.27.1",
"sqlitedict>=2.0.0",
]
REQUIRED = ["redis>=4.3.1", "requests>=2.27.1", "sqlitedict>=2.0.0", "xxhash>=3.0.0"]
# What packages are optional?
EXTRAS = {
"api": [
"diffusers>=0.6.0",
"Flask>=2.1.2",
"accelerate>=0.10.0",
"transformers>=4.20.0",
"torch>=1.8.0",
"numpy>=1.20.0",
],
"dev": [
"autopep8>=1.6.0",
@ -64,6 +61,8 @@ EXTRAS = {
"types-protobuf>=3.19.21",
"types-python-dateutil>=2.8.16",
"types-setuptools>=57.4.17",
"types-pillow>=9.0.0",
"types-xxhash>=3.0.0",
"sphinx-autobuild",
"twine",
],

@ -33,5 +33,5 @@ def redis_cache():
@pytest.fixture
def session_cache(tmpdir):
"""Session cache dir."""
os.environ["MANIFEST_SESSION_HOME"] = str(tmpdir)
os.environ["MANIFEST_HOME"] = str(tmpdir)
yield Path(tmpdir)

@ -0,0 +1,76 @@
"""Array cache test."""
from pathlib import Path
import numpy as np
import pytest
from manifest.caches.array_cache import ArrayCache
def test_init(tmpdir):
"""Test cache initialization."""
tmpdir = Path(tmpdir)
cache = ArrayCache(tmpdir)
assert (tmpdir / "hash2arrloc.sqlite").exists()
assert cache.cur_file_idx == 0
assert cache.cur_offset == 0
def test_put_get(tmpdir):
"""Test putting and getting."""
cache = ArrayCache(tmpdir)
cache.max_memmap_size = 5
arr = np.random.rand(10, 10)
with pytest.raises(ValueError) as exc_info:
cache.put("key", arr)
assert str(exc_info.value) == ("Array is too large to be cached. Max is 5")
cache.max_memmap_size = 120
cache.put("key", arr)
assert np.allclose(cache.get("key"), arr)
assert cache.cur_file_idx == 0
assert cache.cur_offset == 100
assert cache.hash2arrloc["key"] == {
"file_idx": 0,
"offset": 0,
"flatten_size": 100,
"shape": (10, 10),
}
arr2 = np.random.rand(10, 10)
cache.put("key2", arr2)
assert np.allclose(cache.get("key2"), arr2)
assert cache.cur_file_idx == 1
assert cache.cur_offset == 100
assert cache.hash2arrloc["key2"] == {
"file_idx": 1,
"offset": 0,
"flatten_size": 100,
"shape": (10, 10),
}
cache = ArrayCache(tmpdir)
assert cache.hash2arrloc["key"] == {
"file_idx": 0,
"offset": 0,
"flatten_size": 100,
"shape": (10, 10),
}
assert cache.hash2arrloc["key2"] == {
"file_idx": 1,
"offset": 0,
"flatten_size": 100,
"shape": (10, 10),
}
assert np.allclose(cache.get("key"), arr)
assert np.allclose(cache.get("key2"), arr2)
def test_contains_key(tmpdir):
"""Test contains key."""
cache = ArrayCache(tmpdir)
assert not cache.contains_key("key")
arr = np.random.rand(10, 10)
cache.put("key", arr)
assert cache.contains_key("key")

@ -1,4 +1,5 @@
"""Cache test."""
import numpy as np
import pytest
from redis import Redis
from sqlitedict import SqliteDict
@ -56,18 +57,32 @@ def test_get(sqlite_cache, redis_cache, cache_type):
test_request = {"test": "hello", "testA": "world"}
compute = lambda: {"choices": [{"text": "hello"}]}
response = cache.get(test_request, overwrite_cache=False, compute=compute)
assert response.get_response() == "hello"
assert not response.is_cached()
assert response.get_request() == test_request
# response = cache.get(test_request, overwrite_cache=False, compute=compute)
# assert response.get_response() == "hello"
# assert not response.is_cached()
# assert response.get_request() == test_request
response = cache.get(test_request, overwrite_cache=False, compute=compute)
assert response.get_response() == "hello"
assert response.is_cached()
assert response.get_request() == test_request
# response = cache.get(test_request, overwrite_cache=False, compute=compute)
# assert response.get_response() == "hello"
# assert response.is_cached()
# assert response.get_request() == test_request
response = cache.get(test_request, overwrite_cache=True, compute=compute)
assert response.get_response() == "hello"
# response = cache.get(test_request, overwrite_cache=True, compute=compute)
# assert response.get_response() == "hello"
# assert not response.is_cached()
# assert response.get_request() == test_request
arr = np.random.rand(4, 4)
test_request = {"test": "hello", "testA": "world of images"}
compute = lambda: {"choices": [{"array": arr}]}
# Test array
if cache_type == "sqlite":
cache = SQLiteCache(sqlite_cache, client_name="diffuser")
else:
cache = RedisCache(redis_cache, client_name="diffuser")
response = cache.get(test_request, overwrite_cache=False, compute=compute)
assert np.allclose(response.get_response(), arr)
assert not response.is_cached()
assert response.get_request() == test_request

@ -6,7 +6,7 @@ from subprocess import PIPE, Popen
import pytest
from manifest.api.models.huggingface import HuggingFaceModel
from manifest.api.models.huggingface import TextGenerationModel
NOCUDA = 0
try:
@ -37,7 +37,7 @@ if NOCUDA == 0:
def test_gpt_generate():
"""Test pipeline generation from a gpt model."""
model = HuggingFaceModel(
model = TextGenerationModel(
model_name_or_path="gpt2",
use_accelerate=False,
use_parallelize=False,
@ -82,7 +82,7 @@ def test_gpt_generate():
def test_encdec_generate():
"""Test pipeline generation from a gpt model."""
model = HuggingFaceModel(
model = TextGenerationModel(
model_name_or_path="google/t5-small-lm-adapt",
use_accelerate=False,
use_parallelize=False,
@ -127,7 +127,7 @@ def test_encdec_generate():
def test_gpt_score():
"""Test pipeline generation from a gpt model."""
model = HuggingFaceModel(
model = TextGenerationModel(
model_name_or_path="gpt2",
use_accelerate=False,
use_parallelize=False,
@ -146,7 +146,7 @@ def test_gpt_score():
def test_batch_gpt_generate():
"""Test pipeline generation from a gpt model."""
model = HuggingFaceModel(
model = TextGenerationModel(
model_name_or_path="gpt2",
use_accelerate=False,
use_parallelize=False,
@ -195,7 +195,7 @@ def test_batch_gpt_generate():
def test_batch_encdec_generate():
"""Test pipeline generation from a gpt model."""
model = HuggingFaceModel(
model = TextGenerationModel(
model_name_or_path="google/t5-small-lm-adapt",
use_accelerate=False,
use_parallelize=False,
@ -249,7 +249,7 @@ def test_batch_encdec_generate():
)
def test_gpt_deepspeed_generate():
"""Test deepspeed generation from a gpt model."""
model = HuggingFaceModel(
model = TextGenerationModel(
model_name_or_path="gpt2",
use_accelerate=False,
use_parallelize=False,

@ -1,8 +1,9 @@
"""Manifest test."""
import json
import pytest
from manifest import Manifest, Response
from manifest.caches.cache import request_to_key
from manifest.caches.noop import NoopCache
from manifest.caches.sqlite import SQLiteCache
from manifest.clients.dummy import DummyClient
@ -81,12 +82,13 @@ def test_run(sqlite_cache, session_cache, n, return_response):
res = result
assert (
manifest.cache.get_key(
request_to_key(
json.dumps(
{
"prompt": "This is a prompt",
"engine": "dummy",
"num_results": n,
}
},
sort_keys=True,
)
)
is not None
@ -105,13 +107,14 @@ def test_run(sqlite_cache, session_cache, n, return_response):
res = result
assert (
manifest.cache.get_key(
request_to_key(
json.dumps(
{
"prompt": "This is a prompt",
"engine": "dummy",
"num_results": n,
"run_id": "34",
}
},
sort_keys=True,
)
)
is not None
@ -130,12 +133,13 @@ def test_run(sqlite_cache, session_cache, n, return_response):
res = result
assert (
manifest.cache.get_key(
request_to_key(
json.dumps(
{
"prompt": "Hello is a prompt",
"engine": "dummy",
"num_results": n,
}
},
sort_keys=True,
)
)
is not None
@ -154,12 +158,13 @@ def test_run(sqlite_cache, session_cache, n, return_response):
res = result
assert (
manifest.cache.get_key(
request_to_key(
json.dumps(
{
"prompt": "Hello is a prompt",
"engine": "dummy",
"num_results": n,
}
},
sort_keys=True,
)
)
is not None
@ -234,12 +239,13 @@ def test_choices_run(sqlite_cache, session_cache, return_response):
res = result
assert (
manifest.cache.get_key(
request_to_key(
json.dumps(
{
"prompt": "This is a prompt",
"gold_choices": ["cat", "dog"],
"engine": "dummy",
}
},
sort_keys=True,
)
)
is not None
@ -256,12 +262,13 @@ def test_choices_run(sqlite_cache, session_cache, return_response):
res = result
assert (
manifest.cache.get_key(
request_to_key(
json.dumps(
{
"prompt": "Hello is a prompt",
"gold_choices": ["cat", "dog"],
"engine": "dummy",
}
},
sort_keys=True,
)
)
is not None
@ -283,12 +290,13 @@ def test_choices_run(sqlite_cache, session_cache, return_response):
res = result
assert (
manifest.cache.get_key(
request_to_key(
json.dumps(
{
"prompt": "Hello is a prompt",
"gold_choices": ["cat", "dog"],
"engine": "dummy",
}
},
sort_keys=True,
)
)
is not None
@ -310,12 +318,13 @@ def test_choices_run(sqlite_cache, session_cache, return_response):
res = result
assert (
manifest.cache.get_key(
request_to_key(
json.dumps(
{
"prompt": ["Hello is a prompt", "Hello is a prompt"],
"gold_choices": ["callt", "dog"],
"engine": "dummy",
}
},
sort_keys=True,
)
)
is not None
@ -338,6 +347,10 @@ def test_log_query(session_cache):
"cached": False,
"request_params": query_key,
"response": {"choices": [{"text": "hello"}]},
"generation_key": "choices",
"item_dtype": None,
"item_key": "text",
"logits_key": "logprobs",
}
assert manifest.get_last_queries(1) == [("This is a prompt", "hello")]
assert manifest.get_last_queries(1, return_raw_values=True) == [
@ -357,6 +370,10 @@ def test_log_query(session_cache):
}
response_key = {
"cached": False,
"generation_key": "choices",
"item_dtype": None,
"item_key": "text",
"logits_key": "logprobs",
"request_params": query_key,
"response": {"choices": [{"text": "hello"}, {"text": "hello"}]},
}

@ -1,33 +1,63 @@
"""Request test."""
from manifest import Request
from manifest.request import DiffusionRequest, LMRequest
def test_init():
def test_llm_init():
"""Test request initialization."""
request = Request()
request = LMRequest()
assert request.temperature == 0.7
request = Request(temperature=0.5)
request = LMRequest(temperature=0.5)
assert request.temperature == 0.5
request = Request(**{"temperature": 0.5})
request = LMRequest(**{"temperature": 0.5})
assert request.temperature == 0.5
request = Request(**{"temperature": 0.5, "prompt": "test"})
request = LMRequest(**{"temperature": 0.5, "prompt": "test"})
assert request.temperature == 0.5
assert request.prompt == "test"
def test_diff_init():
"""Test request initialization."""
request = DiffusionRequest()
assert request.height == 512
request = DiffusionRequest(height=128)
assert request.height == 128
request = DiffusionRequest(**{"height": 128})
assert request.height == 128
request = DiffusionRequest(**{"height": 128, "prompt": "test"})
assert request.height == 128
assert request.prompt == "test"
def test_to_dict():
"""Test request to dict."""
request = Request()
request = LMRequest()
dct = request.to_dict()
assert dct == {k: v for k, v in request.dict().items() if v is not None}
keys = {"temperature": ("temp", 0.5)}
# Note the second value is a placeholder for the default value
# It's unused in to_dict
keys = {"temperature": ("temp", 0.7)}
dct = request.to_dict(allowable_keys=keys)
assert dct == {"temp": 0.7, "prompt": ""}
dct = request.to_dict(allowable_keys=keys, add_prompt=False)
assert dct == {"temp": 0.7}
request = DiffusionRequest()
dct = request.to_dict()
assert dct == {k: v for k, v in request.dict().items() if v is not None}
keys = {"height": ("hgt", 512)}
dct = request.to_dict(allowable_keys=keys)
assert dct == {"hgt": 512, "prompt": ""}
dct = request.to_dict(allowable_keys=keys, add_prompt=False)
assert dct == {"hgt": 512}

@ -1,4 +1,5 @@
"""Response test."""
import numpy as np
import pytest
from manifest import Response
@ -8,11 +9,11 @@ def test_init():
"""Test response initialization."""
with pytest.raises(ValueError) as exc_info:
response = Response(4, False, {})
assert str(exc_info.value) == "Response must be str or dict. Response is\n4."
assert str(exc_info.value) == "Response must be dict. Response is\n4."
with pytest.raises(ValueError) as exc_info:
response = Response({"test": "hello"}, False, {})
assert str(exc_info.value) == (
"Response must be serialized to a dict with a list of choices. "
"Response must be serialized to a dict with a nonempty list of choices. "
"Response is\n{'test': 'hello'}."
)
with pytest.raises(ValueError) as exc_info:
@ -21,16 +22,46 @@ def test_init():
"Response must be serialized to a dict "
"with a list of choices with text field"
)
with pytest.raises(ValueError) as exc_info:
response = Response({"choices": []}, False, {})
assert str(exc_info.value) == (
"Response must be serialized to a dict with a nonempty list of choices. "
"Response is\n{'choices': []}."
)
response = Response({"choices": [{"text": "hello"}]}, False, {})
assert response._response == {"choices": [{"text": "hello"}]}
assert response._cached is False
assert response._request_params == {}
assert response.item_dtype is None
response = Response({"choices": [{"text": "hello"}]}, True, {"request": "yoyo"})
assert response._response == {"choices": [{"text": "hello"}]}
assert response._cached is True
assert response._request_params == {"request": "yoyo"}
assert response.item_dtype is None
response = Response(
{"generations": [{"txt": "hello"}], "logits": []},
False,
{},
generation_key="generations",
logits_key="logits",
item_key="txt",
)
assert response._response == {"generations": [{"txt": "hello"}], "logits": []}
assert response._cached is False
assert response._request_params == {}
assert response.item_dtype is None
int_arr = np.random.randint(20, size=(4, 4))
response = Response(
{"choices": [{"array": int_arr}]}, True, {"request": "yoyo"}, item_key="array"
)
assert response._response == {"choices": [{"array": int_arr}]}
assert response._cached is True
assert response._request_params == {"request": "yoyo"}
assert response.item_dtype == "int64"
def test_getters():
@ -45,6 +76,14 @@ def test_getters():
assert response.is_cached() is True
assert response.get_request() == {"request": "yoyo"}
int_arr = np.random.randint(20, size=(4, 4))
response = Response(
{"choices": [{"array": int_arr}]}, True, {"request": "yoyo"}, item_key="array"
)
assert response.get_json_response() == {"choices": [{"array": int_arr}]}
assert response.is_cached() is True
assert response.get_request() == {"request": "yoyo"}
def test_serialize():
"""Test response serialization."""
@ -54,14 +93,31 @@ def test_serialize():
assert deserialized_response.is_cached() is True
assert deserialized_response._request_params == {"request": "yoyo"}
int_arr = np.random.randint(20, size=(4, 4))
response = Response(
{"choices": [{"array": int_arr}]}, True, {"request": "yoyo"}, item_key="array"
)
deserialized_response = Response.deserialize(response.serialize())
assert np.array_equal(
deserialized_response._response["choices"][0]["array"], int_arr
)
assert deserialized_response.is_cached() is True
assert deserialized_response._request_params == {"request": "yoyo"}
float_arr = np.random.randn(4, 4)
response = Response(
{"choices": [{"array": float_arr}]}, True, {"request": "yoyo"}, item_key="array"
)
deserialized_response = Response.deserialize(response.serialize())
assert np.array_equal(
deserialized_response._response["choices"][0]["array"], float_arr
)
assert deserialized_response.is_cached() is True
assert deserialized_response._request_params == {"request": "yoyo"}
def test_get_results():
"""Test response get results."""
response = Response({"choices": []}, True, {"request": "yoyo"})
assert response.get_response() is None
assert response.get_response(stop_token="ll") is None
assert response.get_response(stop_token="ll", is_batch=True) is None
response = Response({"choices": [{"text": "hello"}]}, True, {"request": "yoyo"})
assert response.get_response() == "hello"
assert response.get_response(stop_token="ll") == "he"
@ -75,3 +131,13 @@ def test_get_results():
assert response.get_response() == ["hello", "my", "name"]
assert response.get_response(stop_token="m") == ["hello", "", "na"]
assert response.get_response(stop_token="m", is_batch=True) == ["hello", "", "na"]
float_arr = np.random.randn(4, 4)
response = Response(
{"choices": [{"array": float_arr}, {"array": float_arr}]},
True,
{"request": "yoyo"},
item_key="array",
)
assert response.get_response() == [float_arr, float_arr]
assert response.get_response(stop_token="m") == [float_arr, float_arr]

@ -0,0 +1,19 @@
"""Cache test."""
import json
import numpy as np
from manifest.caches.serializers import ArraySerializer
def test_response_to_key(session_cache):
"""Test array serializer initialization."""
serializer = ArraySerializer()
arr = np.random.rand(4, 4)
res = {"choices": [{"array": arr}]}
key = serializer.response_to_key(res)
key_dct = json.loads(key)
assert isinstance(key_dct["choices"][0]["array"], str)
res2 = serializer.key_to_response(key)
assert np.allclose(arr, res2["choices"][0]["array"])
Loading…
Cancel
Save