feature: densepose controlnet (#481)
parent
ce37e60b11
commit
df86aa6668
@ -0,0 +1,653 @@
|
||||
# adapted from https://github.com/Mikubill/sd-webui-controlnet/blob/0b90426254debf78bfc09d88c064d2caf0935282/annotator/densepose/densepose.py
|
||||
import logging
|
||||
import math
|
||||
from enum import IntEnum
|
||||
from functools import lru_cache
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from imaginairy import config
|
||||
from imaginairy.utils.downloads import get_cached_url_path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
N_PART_LABELS = 24
|
||||
|
||||
|
||||
_RawBoxType = Union[List[float], Tuple[float, ...], torch.Tensor, np.ndarray]
|
||||
IntTupleBox = Tuple[int, int, int, int]
|
||||
|
||||
|
||||
def safer_memory(x):
|
||||
# Fix many MAC/AMD problems
|
||||
return np.ascontiguousarray(x.copy()).copy()
|
||||
|
||||
|
||||
def pad64(x):
|
||||
return int(np.ceil(float(x) / 64.0) * 64 - x)
|
||||
|
||||
|
||||
def resize_image_with_pad_torch(
|
||||
img, resolution, upscale_method="bicubic", mode="constant"
|
||||
):
|
||||
B, C, H_raw, W_raw = img.shape
|
||||
k = float(resolution) / float(min(H_raw, W_raw))
|
||||
H_target = int(math.ceil(float(H_raw) * k))
|
||||
W_target = int(math.ceil(float(W_raw) * k))
|
||||
|
||||
if k > 1:
|
||||
img = F.interpolate(
|
||||
img,
|
||||
size=(H_target, W_target),
|
||||
mode=upscale_method,
|
||||
align_corners=False,
|
||||
)
|
||||
else:
|
||||
img = F.interpolate(img, size=(H_target, W_target), mode="area")
|
||||
|
||||
H_pad, W_pad = pad64(H_target), pad64(W_target)
|
||||
# print(f"image after resize but before padding: {img.shape}")
|
||||
img_padded = F.pad(img, (0, W_pad, 0, H_pad), mode=mode)
|
||||
|
||||
def remove_pad(x):
|
||||
# print(
|
||||
# f"remove_pad: x.shape: {x.shape}. H_target: {H_target}, W_target: {W_target}"
|
||||
# )
|
||||
return safer_memory(x[:H_target, :W_target, ...])
|
||||
|
||||
return img_padded, remove_pad
|
||||
|
||||
|
||||
def HWC3(x: np.ndarray) -> np.ndarray:
|
||||
assert x.dtype == np.uint8
|
||||
if x.ndim == 2:
|
||||
x = x[:, :, None]
|
||||
assert x.ndim == 3
|
||||
H, W, C = x.shape
|
||||
assert C == 1 or C == 3 or C == 4
|
||||
if C == 3:
|
||||
return x
|
||||
if C == 1:
|
||||
return np.concatenate([x, x, x], axis=2)
|
||||
if C == 4:
|
||||
color = x[:, :, 0:3].astype(np.float32)
|
||||
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
|
||||
y = color * alpha + 255.0 * (1.0 - alpha)
|
||||
y = y.clip(0, 255).astype(np.uint8)
|
||||
return y
|
||||
raise RuntimeError("unreachable")
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_densepose_model(
|
||||
filename="densepose_r101_fpn_dl.torchscript", base_url=config.DENSEPOSE_REPO_URL
|
||||
):
|
||||
import torchvision # noqa
|
||||
|
||||
url = f"{base_url}/{filename}"
|
||||
torchscript_model_path = get_cached_url_path(url)
|
||||
logger.info(f"Loading densepose model {url} from {torchscript_model_path}")
|
||||
densepose = torch.jit.load(torchscript_model_path, map_location="cpu")
|
||||
return densepose
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_segment_result_visualizer():
|
||||
return DensePoseMaskedColormapResultsVisualizer(
|
||||
alpha=1,
|
||||
data_extractor=_extract_i_from_iuvarr,
|
||||
segm_extractor=_extract_i_from_iuvarr,
|
||||
val_scale=255.0 / N_PART_LABELS,
|
||||
)
|
||||
|
||||
|
||||
def mask_to_bbox(mask_img_t):
|
||||
m = mask_img_t.nonzero()
|
||||
if m.numel() == 0:
|
||||
return None
|
||||
y0 = torch.min(m[:, 0])
|
||||
y1 = torch.max(m[:, 0])
|
||||
x0 = torch.min(m[:, 1])
|
||||
x1 = torch.max(m[:, 1])
|
||||
return x0, y0, x1, y1
|
||||
|
||||
|
||||
def pad_bbox(bbox, max_height, max_width, pad=1):
|
||||
x0, y0, x1, y1 = bbox
|
||||
x0 = max(0, x0 - pad)
|
||||
y0 = max(0, y0 - pad)
|
||||
x1 = min(max_width, x1 + pad)
|
||||
y1 = min(max_height, y1 + pad)
|
||||
return x0, y0, x1, y1
|
||||
|
||||
|
||||
def square_bbox(bbox, max_height, max_width):
|
||||
"""
|
||||
Adjusts the bounding box to make it as close to a square as possible while
|
||||
ensuring it does not exceed the max_size of the image and still includes
|
||||
the original bounding box contents.
|
||||
|
||||
Args:
|
||||
- bbox: A tuple of (x0, y0, x1, y1) for the original bounding box.
|
||||
- max_size: A tuple of (max_width, max_height) representing the image size.
|
||||
|
||||
Returns:
|
||||
- A tuple of (x0, y0, x1, y1) for the adjusted bounding box.
|
||||
"""
|
||||
x0, y0, x1, y1 = bbox
|
||||
width = x1 - x0
|
||||
height = y1 - y0
|
||||
|
||||
# Determine how much to adjust to make the bounding box square
|
||||
if width > height:
|
||||
diff = width - height
|
||||
half_diff = diff // 2
|
||||
y0 = max(0, y0 - half_diff)
|
||||
y1 = min(max_height, y1 + half_diff + (diff % 2)) # Add 1 if diff is odd
|
||||
elif height > width:
|
||||
diff = height - width
|
||||
half_diff = diff // 2
|
||||
x0 = max(0, x0 - half_diff)
|
||||
x1 = min(max_width, x1 + half_diff + (diff % 2)) # Add 1 if diff is odd
|
||||
|
||||
# Ensure the bounding box is within the image boundaries
|
||||
x0 = max(0, min(x0, max_width - 1))
|
||||
y0 = max(0, min(y0, max_height - 1))
|
||||
x1 = max(0, min(x1, max_width))
|
||||
y1 = max(0, min(y1, max_height))
|
||||
|
||||
return x0, y0, x1, y1
|
||||
|
||||
|
||||
def _np_to_t(img_np):
|
||||
img_t = torch.from_numpy(img_np) / 255.0
|
||||
img_t = img_t.permute(2, 0, 1)
|
||||
img_t = img_t.unsqueeze(0)
|
||||
return img_t
|
||||
|
||||
|
||||
def generate_densepose_image(
|
||||
img: torch.Tensor,
|
||||
detect_resolution=512,
|
||||
upscale_method="bicubic",
|
||||
cmap="viridis",
|
||||
double_pass=False,
|
||||
):
|
||||
assert_tensor_float_11_bchw(img)
|
||||
input_h, input_w = img.shape[-2:]
|
||||
if double_pass:
|
||||
first_densepose_img_np = _generate_densepose_image(
|
||||
img, detect_resolution, upscale_method, cmap, adapt_viridis_bg=False
|
||||
)
|
||||
first_densepose_img_t = _np_to_t(first_densepose_img_np)
|
||||
# convert the densepose image into a mask (every color other than black is part of the mask)
|
||||
densepose_img_mask = first_densepose_img_t[0].sum(dim=0) > 0
|
||||
# print(f"Mask shape: {densepose_img_mask.shape}")
|
||||
# bbox = masks_to_boxes(densepose_img_mask.unsqueeze(0)).to(torch.uint8)
|
||||
# crop image by bbox
|
||||
bbox = mask_to_bbox(densepose_img_mask)
|
||||
# print(f"bbox: {bbox}")
|
||||
|
||||
if bbox is None:
|
||||
densepose_np = first_densepose_img_np
|
||||
else:
|
||||
bbox = pad_bbox(bbox, max_height=input_h, max_width=input_w, pad=10)
|
||||
# print(f"padded bbox: {bbox}")
|
||||
bbox = square_bbox(bbox, max_height=input_h, max_width=input_w)
|
||||
# print(f"boxed bbox: {bbox}")
|
||||
x0, y0, x1, y1 = bbox
|
||||
|
||||
cropped_img = img[:, :, y0:y1, x0:x1]
|
||||
# print(f"cropped_img shape: {cropped_img.shape}")
|
||||
|
||||
densepose_np = _generate_densepose_image(
|
||||
cropped_img,
|
||||
detect_resolution,
|
||||
upscale_method,
|
||||
cmap,
|
||||
adapt_viridis_bg=False,
|
||||
)
|
||||
# print(f"cropped densepose_np shape: {densepose_np.shape}")
|
||||
# print(
|
||||
# f"pasting into first_densepose_img_np shape: {first_densepose_img_np.shape} at {y0}:{y1}, {x0}:{x1}"
|
||||
# )
|
||||
# paste denspose_np back into first_densepose_img_np using bbox
|
||||
first_densepose_img_np[y0:y1, x0:x1] = densepose_np
|
||||
densepose_np = first_densepose_img_np
|
||||
else:
|
||||
densepose_np = _generate_densepose_image(
|
||||
img, detect_resolution, upscale_method, cmap, adapt_viridis_bg=False
|
||||
)
|
||||
|
||||
if cmap == "viridis":
|
||||
densepose_np[:, :, 0][densepose_np[:, :, 0] == 0] = 68
|
||||
densepose_np[:, :, 1][densepose_np[:, :, 1] == 0] = 1
|
||||
densepose_np[:, :, 2][densepose_np[:, :, 2] == 0] = 84
|
||||
|
||||
return densepose_np
|
||||
|
||||
|
||||
def _generate_densepose_image(
|
||||
img: torch.Tensor,
|
||||
detect_resolution=512,
|
||||
upscale_method="bicubic",
|
||||
cmap="viridis",
|
||||
adapt_viridis_bg=True,
|
||||
) -> np.ndarray:
|
||||
assert_tensor_float_11_bchw(img)
|
||||
input_h, input_w = img.shape[-2:]
|
||||
# print(f"input_h: {input_h}, input_w: {input_w}")
|
||||
img, remove_pad = resize_image_with_pad_torch(
|
||||
img, detect_resolution, upscale_method
|
||||
)
|
||||
img = ((img + 1.0) * 127.5).to(torch.uint8)
|
||||
assert_tensor_uint8_255_bchw(img)
|
||||
H, W = img.shape[-2:]
|
||||
# print(f"reduced input img size (with padding): h{H}xw{W}")
|
||||
hint_image_canvas = np.zeros([H, W], dtype=np.uint8)
|
||||
hint_image_canvas = np.tile(hint_image_canvas[:, :, np.newaxis], [1, 1, 3])
|
||||
densepose_model = get_densepose_model()
|
||||
pred_boxes, coarse_seg, fine_segm, u, v = densepose_model(img.squeeze(0))
|
||||
densepose_results = list(
|
||||
map(
|
||||
densepose_chart_predictor_output_to_result,
|
||||
pred_boxes,
|
||||
coarse_seg,
|
||||
fine_segm,
|
||||
u,
|
||||
v,
|
||||
)
|
||||
)
|
||||
cmaps = {
|
||||
"viridis": cv2.COLORMAP_VIRIDIS,
|
||||
"parula": cv2.COLORMAP_PARULA,
|
||||
"jet": cv2.COLORMAP_JET,
|
||||
}
|
||||
cv2_cmap = cmaps.get(cmap, cv2.COLORMAP_PARULA)
|
||||
result_visualizer = get_segment_result_visualizer()
|
||||
result_visualizer.mask_visualizer.cmap = cv2_cmap
|
||||
hint_image = result_visualizer.visualize(hint_image_canvas, densepose_results)
|
||||
hint_image = cv2.cvtColor(hint_image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
if cv2_cmap == cv2.COLORMAP_VIRIDIS and adapt_viridis_bg:
|
||||
hint_image[:, :, 0][hint_image[:, :, 0] == 0] = 68
|
||||
hint_image[:, :, 1][hint_image[:, :, 1] == 0] = 1
|
||||
hint_image[:, :, 2][hint_image[:, :, 2] == 0] = 84
|
||||
# print(f"hint_image shape: {hint_image.shape}")
|
||||
detected_map = remove_pad(HWC3(hint_image))
|
||||
# print(f"detected_map shape (padding removed): {detected_map.shape}")
|
||||
# print(f"Resizing detected_map to original size: {input_w}x{input_h}")
|
||||
# if map is smaller than input size, scale it up
|
||||
if detected_map.shape[0] < input_h or detected_map.shape[1] < input_w:
|
||||
detected_map = cv2.resize(
|
||||
detected_map, (input_w, input_h), interpolation=cv2.INTER_NEAREST
|
||||
)
|
||||
else:
|
||||
# scale it down
|
||||
detected_map = cv2.resize(
|
||||
detected_map, (input_w, input_h), interpolation=cv2.INTER_AREA
|
||||
)
|
||||
# print(f"detected_map shape (resized to original): {detected_map.shape}")
|
||||
return detected_map
|
||||
|
||||
|
||||
def assert_ndarray_uint8_255_hwc(img):
|
||||
# assert input_image is ndarray with colors 0-255
|
||||
assert img.dtype == np.uint8
|
||||
assert img.ndim == 3
|
||||
assert img.shape[2] == 3
|
||||
assert img.max() <= 255
|
||||
assert img.min() >= 0
|
||||
|
||||
|
||||
def assert_tensor_uint8_255_bchw(img):
|
||||
# assert input_image is a PyTorch tensor with colors 0-255 and dimensions (C, H, W)
|
||||
assert isinstance(img, torch.Tensor)
|
||||
assert img.dtype == torch.uint8
|
||||
assert img.ndim == 4
|
||||
assert img.shape[1] == 3
|
||||
assert img.max() <= 255
|
||||
assert img.min() >= 0
|
||||
|
||||
|
||||
def assert_tensor_float_11_bchw(img):
|
||||
# assert input_image is a PyTorch tensor with colors -1 to 1 and dimensions (C, H, W)
|
||||
if not isinstance(img, torch.Tensor):
|
||||
msg = f"Input image must be a PyTorch tensor, but got {type(img)}"
|
||||
raise TypeError(msg)
|
||||
|
||||
if img.dtype not in (torch.float32, torch.float64, torch.float16):
|
||||
msg = f"Input image must be a float tensor, but got {img.dtype}"
|
||||
raise ValueError(msg)
|
||||
|
||||
if img.ndim != 4:
|
||||
msg = f"Input image must be 4D (B, C, H, W), but got {img.ndim}D"
|
||||
raise ValueError(msg)
|
||||
|
||||
if img.shape[1] != 3:
|
||||
msg = f"Input image must have 3 channels, but got {img.shape[1]}"
|
||||
raise ValueError(msg)
|
||||
if img.max() > 1 or img.min() < -1:
|
||||
msg = f"Input image must have values in [-1, 1], but got {img.min()} .. {img.max()}"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
class BoxMode(IntEnum):
|
||||
"""
|
||||
Enum of different ways to represent a box.
|
||||
"""
|
||||
|
||||
XYXY_ABS = 0
|
||||
"""
|
||||
(x0, y0, x1, y1) in absolute floating points coordinates.
|
||||
The coordinates in range [0, width or height].
|
||||
"""
|
||||
XYWH_ABS = 1
|
||||
"""
|
||||
(x0, y0, w, h) in absolute floating points coordinates.
|
||||
"""
|
||||
XYXY_REL = 2
|
||||
"""
|
||||
Not yet supported!
|
||||
(x0, y0, x1, y1) in range [0, 1]. They are relative to the size of the image.
|
||||
"""
|
||||
XYWH_REL = 3
|
||||
"""
|
||||
Not yet supported!
|
||||
(x0, y0, w, h) in range [0, 1]. They are relative to the size of the image.
|
||||
"""
|
||||
XYWHA_ABS = 4
|
||||
"""
|
||||
(xc, yc, w, h, a) in absolute floating points coordinates.
|
||||
(xc, yc) is the center of the rotated box, and the angle a is in degrees ccw.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def convert(
|
||||
box: _RawBoxType, from_mode: "BoxMode", to_mode: "BoxMode"
|
||||
) -> _RawBoxType:
|
||||
"""
|
||||
Args:
|
||||
box: can be a k-tuple, k-list or an Nxk array/tensor, where k = 4 or 5
|
||||
from_mode, to_mode (BoxMode)
|
||||
|
||||
Returns:
|
||||
The converted box of the same type.
|
||||
"""
|
||||
if from_mode == to_mode:
|
||||
return box
|
||||
|
||||
original_type = type(box)
|
||||
is_numpy = isinstance(box, np.ndarray)
|
||||
single_box = isinstance(box, (list, tuple))
|
||||
if single_box:
|
||||
assert len(box) == 4 or len(box) == 5, (
|
||||
"BoxMode.convert takes either a k-tuple/list or an Nxk array/tensor,"
|
||||
" where k == 4 or 5"
|
||||
)
|
||||
arr = torch.tensor(box)[None, :]
|
||||
else:
|
||||
# avoid modifying the input box
|
||||
arr = torch.from_numpy(np.asarray(box)).clone() if is_numpy else box.clone() # type: ignore
|
||||
|
||||
assert to_mode not in [
|
||||
BoxMode.XYXY_REL,
|
||||
BoxMode.XYWH_REL,
|
||||
], "Relative mode not yet supported!"
|
||||
assert from_mode not in [
|
||||
BoxMode.XYXY_REL,
|
||||
BoxMode.XYWH_REL,
|
||||
], "Relative mode not yet supported!"
|
||||
|
||||
if from_mode == BoxMode.XYWHA_ABS and to_mode == BoxMode.XYXY_ABS:
|
||||
assert (
|
||||
arr.shape[-1] == 5
|
||||
), "The last dimension of input shape must be 5 for XYWHA format"
|
||||
original_dtype = arr.dtype
|
||||
arr = arr.double()
|
||||
|
||||
w = arr[:, 2]
|
||||
h = arr[:, 3]
|
||||
a = arr[:, 4]
|
||||
c = torch.abs(torch.cos(a * math.pi / 180.0))
|
||||
s = torch.abs(torch.sin(a * math.pi / 180.0))
|
||||
# This basically computes the horizontal bounding rectangle of the rotated box
|
||||
new_w = c * w + s * h
|
||||
new_h = c * h + s * w
|
||||
|
||||
# convert center to top-left corner
|
||||
arr[:, 0] -= new_w / 2.0
|
||||
arr[:, 1] -= new_h / 2.0
|
||||
# bottom-right corner
|
||||
arr[:, 2] = arr[:, 0] + new_w
|
||||
arr[:, 3] = arr[:, 1] + new_h
|
||||
|
||||
arr = arr[:, :4].to(dtype=original_dtype)
|
||||
elif from_mode == BoxMode.XYWH_ABS and to_mode == BoxMode.XYWHA_ABS:
|
||||
original_dtype = arr.dtype
|
||||
arr = arr.double()
|
||||
arr[:, 0] += arr[:, 2] / 2.0
|
||||
arr[:, 1] += arr[:, 3] / 2.0
|
||||
angles = torch.zeros((arr.shape[0], 1), dtype=arr.dtype)
|
||||
arr = torch.cat((arr, angles), axis=1).to(dtype=original_dtype) # type: ignore
|
||||
else:
|
||||
if to_mode == BoxMode.XYXY_ABS and from_mode == BoxMode.XYWH_ABS:
|
||||
arr[:, 2] += arr[:, 0]
|
||||
arr[:, 3] += arr[:, 1]
|
||||
elif from_mode == BoxMode.XYXY_ABS and to_mode == BoxMode.XYWH_ABS:
|
||||
arr[:, 2] -= arr[:, 0]
|
||||
arr[:, 3] -= arr[:, 1]
|
||||
else:
|
||||
msg = f"Conversion from BoxMode {from_mode} to {to_mode} is not supported yet"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
if single_box:
|
||||
return original_type(arr.flatten().tolist())
|
||||
if is_numpy:
|
||||
return arr.numpy()
|
||||
else:
|
||||
return arr
|
||||
|
||||
|
||||
class MatrixVisualizer:
|
||||
def __init__(
|
||||
self,
|
||||
inplace=True,
|
||||
cmap=cv2.COLORMAP_PARULA,
|
||||
val_scale=1.0,
|
||||
alpha=0.7,
|
||||
interp_method_matrix=cv2.INTER_LINEAR,
|
||||
interp_method_mask=cv2.INTER_NEAREST,
|
||||
):
|
||||
self.inplace = inplace
|
||||
self.cmap = cmap
|
||||
self.val_scale = val_scale
|
||||
self.alpha = alpha
|
||||
self.interp_method_matrix = interp_method_matrix
|
||||
self.interp_method_mask = interp_method_mask
|
||||
|
||||
def visualize(self, image_bgr: np.ndarray, mask: np.ndarray, matrix, bbox_xywh):
|
||||
self._check_image(image_bgr)
|
||||
self._check_mask_matrix(mask, matrix)
|
||||
image_target_bgr = image_bgr if self.inplace else image_bgr * 0
|
||||
|
||||
x, y, w, h = (int(v) for v in bbox_xywh)
|
||||
if w <= 0 or h <= 0:
|
||||
return image_bgr
|
||||
mask, matrix = self._resize(mask, matrix, w, h)
|
||||
mask_bg = np.tile((mask == 0)[:, :, np.newaxis], [1, 1, 3])
|
||||
matrix_scaled = matrix.astype(np.float32) * self.val_scale
|
||||
_EPSILON = 1e-6
|
||||
if np.any(matrix_scaled > 255 + _EPSILON):
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(
|
||||
f"Matrix has values > {255 + _EPSILON} after "
|
||||
f"scaling, clipping to [0..255]"
|
||||
)
|
||||
matrix_scaled_8u = matrix_scaled.clip(0, 255).astype(np.uint8)
|
||||
matrix_vis = cv2.applyColorMap(matrix_scaled_8u, self.cmap)
|
||||
matrix_vis[mask_bg] = image_target_bgr[y : y + h, x : x + w, :][mask_bg]
|
||||
image_target_bgr[y : y + h, x : x + w, :] = (
|
||||
image_target_bgr[y : y + h, x : x + w, :] * (1.0 - self.alpha)
|
||||
+ matrix_vis * self.alpha
|
||||
)
|
||||
return image_target_bgr.astype(np.uint8)
|
||||
|
||||
def _resize(self, mask, matrix, w, h):
|
||||
if (w != mask.shape[1]) or (h != mask.shape[0]):
|
||||
mask = cv2.resize(mask, (w, h), self.interp_method_mask)
|
||||
if (w != matrix.shape[1]) or (h != matrix.shape[0]):
|
||||
matrix = cv2.resize(matrix, (w, h), self.interp_method_matrix)
|
||||
return mask, matrix
|
||||
|
||||
def _check_image(self, image_rgb):
|
||||
assert len(image_rgb.shape) == 3
|
||||
assert image_rgb.shape[2] == 3
|
||||
assert image_rgb.dtype == np.uint8
|
||||
|
||||
def _check_mask_matrix(self, mask, matrix):
|
||||
assert len(matrix.shape) == 2
|
||||
assert len(mask.shape) == 2
|
||||
assert mask.dtype == np.uint8
|
||||
|
||||
|
||||
class DensePoseMaskedColormapResultsVisualizer:
|
||||
def __init__(
|
||||
self,
|
||||
data_extractor,
|
||||
segm_extractor,
|
||||
inplace=True,
|
||||
cmap=cv2.COLORMAP_PARULA,
|
||||
alpha=0.7,
|
||||
val_scale=1.0,
|
||||
):
|
||||
self.mask_visualizer = MatrixVisualizer(
|
||||
inplace=inplace, cmap=cmap, val_scale=val_scale, alpha=alpha
|
||||
)
|
||||
self.data_extractor = data_extractor
|
||||
self.segm_extractor = segm_extractor
|
||||
|
||||
def visualize(
|
||||
self,
|
||||
image_bgr: np.ndarray,
|
||||
results,
|
||||
) -> np.ndarray:
|
||||
for result in results:
|
||||
boxes_xywh, labels, uv = result
|
||||
iuv_array = torch.cat((labels[None].type(torch.float32), uv * 255.0)).type(
|
||||
torch.uint8
|
||||
)
|
||||
self.visualize_iuv_arr(image_bgr, iuv_array.cpu().numpy(), boxes_xywh)
|
||||
return image_bgr
|
||||
|
||||
def visualize_iuv_arr(self, image_bgr, iuv_arr: np.ndarray, bbox_xywh) -> None:
|
||||
matrix = self.data_extractor(iuv_arr)
|
||||
segm = self.segm_extractor(iuv_arr)
|
||||
mask = (segm > 0).astype(np.uint8)
|
||||
self.mask_visualizer.visualize(image_bgr, mask, matrix, bbox_xywh)
|
||||
|
||||
|
||||
def _extract_i_from_iuvarr(iuv_arr):
|
||||
return iuv_arr[0, :, :]
|
||||
|
||||
|
||||
def _extract_u_from_iuvarr(iuv_arr):
|
||||
return iuv_arr[1, :, :]
|
||||
|
||||
|
||||
def _extract_v_from_iuvarr(iuv_arr):
|
||||
return iuv_arr[2, :, :]
|
||||
|
||||
|
||||
def make_int_box(box: torch.Tensor) -> IntTupleBox:
|
||||
int_box = [0, 0, 0, 0]
|
||||
int_box[0], int_box[1], int_box[2], int_box[3] = tuple(box.long().tolist())
|
||||
return int_box[0], int_box[1], int_box[2], int_box[3]
|
||||
|
||||
|
||||
def densepose_chart_predictor_output_to_result(
|
||||
boxes: torch.Tensor, coarse_segm: torch.Tensor, fine_segm, u, v
|
||||
):
|
||||
boxes = boxes.unsqueeze(0)
|
||||
coarse_segm = coarse_segm.unsqueeze(0)
|
||||
fine_segm = fine_segm.unsqueeze(0)
|
||||
u = u.unsqueeze(0)
|
||||
v = v.unsqueeze(0)
|
||||
boxes_xyxy_abs = boxes.clone()
|
||||
boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
|
||||
box_xywh = make_int_box(boxes_xywh_abs[0]) # type: ignore
|
||||
|
||||
labels = resample_fine_and_coarse_segm_tensors_to_bbox(
|
||||
fine_segm, coarse_segm, box_xywh
|
||||
).squeeze(0)
|
||||
uv = resample_uv_tensors_to_bbox(u, v, labels, box_xywh)
|
||||
return box_xywh, labels, uv
|
||||
|
||||
|
||||
def resample_fine_and_coarse_segm_tensors_to_bbox(
|
||||
fine_segm: torch.Tensor, coarse_segm: torch.Tensor, box_xywh_abs: IntTupleBox
|
||||
):
|
||||
"""
|
||||
Resample fine and coarse segmentation tensors to the given
|
||||
bounding box and derive labels for each pixel of the bounding box
|
||||
|
||||
Args:
|
||||
fine_segm: float tensor of shape [1, C, Hout, Wout]
|
||||
coarse_segm: float tensor of shape [1, K, Hout, Wout]
|
||||
box_xywh_abs (tuple of 4 int): bounding box given by its upper-left
|
||||
corner coordinates, width (W) and height (H)
|
||||
Return:
|
||||
Labels for each pixel of the bounding box, a long tensor of size [1, H, W]
|
||||
"""
|
||||
x, y, w, h = box_xywh_abs
|
||||
w = max(int(w), 1)
|
||||
h = max(int(h), 1)
|
||||
# coarse segmentation
|
||||
coarse_segm_bbox = F.interpolate(
|
||||
coarse_segm,
|
||||
(h, w),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
).argmax(dim=1)
|
||||
# combined coarse and fine segmentation
|
||||
labels = (
|
||||
F.interpolate(fine_segm, (h, w), mode="bilinear", align_corners=False).argmax(
|
||||
dim=1
|
||||
)
|
||||
* (coarse_segm_bbox > 0).long()
|
||||
)
|
||||
return labels
|
||||
|
||||
|
||||
def resample_uv_tensors_to_bbox(
|
||||
u: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
box_xywh_abs: IntTupleBox,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Resamples U and V coordinate estimates for the given bounding box
|
||||
|
||||
Args:
|
||||
u (tensor [1, C, H, W] of float): U coordinates
|
||||
v (tensor [1, C, H, W] of float): V coordinates
|
||||
labels (tensor [H, W] of long): labels obtained by resampling segmentation
|
||||
outputs for the given bounding box
|
||||
box_xywh_abs (tuple of 4 int): bounding box that corresponds to predictor outputs
|
||||
Return:
|
||||
Resampled U and V coordinates - a tensor [2, H, W] of float
|
||||
"""
|
||||
x, y, w, h = box_xywh_abs
|
||||
w = max(int(w), 1)
|
||||
h = max(int(h), 1)
|
||||
u_bbox = F.interpolate(u, (h, w), mode="bilinear", align_corners=False)
|
||||
v_bbox = F.interpolate(v, (h, w), mode="bilinear", align_corners=False)
|
||||
uv = torch.zeros([2, h, w], dtype=torch.float32, device=u.device)
|
||||
for part_id in range(1, u_bbox.size(1)):
|
||||
uv[0][labels == part_id] = u_bbox[0, part_id][labels == part_id]
|
||||
uv[1][labels == part_id] = v_bbox[0, part_id][labels == part_id]
|
||||
return uv
|
|
Binary file not shown.
After Width: | Height: | Size: 5.6 KiB |
Loading…
Reference in New Issue