2023-02-12 07:42:19 +00:00
|
|
|
"""Functions to create hint images for controlnet."""
|
2023-12-12 06:29:36 +00:00
|
|
|
from typing import TYPE_CHECKING, Callable, Dict, Union
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
import numpy as np
|
|
|
|
from torch import Tensor # noqa
|
2023-02-12 07:42:19 +00:00
|
|
|
|
|
|
|
|
2023-12-12 06:29:36 +00:00
|
|
|
def create_canny_edges(img: "Tensor") -> "Tensor":
|
2023-02-12 07:42:19 +00:00
|
|
|
import cv2
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
from einops import einops
|
|
|
|
|
|
|
|
img = torch.clamp((img + 1.0) / 2.0, min=0.0, max=1.0)
|
|
|
|
img = einops.rearrange(img[0], "c h w -> h w c")
|
|
|
|
img = (255.0 * img).cpu().numpy().astype(np.uint8).squeeze()
|
|
|
|
blurred = cv2.GaussianBlur(img, (5, 5), 0).astype(np.uint8)
|
|
|
|
|
|
|
|
if len(blurred.shape) > 2:
|
|
|
|
blurred = cv2.cvtColor(blurred, cv2.COLOR_BGR2GRAY)
|
|
|
|
|
|
|
|
threshold2, _ = cv2.threshold(
|
|
|
|
blurred, thresh=0, maxval=255, type=(cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
|
|
|
)
|
|
|
|
canny_image = cv2.Canny(
|
|
|
|
blurred, threshold1=(threshold2 * 0.5), threshold2=threshold2
|
|
|
|
)
|
|
|
|
|
|
|
|
# canny_image = cv2.Canny(blur, 100, 200)
|
|
|
|
canny_image = canny_image[:, :, None]
|
|
|
|
# controlnet requires three channels
|
|
|
|
canny_image = np.concatenate([canny_image, canny_image, canny_image], axis=2)
|
|
|
|
canny_image = torch.from_numpy(canny_image).to(dtype=torch.float32) / 255.0
|
|
|
|
canny_image = einops.rearrange(canny_image, "h w c -> c h w").clone()
|
|
|
|
canny_image = canny_image.unsqueeze(0)
|
|
|
|
|
|
|
|
return canny_image
|
|
|
|
|
|
|
|
|
2023-12-18 20:13:42 +00:00
|
|
|
def create_depth_map(
|
|
|
|
img: "Tensor", model_type="dpt_hybrid_384", max_size=512
|
|
|
|
) -> "Tensor":
|
2023-02-12 07:42:19 +00:00
|
|
|
import torch
|
|
|
|
|
|
|
|
orig_size = img.shape[2:]
|
|
|
|
|
2023-12-18 20:13:42 +00:00
|
|
|
depth_pt = _create_depth_map_raw(img, max_size=max_size, model_type=model_type)
|
2023-02-12 07:42:19 +00:00
|
|
|
# copy the depth map to the other channels
|
|
|
|
depth_pt = torch.cat([depth_pt, depth_pt, depth_pt], dim=0)
|
|
|
|
|
|
|
|
depth_pt -= torch.min(depth_pt)
|
|
|
|
depth_pt /= torch.max(depth_pt)
|
|
|
|
depth_pt = depth_pt.unsqueeze(0)
|
|
|
|
# depth_pt = depth_pt.cpu().numpy()
|
|
|
|
|
|
|
|
depth_pt = torch.nn.functional.interpolate(
|
|
|
|
depth_pt,
|
|
|
|
size=orig_size,
|
|
|
|
mode="bilinear",
|
|
|
|
)
|
|
|
|
|
|
|
|
return depth_pt
|
|
|
|
|
|
|
|
|
2023-12-18 20:13:42 +00:00
|
|
|
def _create_depth_map_raw(
|
|
|
|
img: "Tensor", max_size=512, model_type="dpt_large_384"
|
|
|
|
) -> "Tensor":
|
2023-02-12 07:42:19 +00:00
|
|
|
import torch
|
|
|
|
|
|
|
|
from imaginairy.modules.midas.api import MiDaSInference, midas_device
|
|
|
|
|
2023-12-18 20:13:42 +00:00
|
|
|
model = MiDaSInference(model_type=model_type).to(midas_device())
|
2023-02-12 07:42:19 +00:00
|
|
|
img = img.to(midas_device())
|
|
|
|
|
|
|
|
# calculate new size such that image fits within 512x512 but keeps aspect ratio
|
|
|
|
if img.shape[2] > img.shape[3]:
|
|
|
|
new_size = (max_size, int(max_size * img.shape[3] / img.shape[2]))
|
|
|
|
else:
|
|
|
|
new_size = (int(max_size * img.shape[2] / img.shape[3]), max_size)
|
|
|
|
|
|
|
|
# resize torch image to be multiple of 32
|
|
|
|
img = torch.nn.functional.interpolate(
|
|
|
|
img,
|
|
|
|
size=(new_size[0] // 32 * 32, new_size[1] // 32 * 32),
|
|
|
|
mode="bilinear",
|
|
|
|
align_corners=False,
|
|
|
|
)
|
|
|
|
|
2023-09-29 08:13:50 +00:00
|
|
|
depth_pt = model(img)[0]
|
2023-02-12 07:42:19 +00:00
|
|
|
return depth_pt
|
|
|
|
|
|
|
|
|
2023-12-12 06:29:36 +00:00
|
|
|
def create_normal_map(img: "Tensor") -> "Tensor":
|
2023-05-01 04:57:39 +00:00
|
|
|
import torch
|
2023-05-06 19:24:31 +00:00
|
|
|
|
|
|
|
from imaginairy.vendored.imaginairy_normal_map.model import (
|
|
|
|
create_normal_map_torch_img,
|
|
|
|
)
|
2023-05-01 04:57:39 +00:00
|
|
|
|
|
|
|
normal_img_t = create_normal_map_torch_img(img)
|
|
|
|
normal_img_t -= torch.min(normal_img_t)
|
|
|
|
normal_img_t /= torch.max(normal_img_t)
|
|
|
|
|
|
|
|
return normal_img_t
|
|
|
|
|
|
|
|
|
2023-12-12 06:29:36 +00:00
|
|
|
def create_hed_edges(img_t: "Tensor") -> "Tensor":
|
2023-02-12 07:42:19 +00:00
|
|
|
import torch
|
|
|
|
|
|
|
|
from imaginairy.img_processors.hed_boundary import create_hed_map
|
|
|
|
from imaginairy.utils import get_device
|
|
|
|
|
|
|
|
img_t = img_t.to(get_device())
|
|
|
|
# rgb to bgr
|
|
|
|
img_t = img_t[:, [2, 1, 0], :, :]
|
|
|
|
|
|
|
|
hint_t = create_hed_map(img_t)
|
|
|
|
hint_t = hint_t.unsqueeze(0)
|
|
|
|
hint_t = torch.cat([hint_t, hint_t, hint_t], dim=0)
|
|
|
|
|
|
|
|
hint_t -= torch.min(hint_t)
|
|
|
|
hint_t /= torch.max(hint_t)
|
|
|
|
hint_t = (hint_t * 255).clip(0, 255).to(dtype=torch.uint8).float() / 255.0
|
|
|
|
|
|
|
|
hint_t = hint_t.unsqueeze(0)
|
|
|
|
# hint_t = hint_t[:, [2, 0, 1], :, :]
|
|
|
|
return hint_t
|
|
|
|
|
|
|
|
|
2023-12-12 06:29:36 +00:00
|
|
|
def create_pose_map(img_t: "Tensor"):
|
2023-02-12 07:42:19 +00:00
|
|
|
from imaginairy.img_processors.openpose import create_body_pose_img
|
|
|
|
from imaginairy.utils import get_device
|
|
|
|
|
|
|
|
img_t = img_t.to(get_device())
|
|
|
|
pose_t = create_body_pose_img(img_t) / 255
|
|
|
|
# pose_t = pose_t[:, [2, 1, 0], :, :]
|
|
|
|
return pose_t
|
|
|
|
|
|
|
|
|
2023-12-12 06:29:36 +00:00
|
|
|
def make_noise_disk(H: int, W: int, C: int, F: int) -> "np.ndarray":
|
2023-05-05 07:29:43 +00:00
|
|
|
import cv2
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C))
|
|
|
|
noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC)
|
|
|
|
noise = noise[F : F + H, F : F + W]
|
|
|
|
noise -= np.min(noise)
|
|
|
|
noise /= np.max(noise)
|
|
|
|
if C == 1:
|
|
|
|
noise = noise[:, :, None]
|
|
|
|
return noise
|
|
|
|
|
|
|
|
|
2023-12-12 06:29:36 +00:00
|
|
|
def shuffle_map_np(img: "np.ndarray", h=None, w=None, f=256) -> "np.ndarray":
|
2023-05-05 07:29:43 +00:00
|
|
|
import cv2
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
H, W, C = img.shape
|
|
|
|
if h is None:
|
|
|
|
h = H
|
|
|
|
if w is None:
|
|
|
|
w = W
|
|
|
|
|
|
|
|
x = make_noise_disk(h, w, 1, f) * float(W - 1)
|
|
|
|
y = make_noise_disk(h, w, 1, f) * float(H - 1)
|
|
|
|
flow = np.concatenate([x, y], axis=2).astype(np.float32)
|
|
|
|
return cv2.remap(img, flow, None, cv2.INTER_LINEAR)
|
|
|
|
|
|
|
|
|
2023-12-12 06:29:36 +00:00
|
|
|
def shuffle_map_torch(tensor: "Tensor", h=None, w=None, f=256) -> "Tensor":
|
2023-05-05 07:29:43 +00:00
|
|
|
import torch
|
|
|
|
|
|
|
|
# Assuming the input tensor is in shape (B, C, H, W)
|
|
|
|
B, C, H, W = tensor.shape
|
|
|
|
device = tensor.device
|
|
|
|
tensor = tensor.cpu()
|
|
|
|
|
|
|
|
# Create an empty tensor with the same shape as input tensor to store the shuffled images
|
|
|
|
shuffled_tensor = torch.empty_like(tensor)
|
|
|
|
|
|
|
|
# Iterate over the batch and apply the shuffle_map function to each image
|
|
|
|
for b in range(B):
|
|
|
|
# Convert the input torch tensor to a numpy array
|
|
|
|
img_np = tensor[b].numpy().transpose(1, 2, 0) # Shape (H, W, C)
|
|
|
|
|
|
|
|
# Call the shuffle_map function with the numpy array as input
|
|
|
|
shuffled_np = shuffle_map_np(img_np, h, w, f)
|
|
|
|
|
|
|
|
# Convert the shuffled numpy array back to a torch tensor and store it in the shuffled_tensor
|
|
|
|
shuffled_tensor[b] = torch.from_numpy(
|
|
|
|
shuffled_np.transpose(2, 0, 1)
|
|
|
|
) # Shape (C, H, W)
|
|
|
|
shuffled_tensor = (shuffled_tensor + 1.0) / 2.0
|
|
|
|
return shuffled_tensor.to(device)
|
|
|
|
|
|
|
|
|
2023-12-12 06:29:36 +00:00
|
|
|
def inpaint_prep(mask_image_t: "Tensor", target_image_t: "Tensor") -> "Tensor":
|
2023-05-19 09:44:28 +00:00
|
|
|
"""
|
|
|
|
Combines the masked image and target image into a single tensor.
|
|
|
|
|
|
|
|
The output tensor has any masked areas set to -1 and other pixel values set between 0 and 1.
|
|
|
|
|
|
|
|
mask_image_t is a 3-channel torch tensor of shape (B, C, H, W) with pixel values in range [-1, 1], where -1 indicates masked areas
|
|
|
|
target_image_t is a 3-channel torch tensor of shape (B, C, H, W) with pixel values in range [-1, 1]
|
|
|
|
"""
|
|
|
|
import torch
|
|
|
|
|
|
|
|
# Normalize target_image_t from [-1,1] to [0,1]
|
|
|
|
target_image_t = (target_image_t + 1.0) / 2.0
|
|
|
|
|
|
|
|
# Use mask_image_t to replace masked areas in target_image_t with -1
|
|
|
|
output_image_t = torch.where(mask_image_t == -1, mask_image_t, target_image_t)
|
|
|
|
|
|
|
|
return output_image_t
|
|
|
|
|
|
|
|
|
2023-12-12 06:29:36 +00:00
|
|
|
def to_grayscale(img: "Tensor") -> "Tensor":
|
2023-05-20 18:33:43 +00:00
|
|
|
# The dimensions of input should be (batch_size, channels, height, width)
|
2023-09-29 08:13:50 +00:00
|
|
|
if img.dim() != 4:
|
|
|
|
raise ValueError("Input should be a 4d tensor")
|
|
|
|
if img.size(1) != 3:
|
|
|
|
raise ValueError("Input should have 3 channels")
|
2023-05-20 18:33:43 +00:00
|
|
|
|
|
|
|
# Apply the formula to convert to grayscale.
|
|
|
|
gray = (
|
|
|
|
0.2989 * img[:, 0, :, :] + 0.5870 * img[:, 1, :, :] + 0.1140 * img[:, 2, :, :]
|
|
|
|
)
|
|
|
|
|
|
|
|
# Expand the dimensions so it's a 1-channel image.
|
|
|
|
gray = gray.unsqueeze(1)
|
|
|
|
|
|
|
|
# Duplicate the single channel to have 3 identical channels
|
|
|
|
gray_3_channels = gray.repeat(1, 3, 1, 1)
|
|
|
|
|
|
|
|
return (gray_3_channels + 1.0) / 2.0
|
|
|
|
|
|
|
|
|
2023-12-12 06:29:36 +00:00
|
|
|
def noop(img: "Tensor") -> "Tensor":
|
2023-05-05 08:21:29 +00:00
|
|
|
return (img + 1.0) / 2.0
|
2023-05-05 07:29:43 +00:00
|
|
|
|
|
|
|
|
2023-12-12 06:29:36 +00:00
|
|
|
FunctionType = Union["Callable[[Tensor, Tensor], Tensor]", "Callable[[Tensor], Tensor]"]
|
|
|
|
|
2023-12-05 23:34:36 +00:00
|
|
|
|
|
|
|
def adaptive_threshold_binarize(img: "Tensor") -> "Tensor":
|
|
|
|
"""
|
|
|
|
Use adaptive thresholding to binarize the image.
|
|
|
|
|
|
|
|
Using OpenCV for adaptive thresholding as it provides robust and efficient implementation.
|
|
|
|
The output tensor will have values between 0 and 1.
|
|
|
|
"""
|
|
|
|
import cv2
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
|
|
|
|
from imaginairy.utils import get_device
|
|
|
|
|
|
|
|
if img.dim() != 4:
|
|
|
|
raise ValueError("Input should be a 4d tensor")
|
|
|
|
if img.size(1) != 3:
|
|
|
|
raise ValueError("Input should have 3 channels")
|
|
|
|
|
|
|
|
if not torch.all((img >= -1) & (img <= 1)):
|
|
|
|
raise ValueError("All tensor values must be between -1 and 1")
|
|
|
|
|
|
|
|
normalized = (img + 1) / 2
|
|
|
|
|
|
|
|
# returns img if it is already grayscale
|
|
|
|
if torch.allclose(
|
|
|
|
normalized[:, 0, :, :], normalized[:, 1, :, :]
|
|
|
|
) and torch.allclose(normalized[:, 1, :, :], normalized[:, 2, :, :]):
|
|
|
|
return normalized
|
|
|
|
|
|
|
|
# grayscale = normalized.mean(dim=1, keepdim=True)
|
|
|
|
grayscale = to_grayscale(img)
|
|
|
|
grayscale = grayscale[:, 0:1, :, :]
|
|
|
|
|
2023-12-17 06:08:19 +00:00
|
|
|
grayscale_np = grayscale.squeeze(1).to("cpu").numpy()
|
2023-12-05 23:34:36 +00:00
|
|
|
|
|
|
|
blockSize = 129
|
|
|
|
C = 2
|
|
|
|
for i in range(grayscale_np.shape[0]):
|
|
|
|
grayscale_np[i] = cv2.adaptiveThreshold(
|
|
|
|
(grayscale_np[i] * 255).astype(np.uint8),
|
|
|
|
255,
|
|
|
|
cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
|
|
|
|
cv2.THRESH_BINARY,
|
|
|
|
blockSize,
|
|
|
|
C,
|
|
|
|
)
|
|
|
|
|
|
|
|
grayscale_np = grayscale_np / 255
|
|
|
|
|
|
|
|
binary = torch.from_numpy(grayscale_np).unsqueeze(1).to(get_device()).float()
|
|
|
|
|
|
|
|
return binary.repeat(1, 3, 1, 1)
|
|
|
|
|
|
|
|
|
2023-12-12 06:29:36 +00:00
|
|
|
CONTROL_MODES: Dict[str, FunctionType] = {
|
2023-02-12 07:42:19 +00:00
|
|
|
"canny": create_canny_edges,
|
|
|
|
"depth": create_depth_map,
|
|
|
|
"normal": create_normal_map,
|
|
|
|
"hed": create_hed_edges,
|
|
|
|
# "mlsd": create_mlsd_edges,
|
|
|
|
"openpose": create_pose_map,
|
|
|
|
# "scribble": None,
|
2023-05-05 07:29:43 +00:00
|
|
|
"shuffle": shuffle_map_torch,
|
2023-05-05 08:21:29 +00:00
|
|
|
"edit": noop,
|
2023-05-19 09:44:28 +00:00
|
|
|
"inpaint": inpaint_prep,
|
2023-12-18 21:12:27 +00:00
|
|
|
"details": noop,
|
2023-05-20 18:33:43 +00:00
|
|
|
"colorize": to_grayscale,
|
2023-12-05 23:34:36 +00:00
|
|
|
"qrcode": adaptive_threshold_binarize,
|
2023-02-12 07:42:19 +00:00
|
|
|
}
|