mirror of https://github.com/HazyResearch/manifest
Laurel/diffusion (#40)
* Sketch of diffusers added * [WIP] Array caching implemented with end2end diffusion working * [WIP] Make initial pass on CLIP model * [WIP] Get endpoint running for CLIP * Add support for clip images * [chore] merge main * chore: fix xxhash install Co-authored-by: Sabri Eyuboglu <eyuboglu@stanford.edu>laurel/helm
parent
26e440b6a6
commit
6f5b64f0df
@ -0,0 +1,106 @@
|
||||
"""Huggingface model."""
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
from manifest.api.models.model import Model
|
||||
|
||||
|
||||
class DiffuserModel(Model):
|
||||
"""Diffuser model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
model_config: str = None,
|
||||
cache_dir: str = None,
|
||||
device: int = 0,
|
||||
use_accelerate: bool = False,
|
||||
use_parallelize: bool = False,
|
||||
use_bitsandbytes: bool = False,
|
||||
perc_max_gpu_mem_red: float = 1.0,
|
||||
use_fp16: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize model.
|
||||
|
||||
All arguments will be passed in the request from Manifest.
|
||||
|
||||
Args:
|
||||
model_name_or_path: model name string.
|
||||
model_config: model config string.
|
||||
cache_dir: cache directory for model.
|
||||
device: device to use for model.
|
||||
use_accelerate: whether to use accelerate for multi-gpu inference.
|
||||
use_parallelize: use HF default parallelize
|
||||
use_bitsandbytes: use HF bits and bytes
|
||||
perc_max_gpu_mem_red: percent max memory reduction in accelerate
|
||||
use_fp16: use fp16 for model weights.
|
||||
"""
|
||||
if use_accelerate or use_parallelize or use_bitsandbytes:
|
||||
raise ValueError(
|
||||
"Cannot use accelerate or parallelize or bitsandbytes with diffusers"
|
||||
)
|
||||
# Check if providing path
|
||||
self.model_path = model_name_or_path
|
||||
if Path(self.model_path).exists() and Path(self.model_path).is_dir():
|
||||
model_name_or_path = Path(self.model_path).name
|
||||
self.model_name = model_name_or_path
|
||||
print("Model Name:", self.model_name, "Model Path:", self.model_path)
|
||||
dtype = torch.float16 if use_fp16 else None
|
||||
torch_device = (
|
||||
torch.device("cpu")
|
||||
if (device == -1 or not torch.cuda.is_available())
|
||||
else torch.device(f"cuda:{device}")
|
||||
)
|
||||
self.pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
self.model_path,
|
||||
torch_dtype=dtype,
|
||||
revision="fp16" if str(dtype) == "float16" else None,
|
||||
)
|
||||
self.pipeline.safety_checker = None
|
||||
self.pipeline.to(torch_device)
|
||||
|
||||
def get_init_params(self) -> Dict:
|
||||
"""Return init params to determine what model is being used."""
|
||||
return {"model_name": self.model_name, "model_path": self.model_path}
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self, prompt: Union[str, List[str]], **kwargs: Any
|
||||
) -> List[Tuple[Any, float]]:
|
||||
"""
|
||||
Generate the prompt from model.
|
||||
|
||||
Outputs must be generated text and score, not including prompt.
|
||||
|
||||
Args:
|
||||
prompt: promt to generate from.
|
||||
|
||||
Returns:
|
||||
list of generated text (list of length 1 for 1 generation).
|
||||
"""
|
||||
# TODO: Is this correct for getting arguments in?
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
result = self.pipeline(prompt, output_type="np.array", **kwargs)
|
||||
# Return None for logprobs
|
||||
return [(im, None) for im in result["images"]]
|
||||
|
||||
@torch.no_grad()
|
||||
def logits_scoring(
|
||||
self, prompt: Union[str, List[str]], gold_choices: List[str], **kwargs: Any
|
||||
) -> List[Tuple[Any, float]]:
|
||||
"""
|
||||
Given the prompt and gold choices, choose the best choice with max logits.
|
||||
|
||||
Args:
|
||||
prompt: promt to generate from.
|
||||
gold_choices: list of choices to choose from.
|
||||
|
||||
Returns:
|
||||
the returned gold choice
|
||||
"""
|
||||
raise NotImplementedError("Logits scoring not supported for diffusers")
|
@ -0,0 +1,115 @@
|
||||
"""Array cache."""
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
from sqlitedict import SqliteDict
|
||||
|
||||
|
||||
def open_mmap_arr(file: Union[Path, str], size: float) -> np.memmap:
|
||||
"""Open memmap."""
|
||||
if not Path(file).exists():
|
||||
mode = "w+"
|
||||
else:
|
||||
mode = "r+"
|
||||
arr = np.memmap( # type: ignore
|
||||
str(file),
|
||||
dtype=np.float32, # This means we only support float 32
|
||||
mode=mode,
|
||||
shape=size,
|
||||
)
|
||||
return arr
|
||||
|
||||
|
||||
class ArrayCache:
|
||||
"""Array cache."""
|
||||
|
||||
def __init__(self, folder: Union[str, Path]) -> None:
|
||||
"""
|
||||
Initialize the array writer.
|
||||
|
||||
Args:
|
||||
folder: folder to write to.
|
||||
"""
|
||||
self.folder = Path(folder)
|
||||
self.folder.mkdir(exist_ok=True, parents=True)
|
||||
self.hash2arrloc = SqliteDict(
|
||||
self.folder / "hash2arrloc.sqlite", autocommit=True
|
||||
)
|
||||
# Approx 1GB (I think)
|
||||
self.max_memmap_size = 20480000
|
||||
self.cur_file_idx = 0
|
||||
# Get the last file idx used
|
||||
for key in self.hash2arrloc:
|
||||
file_data = self.hash2arrloc[key]
|
||||
if file_data["file_idx"] > self.cur_file_idx:
|
||||
self.cur_file_idx = file_data["file_idx"]
|
||||
self.cur_memmap = open_mmap_arr(
|
||||
self.folder / f"{self.cur_file_idx}.npy",
|
||||
self.max_memmap_size,
|
||||
)
|
||||
# Make sure there is space left in the memmap
|
||||
non_zero = np.nonzero(self.cur_memmap)[0]
|
||||
if len(non_zero) > 0:
|
||||
self.cur_offset = int(np.max(non_zero) + 1)
|
||||
else:
|
||||
self.cur_offset = 0
|
||||
# If no space, make a new memmap
|
||||
if self.cur_offset == self.max_memmap_size:
|
||||
self.cur_file_idx += 1
|
||||
self.cur_memmap = open_mmap_arr(
|
||||
self.folder / f"{self.cur_file_idx}.npy",
|
||||
self.max_memmap_size,
|
||||
)
|
||||
self.cur_offset = 0
|
||||
|
||||
def contains_key(self, key: str) -> bool:
|
||||
"""
|
||||
Check if the key is in the cache.
|
||||
|
||||
Args:
|
||||
key: key to check.
|
||||
|
||||
Returns:
|
||||
True if the key is in the cache.
|
||||
"""
|
||||
return key in self.hash2arrloc
|
||||
|
||||
def put(self, key: str, arr: np.ndarray) -> None:
|
||||
"""Save array in store and associate location with key."""
|
||||
# Check if there is space in the memmap
|
||||
arr_shape = arr.shape
|
||||
arr = arr.flatten()
|
||||
if len(arr) > self.max_memmap_size:
|
||||
raise ValueError(
|
||||
f"Array is too large to be cached. Max is {self.max_memmap_size}"
|
||||
)
|
||||
if self.cur_offset + len(arr) > self.max_memmap_size:
|
||||
self.cur_file_idx += 1
|
||||
self.cur_memmap = open_mmap_arr(
|
||||
self.folder / f"{self.cur_file_idx}.npy",
|
||||
self.max_memmap_size,
|
||||
)
|
||||
self.cur_offset = 0
|
||||
self.cur_memmap[self.cur_offset : self.cur_offset + len(arr)] = arr
|
||||
self.cur_memmap.flush()
|
||||
self.hash2arrloc[key] = {
|
||||
"file_idx": self.cur_file_idx,
|
||||
"offset": self.cur_offset,
|
||||
"flatten_size": len(arr),
|
||||
"shape": arr_shape,
|
||||
}
|
||||
self.cur_offset += len(arr)
|
||||
return
|
||||
|
||||
def get(self, key: str) -> np.ndarray:
|
||||
"""Get array associated with location from key."""
|
||||
file_data = self.hash2arrloc[key]
|
||||
memmap = open_mmap_arr(
|
||||
self.folder / f"{file_data['file_idx']}.npy",
|
||||
self.max_memmap_size,
|
||||
)
|
||||
arr = memmap[
|
||||
file_data["offset"] : file_data["offset"] + file_data["flatten_size"]
|
||||
]
|
||||
return arr.reshape(file_data["shape"])
|
@ -0,0 +1,140 @@
|
||||
"""Serializer."""
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import xxhash
|
||||
|
||||
from manifest.caches.array_cache import ArrayCache
|
||||
|
||||
|
||||
class Serializer:
|
||||
"""Serializer."""
|
||||
|
||||
def request_to_key(self, request: Dict) -> str:
|
||||
"""
|
||||
Normalize a request into a key.
|
||||
|
||||
Args:
|
||||
request: request to normalize.
|
||||
|
||||
Returns:
|
||||
normalized key.
|
||||
"""
|
||||
return json.dumps(request, sort_keys=True)
|
||||
|
||||
def key_to_request(self, key: str) -> Dict:
|
||||
"""
|
||||
Convert the normalized version to the request.
|
||||
|
||||
Args:
|
||||
key: normalized key to convert.
|
||||
|
||||
Returns:
|
||||
unnormalized request dict.
|
||||
"""
|
||||
return json.loads(key)
|
||||
|
||||
def response_to_key(self, response: Dict) -> str:
|
||||
"""
|
||||
Normalize a response into a key.
|
||||
|
||||
Args:
|
||||
response: response to normalize.
|
||||
|
||||
Returns:
|
||||
normalized key.
|
||||
"""
|
||||
return json.dumps(response, sort_keys=True)
|
||||
|
||||
def key_to_response(self, key: str) -> Dict:
|
||||
"""
|
||||
Convert the normalized version to the response.
|
||||
|
||||
Args:
|
||||
key: normalized key to convert.
|
||||
|
||||
Returns:
|
||||
unnormalized response dict.
|
||||
"""
|
||||
return json.loads(key)
|
||||
|
||||
|
||||
class ArraySerializer(Serializer):
|
||||
"""Serializer for array."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
Initialize array serializer.
|
||||
|
||||
We don't want to cache the array. We hash the value and
|
||||
store the array in a memmap file. Store filename/offsets
|
||||
in sqlitedict to keep track of hash -> array.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.hash = xxhash.xxh64()
|
||||
manifest_home = Path(os.environ.get("MANIFEST_HOME", Path.home()))
|
||||
cache_folder = manifest_home / ".manifest" / "array_cache"
|
||||
self.writer = ArrayCache(cache_folder)
|
||||
|
||||
def response_to_key(self, response: Dict) -> str:
|
||||
"""
|
||||
Normalize a response into a key.
|
||||
|
||||
Convert arrays to hash string for cache key.
|
||||
|
||||
Args:
|
||||
response: response to normalize.
|
||||
|
||||
Returns:
|
||||
normalized key.
|
||||
"""
|
||||
# Assume response is a dict with keys "choices" -> List dicts
|
||||
# with keys "array".
|
||||
choices = response["choices"]
|
||||
# We don't want to modify the response in place
|
||||
# but we want to avoid calling deepcopy on an array
|
||||
del response["choices"]
|
||||
response_copy = response.copy()
|
||||
response["choices"] = choices
|
||||
response_copy["choices"] = []
|
||||
for choice in choices:
|
||||
if "array" not in choice:
|
||||
raise ValueError(
|
||||
f"Choice with keys {choice.keys()} does not have array key."
|
||||
)
|
||||
arr = choice["array"]
|
||||
# Avoid copying an array
|
||||
del choice["array"]
|
||||
new_choice = choice.copy()
|
||||
choice["array"] = arr
|
||||
|
||||
self.hash.update(arr)
|
||||
hash_str = self.hash.hexdigest()
|
||||
self.hash.reset()
|
||||
new_choice["array"] = hash_str
|
||||
response_copy["choices"].append(new_choice)
|
||||
if not self.writer.contains_key(hash_str):
|
||||
self.writer.put(hash_str, arr)
|
||||
return json.dumps(response_copy, sort_keys=True)
|
||||
|
||||
def key_to_response(self, key: str) -> Dict:
|
||||
"""
|
||||
Convert the normalized version to the response.
|
||||
|
||||
Convert the hash string keys to the arrays.
|
||||
|
||||
Args:
|
||||
key: normalized key to convert.
|
||||
|
||||
Returns:
|
||||
unnormalized response dict.
|
||||
"""
|
||||
response = json.loads(key)
|
||||
for choice in response["choices"]:
|
||||
hash_str = choice["array"]
|
||||
choice["array"] = self.writer.get(hash_str)
|
||||
return response
|
@ -0,0 +1,92 @@
|
||||
"""Hugging Face client."""
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
from manifest.clients.client import Client
|
||||
from manifest.request import DiffusionRequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DiffuserClient(Client):
|
||||
"""Diffuser client."""
|
||||
|
||||
# User param -> (client param, default value)
|
||||
PARAMS = {
|
||||
"num_inference_steps": ("num_inference_steps", 50),
|
||||
"height": ("height", 512),
|
||||
"width": ("width", 512),
|
||||
"n": ("num_images_per_prompt", 1),
|
||||
"guidance_scale": ("guidance_scale", 7.5),
|
||||
"eta": ("eta", 0.0),
|
||||
}
|
||||
REQUEST_CLS = DiffusionRequest
|
||||
|
||||
def connect(
|
||||
self,
|
||||
connection_str: Optional[str] = None,
|
||||
client_args: Dict[str, Any] = {},
|
||||
) -> None:
|
||||
"""
|
||||
Connect to the Diffuser url.
|
||||
|
||||
Arsg:
|
||||
connection_str: connection string.
|
||||
client_args: client arguments.
|
||||
"""
|
||||
self.host = connection_str.rstrip("/")
|
||||
for key in self.PARAMS:
|
||||
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
|
||||
self.model_params = self.get_model_params()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the client."""
|
||||
pass
|
||||
|
||||
def get_generation_url(self) -> str:
|
||||
"""Get generation URL."""
|
||||
return self.host + "/completions"
|
||||
|
||||
def get_generation_header(self) -> Dict[str, str]:
|
||||
"""
|
||||
Get generation header.
|
||||
|
||||
Returns:
|
||||
header.
|
||||
"""
|
||||
return {}
|
||||
|
||||
def supports_batch_inference(self) -> bool:
|
||||
"""Return whether the client supports batch inference."""
|
||||
return True
|
||||
|
||||
def get_model_params(self) -> Dict:
|
||||
"""
|
||||
Get model params.
|
||||
|
||||
By getting model params from the server, we can add to request
|
||||
and make sure cache keys are unique to model.
|
||||
|
||||
Returns:
|
||||
model params.
|
||||
"""
|
||||
res = requests.post(self.host + "/params")
|
||||
return res.json()
|
||||
|
||||
def format_response(self, response: Dict) -> Dict[str, Any]:
|
||||
"""
|
||||
Format response to dict.
|
||||
|
||||
Args:
|
||||
response: response
|
||||
|
||||
Return:
|
||||
response as dict
|
||||
"""
|
||||
# Convert array to np.array
|
||||
for choice in response["choices"]:
|
||||
choice["array"] = np.array(choice["array"])
|
||||
return response
|
@ -0,0 +1,76 @@
|
||||
"""Array cache test."""
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from manifest.caches.array_cache import ArrayCache
|
||||
|
||||
|
||||
def test_init(tmpdir):
|
||||
"""Test cache initialization."""
|
||||
tmpdir = Path(tmpdir)
|
||||
cache = ArrayCache(tmpdir)
|
||||
assert (tmpdir / "hash2arrloc.sqlite").exists()
|
||||
assert cache.cur_file_idx == 0
|
||||
assert cache.cur_offset == 0
|
||||
|
||||
|
||||
def test_put_get(tmpdir):
|
||||
"""Test putting and getting."""
|
||||
cache = ArrayCache(tmpdir)
|
||||
cache.max_memmap_size = 5
|
||||
arr = np.random.rand(10, 10)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
cache.put("key", arr)
|
||||
assert str(exc_info.value) == ("Array is too large to be cached. Max is 5")
|
||||
|
||||
cache.max_memmap_size = 120
|
||||
cache.put("key", arr)
|
||||
assert np.allclose(cache.get("key"), arr)
|
||||
assert cache.cur_file_idx == 0
|
||||
assert cache.cur_offset == 100
|
||||
assert cache.hash2arrloc["key"] == {
|
||||
"file_idx": 0,
|
||||
"offset": 0,
|
||||
"flatten_size": 100,
|
||||
"shape": (10, 10),
|
||||
}
|
||||
|
||||
arr2 = np.random.rand(10, 10)
|
||||
cache.put("key2", arr2)
|
||||
assert np.allclose(cache.get("key2"), arr2)
|
||||
assert cache.cur_file_idx == 1
|
||||
assert cache.cur_offset == 100
|
||||
assert cache.hash2arrloc["key2"] == {
|
||||
"file_idx": 1,
|
||||
"offset": 0,
|
||||
"flatten_size": 100,
|
||||
"shape": (10, 10),
|
||||
}
|
||||
|
||||
cache = ArrayCache(tmpdir)
|
||||
assert cache.hash2arrloc["key"] == {
|
||||
"file_idx": 0,
|
||||
"offset": 0,
|
||||
"flatten_size": 100,
|
||||
"shape": (10, 10),
|
||||
}
|
||||
assert cache.hash2arrloc["key2"] == {
|
||||
"file_idx": 1,
|
||||
"offset": 0,
|
||||
"flatten_size": 100,
|
||||
"shape": (10, 10),
|
||||
}
|
||||
assert np.allclose(cache.get("key"), arr)
|
||||
assert np.allclose(cache.get("key2"), arr2)
|
||||
|
||||
|
||||
def test_contains_key(tmpdir):
|
||||
"""Test contains key."""
|
||||
cache = ArrayCache(tmpdir)
|
||||
assert not cache.contains_key("key")
|
||||
arr = np.random.rand(10, 10)
|
||||
cache.put("key", arr)
|
||||
assert cache.contains_key("key")
|
@ -1,33 +1,63 @@
|
||||
"""Request test."""
|
||||
from manifest import Request
|
||||
from manifest.request import DiffusionRequest, LMRequest
|
||||
|
||||
|
||||
def test_init():
|
||||
def test_llm_init():
|
||||
"""Test request initialization."""
|
||||
request = Request()
|
||||
request = LMRequest()
|
||||
assert request.temperature == 0.7
|
||||
|
||||
request = Request(temperature=0.5)
|
||||
request = LMRequest(temperature=0.5)
|
||||
assert request.temperature == 0.5
|
||||
|
||||
request = Request(**{"temperature": 0.5})
|
||||
request = LMRequest(**{"temperature": 0.5})
|
||||
assert request.temperature == 0.5
|
||||
|
||||
request = Request(**{"temperature": 0.5, "prompt": "test"})
|
||||
request = LMRequest(**{"temperature": 0.5, "prompt": "test"})
|
||||
assert request.temperature == 0.5
|
||||
assert request.prompt == "test"
|
||||
|
||||
|
||||
def test_diff_init():
|
||||
"""Test request initialization."""
|
||||
request = DiffusionRequest()
|
||||
assert request.height == 512
|
||||
|
||||
request = DiffusionRequest(height=128)
|
||||
assert request.height == 128
|
||||
|
||||
request = DiffusionRequest(**{"height": 128})
|
||||
assert request.height == 128
|
||||
|
||||
request = DiffusionRequest(**{"height": 128, "prompt": "test"})
|
||||
assert request.height == 128
|
||||
assert request.prompt == "test"
|
||||
|
||||
|
||||
def test_to_dict():
|
||||
"""Test request to dict."""
|
||||
request = Request()
|
||||
request = LMRequest()
|
||||
dct = request.to_dict()
|
||||
|
||||
assert dct == {k: v for k, v in request.dict().items() if v is not None}
|
||||
|
||||
keys = {"temperature": ("temp", 0.5)}
|
||||
# Note the second value is a placeholder for the default value
|
||||
# It's unused in to_dict
|
||||
keys = {"temperature": ("temp", 0.7)}
|
||||
dct = request.to_dict(allowable_keys=keys)
|
||||
assert dct == {"temp": 0.7, "prompt": ""}
|
||||
|
||||
dct = request.to_dict(allowable_keys=keys, add_prompt=False)
|
||||
assert dct == {"temp": 0.7}
|
||||
|
||||
request = DiffusionRequest()
|
||||
dct = request.to_dict()
|
||||
|
||||
assert dct == {k: v for k, v in request.dict().items() if v is not None}
|
||||
|
||||
keys = {"height": ("hgt", 512)}
|
||||
dct = request.to_dict(allowable_keys=keys)
|
||||
assert dct == {"hgt": 512, "prompt": ""}
|
||||
|
||||
dct = request.to_dict(allowable_keys=keys, add_prompt=False)
|
||||
assert dct == {"hgt": 512}
|
||||
|
@ -0,0 +1,19 @@
|
||||
"""Cache test."""
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
|
||||
from manifest.caches.serializers import ArraySerializer
|
||||
|
||||
|
||||
def test_response_to_key(session_cache):
|
||||
"""Test array serializer initialization."""
|
||||
serializer = ArraySerializer()
|
||||
arr = np.random.rand(4, 4)
|
||||
res = {"choices": [{"array": arr}]}
|
||||
key = serializer.response_to_key(res)
|
||||
key_dct = json.loads(key)
|
||||
assert isinstance(key_dct["choices"][0]["array"], str)
|
||||
|
||||
res2 = serializer.key_to_response(key)
|
||||
assert np.allclose(arr, res2["choices"][0]["array"])
|
Loading…
Reference in New Issue