diff --git a/imaginairy/cli/imagine.py b/imaginairy/cli/imagine.py index 573006b..6132c57 100644 --- a/imaginairy/cli/imagine.py +++ b/imaginairy/cli/imagine.py @@ -62,6 +62,7 @@ from imaginairy.cli.shared import ( "inpaint", "colorize", "qrcode", + "densepose", ] ), help="how the control image is used as signal", diff --git a/imaginairy/config.py b/imaginairy/config.py index 76bc7ec..75af302 100644 --- a/imaginairy/config.py +++ b/imaginairy/config.py @@ -168,10 +168,17 @@ MODEL_WEIGHT_CONFIGS = [ defaults={"negative_prompt": DEFAULT_NEGATIVE_PROMPT}, ), ModelWeightsConfig( - name="Redshift Diffusion", - aliases=["redshift-diffusion", "red", "redshift-diffusion-15", "red15"], + name="Miniaturus Potentia V1.2", + aliases=[ + "miniaturuspotentia", + "potentia", + "miniaturuspotentia12", + "mp12", + "mp", + "potentia12", + ], architecture=MODEL_ARCHITECTURE_LOOKUP["sd15"], - weights_location="https://huggingface.co/nitrosocke/redshift-diffusion/tree/80837fe18df05807861ab91c3bad3693c9342e4c/", + weights_location="https://huggingface.co/dataautogpt3/Miniaturus_PotentiaV1.2/tree/7ef539518ad5ad591c45f0b920050883f7e51e83/", defaults={"negative_prompt": DEFAULT_NEGATIVE_PROMPT}, ), # SDXL Weights @@ -338,6 +345,13 @@ CONTROL_CONFIGS = [ weights_location="https://huggingface.co/monster-labs/control_v1p_sd15_qrcode_monster/resolve/4a946e610f670c4cd6cf46b8641fca190e4f56c4/diffusion_pytorch_model.safetensors", aliases=["qrcode"], ), + ControlConfig( + name="DensePose", + control_type="densepose", + config_path="configs/control-net-v15.yaml", + weights_location="https://huggingface.co/zcxu-eric/MagicAnimate/resolve/3d80ae8c50b289e55ee68deecc83afaab9c6a382/densepose_controlnet/diffusion_pytorch_model.safetensors?download=true", + aliases=["densepose"], + ), ] CONTROL_CONFIG_SHORTCUTS: dict[str, ControlConfig] = {} @@ -398,7 +412,7 @@ IP_ADAPTER_WEIGHT_LOCATIONS = { }, } SD21_UNCLIP_WEIGHTS_URL = "https://huggingface.co/stabilityai/stable-diffusion-2-1-unclip/resolve/e99f66a92bdcd1b0fb0d4b6a9b81b3b37d8bea44/image_encoder/model.fp16.safetensors" - +DENSEPOSE_REPO_URL = "https://huggingface.co/LayerNorm/DensePose-TorchScript-with-hint-image/resolve/65446422ea6225b9d72f93f3d2e2ad55e78b0b78" SOLVER_TYPE_NAMES = [s.aliases[0] for s in SOLVER_CONFIGS] diff --git a/imaginairy/img_processors/control_modes.py b/imaginairy/img_processors/control_modes.py index 3acfb9f..699ccbd 100644 --- a/imaginairy/img_processors/control_modes.py +++ b/imaginairy/img_processors/control_modes.py @@ -139,6 +139,24 @@ def create_pose_map(img_t: "Tensor"): return pose_t +def create_densepose_map(img_t: "Tensor") -> "Tensor": + import torch + + from imaginairy.img_processors.densepose import generate_densepose_image + + img_np = generate_densepose_image(img_t) + + img_t = ( + torch.tensor(img_np, dtype=torch.float) + if not isinstance(img_np, torch.Tensor) + else img_np.float() + ) + img_t /= 255.0 + img_t = img_t.permute(2, 0, 1).unsqueeze(0) + + return img_t + + def make_noise_disk(H: int, W: int, C: int, F: int) -> "np.ndarray": import cv2 import numpy as np @@ -312,4 +330,5 @@ CONTROL_MODES: Dict[str, FunctionType] = { "details": noop, "colorize": to_grayscale, "qrcode": adaptive_threshold_binarize, + "densepose": create_densepose_map, } diff --git a/imaginairy/img_processors/densepose.py b/imaginairy/img_processors/densepose.py new file mode 100644 index 0000000..da0a41b --- /dev/null +++ b/imaginairy/img_processors/densepose.py @@ -0,0 +1,653 @@ +# adapted from https://github.com/Mikubill/sd-webui-controlnet/blob/0b90426254debf78bfc09d88c064d2caf0935282/annotator/densepose/densepose.py +import logging +import math +from enum import IntEnum +from functools import lru_cache +from typing import List, Tuple, Union + +import cv2 +import numpy as np +import torch +from torch.nn import functional as F + +from imaginairy import config +from imaginairy.utils.downloads import get_cached_url_path + +logger = logging.getLogger(__name__) + +N_PART_LABELS = 24 + + +_RawBoxType = Union[List[float], Tuple[float, ...], torch.Tensor, np.ndarray] +IntTupleBox = Tuple[int, int, int, int] + + +def safer_memory(x): + # Fix many MAC/AMD problems + return np.ascontiguousarray(x.copy()).copy() + + +def pad64(x): + return int(np.ceil(float(x) / 64.0) * 64 - x) + + +def resize_image_with_pad_torch( + img, resolution, upscale_method="bicubic", mode="constant" +): + B, C, H_raw, W_raw = img.shape + k = float(resolution) / float(min(H_raw, W_raw)) + H_target = int(math.ceil(float(H_raw) * k)) + W_target = int(math.ceil(float(W_raw) * k)) + + if k > 1: + img = F.interpolate( + img, + size=(H_target, W_target), + mode=upscale_method, + align_corners=False, + ) + else: + img = F.interpolate(img, size=(H_target, W_target), mode="area") + + H_pad, W_pad = pad64(H_target), pad64(W_target) + # print(f"image after resize but before padding: {img.shape}") + img_padded = F.pad(img, (0, W_pad, 0, H_pad), mode=mode) + + def remove_pad(x): + # print( + # f"remove_pad: x.shape: {x.shape}. H_target: {H_target}, W_target: {W_target}" + # ) + return safer_memory(x[:H_target, :W_target, ...]) + + return img_padded, remove_pad + + +def HWC3(x: np.ndarray) -> np.ndarray: + assert x.dtype == np.uint8 + if x.ndim == 2: + x = x[:, :, None] + assert x.ndim == 3 + H, W, C = x.shape + assert C == 1 or C == 3 or C == 4 + if C == 3: + return x + if C == 1: + return np.concatenate([x, x, x], axis=2) + if C == 4: + color = x[:, :, 0:3].astype(np.float32) + alpha = x[:, :, 3:4].astype(np.float32) / 255.0 + y = color * alpha + 255.0 * (1.0 - alpha) + y = y.clip(0, 255).astype(np.uint8) + return y + raise RuntimeError("unreachable") + + +@lru_cache(maxsize=1) +def get_densepose_model( + filename="densepose_r101_fpn_dl.torchscript", base_url=config.DENSEPOSE_REPO_URL +): + import torchvision # noqa + + url = f"{base_url}/{filename}" + torchscript_model_path = get_cached_url_path(url) + logger.info(f"Loading densepose model {url} from {torchscript_model_path}") + densepose = torch.jit.load(torchscript_model_path, map_location="cpu") + return densepose + + +@lru_cache(maxsize=1) +def get_segment_result_visualizer(): + return DensePoseMaskedColormapResultsVisualizer( + alpha=1, + data_extractor=_extract_i_from_iuvarr, + segm_extractor=_extract_i_from_iuvarr, + val_scale=255.0 / N_PART_LABELS, + ) + + +def mask_to_bbox(mask_img_t): + m = mask_img_t.nonzero() + if m.numel() == 0: + return None + y0 = torch.min(m[:, 0]) + y1 = torch.max(m[:, 0]) + x0 = torch.min(m[:, 1]) + x1 = torch.max(m[:, 1]) + return x0, y0, x1, y1 + + +def pad_bbox(bbox, max_height, max_width, pad=1): + x0, y0, x1, y1 = bbox + x0 = max(0, x0 - pad) + y0 = max(0, y0 - pad) + x1 = min(max_width, x1 + pad) + y1 = min(max_height, y1 + pad) + return x0, y0, x1, y1 + + +def square_bbox(bbox, max_height, max_width): + """ + Adjusts the bounding box to make it as close to a square as possible while + ensuring it does not exceed the max_size of the image and still includes + the original bounding box contents. + + Args: + - bbox: A tuple of (x0, y0, x1, y1) for the original bounding box. + - max_size: A tuple of (max_width, max_height) representing the image size. + + Returns: + - A tuple of (x0, y0, x1, y1) for the adjusted bounding box. + """ + x0, y0, x1, y1 = bbox + width = x1 - x0 + height = y1 - y0 + + # Determine how much to adjust to make the bounding box square + if width > height: + diff = width - height + half_diff = diff // 2 + y0 = max(0, y0 - half_diff) + y1 = min(max_height, y1 + half_diff + (diff % 2)) # Add 1 if diff is odd + elif height > width: + diff = height - width + half_diff = diff // 2 + x0 = max(0, x0 - half_diff) + x1 = min(max_width, x1 + half_diff + (diff % 2)) # Add 1 if diff is odd + + # Ensure the bounding box is within the image boundaries + x0 = max(0, min(x0, max_width - 1)) + y0 = max(0, min(y0, max_height - 1)) + x1 = max(0, min(x1, max_width)) + y1 = max(0, min(y1, max_height)) + + return x0, y0, x1, y1 + + +def _np_to_t(img_np): + img_t = torch.from_numpy(img_np) / 255.0 + img_t = img_t.permute(2, 0, 1) + img_t = img_t.unsqueeze(0) + return img_t + + +def generate_densepose_image( + img: torch.Tensor, + detect_resolution=512, + upscale_method="bicubic", + cmap="viridis", + double_pass=False, +): + assert_tensor_float_11_bchw(img) + input_h, input_w = img.shape[-2:] + if double_pass: + first_densepose_img_np = _generate_densepose_image( + img, detect_resolution, upscale_method, cmap, adapt_viridis_bg=False + ) + first_densepose_img_t = _np_to_t(first_densepose_img_np) + # convert the densepose image into a mask (every color other than black is part of the mask) + densepose_img_mask = first_densepose_img_t[0].sum(dim=0) > 0 + # print(f"Mask shape: {densepose_img_mask.shape}") + # bbox = masks_to_boxes(densepose_img_mask.unsqueeze(0)).to(torch.uint8) + # crop image by bbox + bbox = mask_to_bbox(densepose_img_mask) + # print(f"bbox: {bbox}") + + if bbox is None: + densepose_np = first_densepose_img_np + else: + bbox = pad_bbox(bbox, max_height=input_h, max_width=input_w, pad=10) + # print(f"padded bbox: {bbox}") + bbox = square_bbox(bbox, max_height=input_h, max_width=input_w) + # print(f"boxed bbox: {bbox}") + x0, y0, x1, y1 = bbox + + cropped_img = img[:, :, y0:y1, x0:x1] + # print(f"cropped_img shape: {cropped_img.shape}") + + densepose_np = _generate_densepose_image( + cropped_img, + detect_resolution, + upscale_method, + cmap, + adapt_viridis_bg=False, + ) + # print(f"cropped densepose_np shape: {densepose_np.shape}") + # print( + # f"pasting into first_densepose_img_np shape: {first_densepose_img_np.shape} at {y0}:{y1}, {x0}:{x1}" + # ) + # paste denspose_np back into first_densepose_img_np using bbox + first_densepose_img_np[y0:y1, x0:x1] = densepose_np + densepose_np = first_densepose_img_np + else: + densepose_np = _generate_densepose_image( + img, detect_resolution, upscale_method, cmap, adapt_viridis_bg=False + ) + + if cmap == "viridis": + densepose_np[:, :, 0][densepose_np[:, :, 0] == 0] = 68 + densepose_np[:, :, 1][densepose_np[:, :, 1] == 0] = 1 + densepose_np[:, :, 2][densepose_np[:, :, 2] == 0] = 84 + + return densepose_np + + +def _generate_densepose_image( + img: torch.Tensor, + detect_resolution=512, + upscale_method="bicubic", + cmap="viridis", + adapt_viridis_bg=True, +) -> np.ndarray: + assert_tensor_float_11_bchw(img) + input_h, input_w = img.shape[-2:] + # print(f"input_h: {input_h}, input_w: {input_w}") + img, remove_pad = resize_image_with_pad_torch( + img, detect_resolution, upscale_method + ) + img = ((img + 1.0) * 127.5).to(torch.uint8) + assert_tensor_uint8_255_bchw(img) + H, W = img.shape[-2:] + # print(f"reduced input img size (with padding): h{H}xw{W}") + hint_image_canvas = np.zeros([H, W], dtype=np.uint8) + hint_image_canvas = np.tile(hint_image_canvas[:, :, np.newaxis], [1, 1, 3]) + densepose_model = get_densepose_model() + pred_boxes, coarse_seg, fine_segm, u, v = densepose_model(img.squeeze(0)) + densepose_results = list( + map( + densepose_chart_predictor_output_to_result, + pred_boxes, + coarse_seg, + fine_segm, + u, + v, + ) + ) + cmaps = { + "viridis": cv2.COLORMAP_VIRIDIS, + "parula": cv2.COLORMAP_PARULA, + "jet": cv2.COLORMAP_JET, + } + cv2_cmap = cmaps.get(cmap, cv2.COLORMAP_PARULA) + result_visualizer = get_segment_result_visualizer() + result_visualizer.mask_visualizer.cmap = cv2_cmap + hint_image = result_visualizer.visualize(hint_image_canvas, densepose_results) + hint_image = cv2.cvtColor(hint_image, cv2.COLOR_BGR2RGB) + + if cv2_cmap == cv2.COLORMAP_VIRIDIS and adapt_viridis_bg: + hint_image[:, :, 0][hint_image[:, :, 0] == 0] = 68 + hint_image[:, :, 1][hint_image[:, :, 1] == 0] = 1 + hint_image[:, :, 2][hint_image[:, :, 2] == 0] = 84 + # print(f"hint_image shape: {hint_image.shape}") + detected_map = remove_pad(HWC3(hint_image)) + # print(f"detected_map shape (padding removed): {detected_map.shape}") + # print(f"Resizing detected_map to original size: {input_w}x{input_h}") + # if map is smaller than input size, scale it up + if detected_map.shape[0] < input_h or detected_map.shape[1] < input_w: + detected_map = cv2.resize( + detected_map, (input_w, input_h), interpolation=cv2.INTER_NEAREST + ) + else: + # scale it down + detected_map = cv2.resize( + detected_map, (input_w, input_h), interpolation=cv2.INTER_AREA + ) + # print(f"detected_map shape (resized to original): {detected_map.shape}") + return detected_map + + +def assert_ndarray_uint8_255_hwc(img): + # assert input_image is ndarray with colors 0-255 + assert img.dtype == np.uint8 + assert img.ndim == 3 + assert img.shape[2] == 3 + assert img.max() <= 255 + assert img.min() >= 0 + + +def assert_tensor_uint8_255_bchw(img): + # assert input_image is a PyTorch tensor with colors 0-255 and dimensions (C, H, W) + assert isinstance(img, torch.Tensor) + assert img.dtype == torch.uint8 + assert img.ndim == 4 + assert img.shape[1] == 3 + assert img.max() <= 255 + assert img.min() >= 0 + + +def assert_tensor_float_11_bchw(img): + # assert input_image is a PyTorch tensor with colors -1 to 1 and dimensions (C, H, W) + if not isinstance(img, torch.Tensor): + msg = f"Input image must be a PyTorch tensor, but got {type(img)}" + raise TypeError(msg) + + if img.dtype not in (torch.float32, torch.float64, torch.float16): + msg = f"Input image must be a float tensor, but got {img.dtype}" + raise ValueError(msg) + + if img.ndim != 4: + msg = f"Input image must be 4D (B, C, H, W), but got {img.ndim}D" + raise ValueError(msg) + + if img.shape[1] != 3: + msg = f"Input image must have 3 channels, but got {img.shape[1]}" + raise ValueError(msg) + if img.max() > 1 or img.min() < -1: + msg = f"Input image must have values in [-1, 1], but got {img.min()} .. {img.max()}" + raise ValueError(msg) + + +class BoxMode(IntEnum): + """ + Enum of different ways to represent a box. + """ + + XYXY_ABS = 0 + """ + (x0, y0, x1, y1) in absolute floating points coordinates. + The coordinates in range [0, width or height]. + """ + XYWH_ABS = 1 + """ + (x0, y0, w, h) in absolute floating points coordinates. + """ + XYXY_REL = 2 + """ + Not yet supported! + (x0, y0, x1, y1) in range [0, 1]. They are relative to the size of the image. + """ + XYWH_REL = 3 + """ + Not yet supported! + (x0, y0, w, h) in range [0, 1]. They are relative to the size of the image. + """ + XYWHA_ABS = 4 + """ + (xc, yc, w, h, a) in absolute floating points coordinates. + (xc, yc) is the center of the rotated box, and the angle a is in degrees ccw. + """ + + @staticmethod + def convert( + box: _RawBoxType, from_mode: "BoxMode", to_mode: "BoxMode" + ) -> _RawBoxType: + """ + Args: + box: can be a k-tuple, k-list or an Nxk array/tensor, where k = 4 or 5 + from_mode, to_mode (BoxMode) + + Returns: + The converted box of the same type. + """ + if from_mode == to_mode: + return box + + original_type = type(box) + is_numpy = isinstance(box, np.ndarray) + single_box = isinstance(box, (list, tuple)) + if single_box: + assert len(box) == 4 or len(box) == 5, ( + "BoxMode.convert takes either a k-tuple/list or an Nxk array/tensor," + " where k == 4 or 5" + ) + arr = torch.tensor(box)[None, :] + else: + # avoid modifying the input box + arr = torch.from_numpy(np.asarray(box)).clone() if is_numpy else box.clone() # type: ignore + + assert to_mode not in [ + BoxMode.XYXY_REL, + BoxMode.XYWH_REL, + ], "Relative mode not yet supported!" + assert from_mode not in [ + BoxMode.XYXY_REL, + BoxMode.XYWH_REL, + ], "Relative mode not yet supported!" + + if from_mode == BoxMode.XYWHA_ABS and to_mode == BoxMode.XYXY_ABS: + assert ( + arr.shape[-1] == 5 + ), "The last dimension of input shape must be 5 for XYWHA format" + original_dtype = arr.dtype + arr = arr.double() + + w = arr[:, 2] + h = arr[:, 3] + a = arr[:, 4] + c = torch.abs(torch.cos(a * math.pi / 180.0)) + s = torch.abs(torch.sin(a * math.pi / 180.0)) + # This basically computes the horizontal bounding rectangle of the rotated box + new_w = c * w + s * h + new_h = c * h + s * w + + # convert center to top-left corner + arr[:, 0] -= new_w / 2.0 + arr[:, 1] -= new_h / 2.0 + # bottom-right corner + arr[:, 2] = arr[:, 0] + new_w + arr[:, 3] = arr[:, 1] + new_h + + arr = arr[:, :4].to(dtype=original_dtype) + elif from_mode == BoxMode.XYWH_ABS and to_mode == BoxMode.XYWHA_ABS: + original_dtype = arr.dtype + arr = arr.double() + arr[:, 0] += arr[:, 2] / 2.0 + arr[:, 1] += arr[:, 3] / 2.0 + angles = torch.zeros((arr.shape[0], 1), dtype=arr.dtype) + arr = torch.cat((arr, angles), axis=1).to(dtype=original_dtype) # type: ignore + else: + if to_mode == BoxMode.XYXY_ABS and from_mode == BoxMode.XYWH_ABS: + arr[:, 2] += arr[:, 0] + arr[:, 3] += arr[:, 1] + elif from_mode == BoxMode.XYXY_ABS and to_mode == BoxMode.XYWH_ABS: + arr[:, 2] -= arr[:, 0] + arr[:, 3] -= arr[:, 1] + else: + msg = f"Conversion from BoxMode {from_mode} to {to_mode} is not supported yet" + raise NotImplementedError(msg) + + if single_box: + return original_type(arr.flatten().tolist()) + if is_numpy: + return arr.numpy() + else: + return arr + + +class MatrixVisualizer: + def __init__( + self, + inplace=True, + cmap=cv2.COLORMAP_PARULA, + val_scale=1.0, + alpha=0.7, + interp_method_matrix=cv2.INTER_LINEAR, + interp_method_mask=cv2.INTER_NEAREST, + ): + self.inplace = inplace + self.cmap = cmap + self.val_scale = val_scale + self.alpha = alpha + self.interp_method_matrix = interp_method_matrix + self.interp_method_mask = interp_method_mask + + def visualize(self, image_bgr: np.ndarray, mask: np.ndarray, matrix, bbox_xywh): + self._check_image(image_bgr) + self._check_mask_matrix(mask, matrix) + image_target_bgr = image_bgr if self.inplace else image_bgr * 0 + + x, y, w, h = (int(v) for v in bbox_xywh) + if w <= 0 or h <= 0: + return image_bgr + mask, matrix = self._resize(mask, matrix, w, h) + mask_bg = np.tile((mask == 0)[:, :, np.newaxis], [1, 1, 3]) + matrix_scaled = matrix.astype(np.float32) * self.val_scale + _EPSILON = 1e-6 + if np.any(matrix_scaled > 255 + _EPSILON): + logger = logging.getLogger(__name__) + logger.warning( + f"Matrix has values > {255 + _EPSILON} after " + f"scaling, clipping to [0..255]" + ) + matrix_scaled_8u = matrix_scaled.clip(0, 255).astype(np.uint8) + matrix_vis = cv2.applyColorMap(matrix_scaled_8u, self.cmap) + matrix_vis[mask_bg] = image_target_bgr[y : y + h, x : x + w, :][mask_bg] + image_target_bgr[y : y + h, x : x + w, :] = ( + image_target_bgr[y : y + h, x : x + w, :] * (1.0 - self.alpha) + + matrix_vis * self.alpha + ) + return image_target_bgr.astype(np.uint8) + + def _resize(self, mask, matrix, w, h): + if (w != mask.shape[1]) or (h != mask.shape[0]): + mask = cv2.resize(mask, (w, h), self.interp_method_mask) + if (w != matrix.shape[1]) or (h != matrix.shape[0]): + matrix = cv2.resize(matrix, (w, h), self.interp_method_matrix) + return mask, matrix + + def _check_image(self, image_rgb): + assert len(image_rgb.shape) == 3 + assert image_rgb.shape[2] == 3 + assert image_rgb.dtype == np.uint8 + + def _check_mask_matrix(self, mask, matrix): + assert len(matrix.shape) == 2 + assert len(mask.shape) == 2 + assert mask.dtype == np.uint8 + + +class DensePoseMaskedColormapResultsVisualizer: + def __init__( + self, + data_extractor, + segm_extractor, + inplace=True, + cmap=cv2.COLORMAP_PARULA, + alpha=0.7, + val_scale=1.0, + ): + self.mask_visualizer = MatrixVisualizer( + inplace=inplace, cmap=cmap, val_scale=val_scale, alpha=alpha + ) + self.data_extractor = data_extractor + self.segm_extractor = segm_extractor + + def visualize( + self, + image_bgr: np.ndarray, + results, + ) -> np.ndarray: + for result in results: + boxes_xywh, labels, uv = result + iuv_array = torch.cat((labels[None].type(torch.float32), uv * 255.0)).type( + torch.uint8 + ) + self.visualize_iuv_arr(image_bgr, iuv_array.cpu().numpy(), boxes_xywh) + return image_bgr + + def visualize_iuv_arr(self, image_bgr, iuv_arr: np.ndarray, bbox_xywh) -> None: + matrix = self.data_extractor(iuv_arr) + segm = self.segm_extractor(iuv_arr) + mask = (segm > 0).astype(np.uint8) + self.mask_visualizer.visualize(image_bgr, mask, matrix, bbox_xywh) + + +def _extract_i_from_iuvarr(iuv_arr): + return iuv_arr[0, :, :] + + +def _extract_u_from_iuvarr(iuv_arr): + return iuv_arr[1, :, :] + + +def _extract_v_from_iuvarr(iuv_arr): + return iuv_arr[2, :, :] + + +def make_int_box(box: torch.Tensor) -> IntTupleBox: + int_box = [0, 0, 0, 0] + int_box[0], int_box[1], int_box[2], int_box[3] = tuple(box.long().tolist()) + return int_box[0], int_box[1], int_box[2], int_box[3] + + +def densepose_chart_predictor_output_to_result( + boxes: torch.Tensor, coarse_segm: torch.Tensor, fine_segm, u, v +): + boxes = boxes.unsqueeze(0) + coarse_segm = coarse_segm.unsqueeze(0) + fine_segm = fine_segm.unsqueeze(0) + u = u.unsqueeze(0) + v = v.unsqueeze(0) + boxes_xyxy_abs = boxes.clone() + boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS) + box_xywh = make_int_box(boxes_xywh_abs[0]) # type: ignore + + labels = resample_fine_and_coarse_segm_tensors_to_bbox( + fine_segm, coarse_segm, box_xywh + ).squeeze(0) + uv = resample_uv_tensors_to_bbox(u, v, labels, box_xywh) + return box_xywh, labels, uv + + +def resample_fine_and_coarse_segm_tensors_to_bbox( + fine_segm: torch.Tensor, coarse_segm: torch.Tensor, box_xywh_abs: IntTupleBox +): + """ + Resample fine and coarse segmentation tensors to the given + bounding box and derive labels for each pixel of the bounding box + + Args: + fine_segm: float tensor of shape [1, C, Hout, Wout] + coarse_segm: float tensor of shape [1, K, Hout, Wout] + box_xywh_abs (tuple of 4 int): bounding box given by its upper-left + corner coordinates, width (W) and height (H) + Return: + Labels for each pixel of the bounding box, a long tensor of size [1, H, W] + """ + x, y, w, h = box_xywh_abs + w = max(int(w), 1) + h = max(int(h), 1) + # coarse segmentation + coarse_segm_bbox = F.interpolate( + coarse_segm, + (h, w), + mode="bilinear", + align_corners=False, + ).argmax(dim=1) + # combined coarse and fine segmentation + labels = ( + F.interpolate(fine_segm, (h, w), mode="bilinear", align_corners=False).argmax( + dim=1 + ) + * (coarse_segm_bbox > 0).long() + ) + return labels + + +def resample_uv_tensors_to_bbox( + u: torch.Tensor, + v: torch.Tensor, + labels: torch.Tensor, + box_xywh_abs: IntTupleBox, +) -> torch.Tensor: + """ + Resamples U and V coordinate estimates for the given bounding box + + Args: + u (tensor [1, C, H, W] of float): U coordinates + v (tensor [1, C, H, W] of float): V coordinates + labels (tensor [H, W] of long): labels obtained by resampling segmentation + outputs for the given bounding box + box_xywh_abs (tuple of 4 int): bounding box that corresponds to predictor outputs + Return: + Resampled U and V coordinates - a tensor [2, H, W] of float + """ + x, y, w, h = box_xywh_abs + w = max(int(w), 1) + h = max(int(h), 1) + u_bbox = F.interpolate(u, (h, w), mode="bilinear", align_corners=False) + v_bbox = F.interpolate(v, (h, w), mode="bilinear", align_corners=False) + uv = torch.zeros([2, h, w], dtype=torch.float32, device=u.device) + for part_id in range(1, u_bbox.size(1)): + uv[0][labels == part_id] = u_bbox[0, part_id][labels == part_id] + uv[1][labels == part_id] = v_bbox[0, part_id][labels == part_id] + return uv diff --git a/imaginairy/utils/downloads.py b/imaginairy/utils/downloads.py index d048139..927a007 100644 --- a/imaginairy/utils/downloads.py +++ b/imaginairy/utils/downloads.py @@ -91,7 +91,9 @@ def huggingface_cached_path(url: str) -> str: dest_path = try_to_load_from_cache( repo_id=repo, revision=commit_hash, filename=filepath ) - if not dest_path: + from huggingface_hub.file_download import _CACHED_NO_EXIST + + if not dest_path or dest_path == _CACHED_NO_EXIST: check_huggingface_url_authorized(url) token = HfFolder.get_token() logger.info(f"Downloading {url} from huggingface") diff --git a/tests/data/cuda-tests.csv b/tests/data/cuda-tests.csv index 7a91904..aa8e0af 100644 --- a/tests/data/cuda-tests.csv +++ b/tests/data/cuda-tests.csv @@ -1,3 +1,4 @@ +tests/img_processors/test_control_modes.py::test_control_images[densepose-create_densepose_map] tests/img_processors/test_control_modes.py::test_control_images[depth-create_depth_map] tests/img_processors/test_control_modes.py::test_control_images[hed-create_hed_edges] tests/img_processors/test_control_modes.py::test_control_images[normal-create_normal_map] diff --git a/tests/expected_output/test_control_images[densepose-create_densepose_map]_.png b/tests/expected_output/test_control_images[densepose-create_densepose_map]_.png new file mode 100644 index 0000000..1cf4c30 Binary files /dev/null and b/tests/expected_output/test_control_images[densepose-create_densepose_map]_.png differ diff --git a/tests/test_http_app/test_routes.py b/tests/test_http_app/test_routes.py index 97306b5..e8f151e 100644 --- a/tests/test_http_app/test_routes.py +++ b/tests/test_http_app/test_routes.py @@ -57,12 +57,12 @@ async def test_list_models(): assert response.status_code == 200 expected_model_ids = { + "miniaturuspotentia", "sd15", "openjourney-v1", "openjourney-v2", "openjourney-v4", "modern-disney", - "redshift-diffusion", "sdxl", "opendalle11", }