feature: controlnet
@ -0,0 +1,80 @@
|
||||
model:
|
||||
target: imaginairy.modules.cldm.ControlLDM
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "image"
|
||||
cond_stage_key: "txt"
|
||||
control_key: "hint"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false
|
||||
conditioning_key: crossattn
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: False
|
||||
only_mid_control: False
|
||||
|
||||
|
||||
unet_config:
|
||||
target: imaginairy.modules.cldm.ControlledUnetModel
|
||||
params:
|
||||
use_checkpoint: True
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: imaginairy.modules.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: imaginairy.modules.clip_embedders.FrozenCLIPEmbedder
|
||||
|
||||
control_stage_config:
|
||||
target: imaginairy.modules.cldm.ControlNet
|
||||
params:
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
hint_channels: 3
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
use_checkpoint: True
|
||||
legacy: False
|
@ -0,0 +1,157 @@
|
||||
"""Functions to create hint images for controlnet."""
|
||||
|
||||
|
||||
def create_canny_edges(img):
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from einops import einops
|
||||
|
||||
img = torch.clamp((img + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
img = einops.rearrange(img[0], "c h w -> h w c")
|
||||
img = (255.0 * img).cpu().numpy().astype(np.uint8).squeeze()
|
||||
blurred = cv2.GaussianBlur(img, (5, 5), 0).astype(np.uint8)
|
||||
|
||||
if len(blurred.shape) > 2:
|
||||
blurred = cv2.cvtColor(blurred, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
threshold2, _ = cv2.threshold(
|
||||
blurred, thresh=0, maxval=255, type=(cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
||||
)
|
||||
canny_image = cv2.Canny(
|
||||
blurred, threshold1=(threshold2 * 0.5), threshold2=threshold2
|
||||
)
|
||||
|
||||
# canny_image = cv2.Canny(blur, 100, 200)
|
||||
canny_image = canny_image[:, :, None]
|
||||
# controlnet requires three channels
|
||||
canny_image = np.concatenate([canny_image, canny_image, canny_image], axis=2)
|
||||
canny_image = torch.from_numpy(canny_image).to(dtype=torch.float32) / 255.0
|
||||
canny_image = einops.rearrange(canny_image, "h w c -> c h w").clone()
|
||||
canny_image = canny_image.unsqueeze(0)
|
||||
|
||||
return canny_image
|
||||
|
||||
|
||||
def create_depth_map(img):
|
||||
import torch
|
||||
|
||||
orig_size = img.shape[2:]
|
||||
|
||||
depth_pt = _create_depth_map_raw(img)
|
||||
# copy the depth map to the other channels
|
||||
depth_pt = torch.cat([depth_pt, depth_pt, depth_pt], dim=0)
|
||||
|
||||
depth_pt -= torch.min(depth_pt)
|
||||
depth_pt /= torch.max(depth_pt)
|
||||
depth_pt = depth_pt.unsqueeze(0)
|
||||
# depth_pt = depth_pt.cpu().numpy()
|
||||
|
||||
depth_pt = torch.nn.functional.interpolate(
|
||||
depth_pt,
|
||||
size=orig_size,
|
||||
mode="bilinear",
|
||||
)
|
||||
|
||||
return depth_pt
|
||||
|
||||
|
||||
def _create_depth_map_raw(img):
|
||||
import torch
|
||||
|
||||
from imaginairy.modules.midas.api import MiDaSInference, midas_device
|
||||
|
||||
model = MiDaSInference(model_type="dpt_hybrid").to(midas_device())
|
||||
img = img.to(midas_device())
|
||||
max_size = 512
|
||||
|
||||
# calculate new size such that image fits within 512x512 but keeps aspect ratio
|
||||
if img.shape[2] > img.shape[3]:
|
||||
new_size = (max_size, int(max_size * img.shape[3] / img.shape[2]))
|
||||
else:
|
||||
new_size = (int(max_size * img.shape[2] / img.shape[3]), max_size)
|
||||
|
||||
# resize torch image to be multiple of 32
|
||||
img = torch.nn.functional.interpolate(
|
||||
img,
|
||||
size=(new_size[0] // 32 * 32, new_size[1] // 32 * 32),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
depth_pt = model(img)[0] # noqa
|
||||
return depth_pt
|
||||
|
||||
|
||||
def create_normal_map(img):
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
depth = _create_depth_map_raw(img)
|
||||
depth = depth[0]
|
||||
|
||||
depth_pt = depth.clone()
|
||||
depth_pt -= torch.min(depth_pt)
|
||||
depth_pt /= torch.max(depth_pt)
|
||||
depth_pt = depth_pt.cpu().numpy()
|
||||
|
||||
bg_th = 0.1
|
||||
a = np.pi * 2.0
|
||||
depth_np = depth.cpu().float().numpy()
|
||||
x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3)
|
||||
y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3)
|
||||
z = np.ones_like(x) * a
|
||||
x[depth_pt < bg_th] = 0
|
||||
y[depth_pt < bg_th] = 0
|
||||
normal = np.stack([x, y, z], axis=2)
|
||||
normal /= np.sum(normal**2.0, axis=2, keepdims=True) ** 0.5
|
||||
normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8)
|
||||
|
||||
normal_image = torch.from_numpy(normal_image[:, :, ::-1].copy()).float() / 255.0
|
||||
normal_image = normal_image.permute(2, 0, 1).unsqueeze(0)
|
||||
return normal_image
|
||||
|
||||
|
||||
def create_hed_edges(img_t):
|
||||
import torch
|
||||
|
||||
from imaginairy.img_processors.hed_boundary import create_hed_map
|
||||
from imaginairy.utils import get_device
|
||||
|
||||
img_t = img_t.to(get_device())
|
||||
# rgb to bgr
|
||||
img_t = img_t[:, [2, 1, 0], :, :]
|
||||
|
||||
hint_t = create_hed_map(img_t)
|
||||
hint_t = hint_t.unsqueeze(0)
|
||||
hint_t = torch.cat([hint_t, hint_t, hint_t], dim=0)
|
||||
|
||||
hint_t -= torch.min(hint_t)
|
||||
hint_t /= torch.max(hint_t)
|
||||
hint_t = (hint_t * 255).clip(0, 255).to(dtype=torch.uint8).float() / 255.0
|
||||
|
||||
hint_t = hint_t.unsqueeze(0)
|
||||
# hint_t = hint_t[:, [2, 0, 1], :, :]
|
||||
return hint_t
|
||||
|
||||
|
||||
def create_pose_map(img_t):
|
||||
from imaginairy.img_processors.openpose import create_body_pose_img
|
||||
from imaginairy.utils import get_device
|
||||
|
||||
img_t = img_t.to(get_device())
|
||||
pose_t = create_body_pose_img(img_t) / 255
|
||||
# pose_t = pose_t[:, [2, 1, 0], :, :]
|
||||
return pose_t
|
||||
|
||||
|
||||
CONTROL_MODES = {
|
||||
"canny": create_canny_edges,
|
||||
"depth": create_depth_map,
|
||||
"normal": create_normal_map,
|
||||
"hed": create_hed_edges,
|
||||
# "mlsd": create_mlsd_edges,
|
||||
"openpose": create_pose_map,
|
||||
# "scribble": None,
|
||||
}
|
@ -0,0 +1,210 @@
|
||||
from functools import lru_cache
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from imaginairy.model_manager import get_cached_url_path
|
||||
from imaginairy.utils import get_device
|
||||
|
||||
|
||||
class Network(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.netVggOne = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(
|
||||
in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1
|
||||
),
|
||||
torch.nn.ReLU(inplace=False),
|
||||
torch.nn.Conv2d(
|
||||
in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1
|
||||
),
|
||||
torch.nn.ReLU(inplace=False),
|
||||
)
|
||||
|
||||
self.netVggTwo = torch.nn.Sequential(
|
||||
torch.nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
torch.nn.Conv2d(
|
||||
in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1
|
||||
),
|
||||
torch.nn.ReLU(inplace=False),
|
||||
torch.nn.Conv2d(
|
||||
in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1
|
||||
),
|
||||
torch.nn.ReLU(inplace=False),
|
||||
)
|
||||
|
||||
self.netVggThr = torch.nn.Sequential(
|
||||
torch.nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
torch.nn.Conv2d(
|
||||
in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1
|
||||
),
|
||||
torch.nn.ReLU(inplace=False),
|
||||
torch.nn.Conv2d(
|
||||
in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1
|
||||
),
|
||||
torch.nn.ReLU(inplace=False),
|
||||
torch.nn.Conv2d(
|
||||
in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1
|
||||
),
|
||||
torch.nn.ReLU(inplace=False),
|
||||
)
|
||||
|
||||
self.netVggFou = torch.nn.Sequential(
|
||||
torch.nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
torch.nn.Conv2d(
|
||||
in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1
|
||||
),
|
||||
torch.nn.ReLU(inplace=False),
|
||||
torch.nn.Conv2d(
|
||||
in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1
|
||||
),
|
||||
torch.nn.ReLU(inplace=False),
|
||||
torch.nn.Conv2d(
|
||||
in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1
|
||||
),
|
||||
torch.nn.ReLU(inplace=False),
|
||||
)
|
||||
|
||||
self.netVggFiv = torch.nn.Sequential(
|
||||
torch.nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
torch.nn.Conv2d(
|
||||
in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1
|
||||
),
|
||||
torch.nn.ReLU(inplace=False),
|
||||
torch.nn.Conv2d(
|
||||
in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1
|
||||
),
|
||||
torch.nn.ReLU(inplace=False),
|
||||
torch.nn.Conv2d(
|
||||
in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1
|
||||
),
|
||||
torch.nn.ReLU(inplace=False),
|
||||
)
|
||||
|
||||
self.netScoreOne = torch.nn.Conv2d(
|
||||
in_channels=64, out_channels=1, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.netScoreTwo = torch.nn.Conv2d(
|
||||
in_channels=128, out_channels=1, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.netScoreThr = torch.nn.Conv2d(
|
||||
in_channels=256, out_channels=1, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.netScoreFou = torch.nn.Conv2d(
|
||||
in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.netScoreFiv = torch.nn.Conv2d(
|
||||
in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
self.netCombine = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(
|
||||
in_channels=5, out_channels=1, kernel_size=1, stride=1, padding=0
|
||||
),
|
||||
torch.nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(self, img_t):
|
||||
img_t = (img_t + 1) * 127.5
|
||||
img_t = img_t - torch.tensor(
|
||||
data=[104.00698793, 116.66876762, 122.67891434],
|
||||
dtype=img_t.dtype,
|
||||
device=img_t.device,
|
||||
).view(1, 3, 1, 1)
|
||||
|
||||
ten_vgg_one = self.netVggOne(img_t)
|
||||
ten_vgg_two = self.netVggTwo(ten_vgg_one)
|
||||
ten_vgg_thr = self.netVggThr(ten_vgg_two)
|
||||
ten_vgg_fou = self.netVggFou(ten_vgg_thr)
|
||||
ten_vgg_fiv = self.netVggFiv(ten_vgg_fou)
|
||||
|
||||
ten_score_one = self.netScoreOne(ten_vgg_one)
|
||||
ten_score_two = self.netScoreTwo(ten_vgg_two)
|
||||
ten_score_thr = self.netScoreThr(ten_vgg_thr)
|
||||
ten_score_fou = self.netScoreFou(ten_vgg_fou)
|
||||
ten_score_fiv = self.netScoreFiv(ten_vgg_fiv)
|
||||
|
||||
ten_score_one = torch.nn.functional.interpolate(
|
||||
input=ten_score_one,
|
||||
size=(img_t.shape[2], img_t.shape[3]),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
)
|
||||
ten_score_two = torch.nn.functional.interpolate(
|
||||
input=ten_score_two,
|
||||
size=(img_t.shape[2], img_t.shape[3]),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
)
|
||||
ten_score_thr = torch.nn.functional.interpolate(
|
||||
input=ten_score_thr,
|
||||
size=(img_t.shape[2], img_t.shape[3]),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
)
|
||||
ten_score_fou = torch.nn.functional.interpolate(
|
||||
input=ten_score_fou,
|
||||
size=(img_t.shape[2], img_t.shape[3]),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
)
|
||||
ten_score_fiv = torch.nn.functional.interpolate(
|
||||
input=ten_score_fiv,
|
||||
size=(img_t.shape[2], img_t.shape[3]),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
return self.netCombine(
|
||||
torch.cat(
|
||||
[
|
||||
ten_score_one,
|
||||
ten_score_two,
|
||||
ten_score_thr,
|
||||
ten_score_fou,
|
||||
ten_score_fiv,
|
||||
],
|
||||
1,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def hed_model():
|
||||
model = Network().to(get_device()).eval()
|
||||
model_path = get_cached_url_path(
|
||||
"https://huggingface.co/lllyasviel/ControlNet/resolve/38a62cbf79862c1bac73405ec8dc46133aee3e36/annotator/ckpts/network-bsds500.pth"
|
||||
)
|
||||
state_dict = torch.load(model_path, map_location="cpu")
|
||||
state_dict = {k.replace("module", "net"): v for k, v in state_dict.items()}
|
||||
model.load_state_dict(state_dict)
|
||||
return model
|
||||
|
||||
|
||||
def create_hed_map(img_t):
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = hed_model().to(device)
|
||||
img_t = img_t.to(device)
|
||||
with torch.no_grad():
|
||||
edge = model(img_t)[0]
|
||||
return edge[0]
|
||||
|
||||
|
||||
def nms(x, t, s):
|
||||
"""make scribbles."""
|
||||
x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
|
||||
|
||||
f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
|
||||
f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
|
||||
f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
|
||||
f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
|
||||
|
||||
y = np.zeros_like(x)
|
||||
|
||||
for f in [f1, f2, f3, f4]:
|
||||
np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
|
||||
|
||||
z = np.zeros_like(y, dtype=np.uint8)
|
||||
z[y > t] = 255
|
||||
return z
|
@ -0,0 +1,836 @@
|
||||
import math
|
||||
from collections import OrderedDict
|
||||
from functools import lru_cache
|
||||
|
||||
import cv2
|
||||
import matplotlib
|
||||
import numpy as np
|
||||
import torch
|
||||
from scipy.ndimage.filters import gaussian_filter
|
||||
from torch import nn
|
||||
|
||||
from imaginairy.img_utils import torch_image_to_openvcv_img
|
||||
from imaginairy.model_manager import get_cached_url_path
|
||||
from imaginairy.utils import get_device
|
||||
|
||||
|
||||
def pad_right_down_corner(img, stride, padValue):
|
||||
h = img.shape[0]
|
||||
w = img.shape[1]
|
||||
|
||||
pad = 4 * [None]
|
||||
pad[0] = 0 # up
|
||||
pad[1] = 0 # left
|
||||
pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down
|
||||
pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right
|
||||
|
||||
img_padded = img
|
||||
pad_up = np.tile(img_padded[0:1, :, :] * 0 + padValue, (pad[0], 1, 1))
|
||||
img_padded = np.concatenate((pad_up, img_padded), axis=0)
|
||||
pad_left = np.tile(img_padded[:, 0:1, :] * 0 + padValue, (1, pad[1], 1))
|
||||
img_padded = np.concatenate((pad_left, img_padded), axis=1)
|
||||
pad_down = np.tile(img_padded[-2:-1, :, :] * 0 + padValue, (pad[2], 1, 1))
|
||||
img_padded = np.concatenate((img_padded, pad_down), axis=0)
|
||||
pad_right = np.tile(img_padded[:, -2:-1, :] * 0 + padValue, (1, pad[3], 1))
|
||||
img_padded = np.concatenate((img_padded, pad_right), axis=1)
|
||||
|
||||
return img_padded, pad
|
||||
|
||||
|
||||
def transfer(model, model_weights):
|
||||
# transfer caffe model to pytorch which will match the layer name
|
||||
transfered_model_weights = {}
|
||||
for weights_name in model.state_dict().keys():
|
||||
transfered_model_weights[weights_name] = model_weights[
|
||||
".".join(weights_name.split(".")[1:])
|
||||
]
|
||||
return transfered_model_weights
|
||||
|
||||
|
||||
# draw the body keypoint and lims
|
||||
def draw_bodypose(canvas, candidate, subset):
|
||||
stickwidth = 4
|
||||
limbSeq = [
|
||||
[2, 3],
|
||||
[2, 6],
|
||||
[3, 4],
|
||||
[4, 5],
|
||||
[6, 7],
|
||||
[7, 8],
|
||||
[2, 9],
|
||||
[9, 10],
|
||||
[10, 11],
|
||||
[2, 12],
|
||||
[12, 13],
|
||||
[13, 14],
|
||||
[2, 1],
|
||||
[1, 15],
|
||||
[15, 17],
|
||||
[1, 16],
|
||||
[16, 18],
|
||||
[3, 17],
|
||||
[6, 18],
|
||||
]
|
||||
|
||||
colors = [
|
||||
[255, 0, 0],
|
||||
[255, 85, 0],
|
||||
[255, 170, 0],
|
||||
[255, 255, 0],
|
||||
[170, 255, 0],
|
||||
[85, 255, 0],
|
||||
[0, 255, 0],
|
||||
[0, 255, 85],
|
||||
[0, 255, 170],
|
||||
[0, 255, 255],
|
||||
[0, 170, 255],
|
||||
[0, 85, 255],
|
||||
[0, 0, 255],
|
||||
[85, 0, 255],
|
||||
[170, 0, 255],
|
||||
[255, 0, 255],
|
||||
[255, 0, 170],
|
||||
[255, 0, 85],
|
||||
]
|
||||
for i in range(18):
|
||||
for n in range(len(subset)): # noqa
|
||||
index = int(subset[n][i])
|
||||
if index == -1:
|
||||
continue
|
||||
x, y = candidate[index][0:2]
|
||||
cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1)
|
||||
for i in range(17):
|
||||
for n in range(len(subset)): # noqa
|
||||
index = subset[n][np.array(limbSeq[i]) - 1]
|
||||
if -1 in index:
|
||||
continue
|
||||
cur_canvas = canvas.copy()
|
||||
Y = candidate[index.astype(int), 0]
|
||||
X = candidate[index.astype(int), 1]
|
||||
mX = np.mean(X)
|
||||
mY = np.mean(Y)
|
||||
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
|
||||
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
|
||||
polygon = cv2.ellipse2Poly(
|
||||
(int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1
|
||||
)
|
||||
cv2.fillConvexPoly(cur_canvas, polygon, colors[i])
|
||||
canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
|
||||
# plt.imsave("preview.jpg", canvas[:, :, [2, 1, 0]])
|
||||
# plt.imshow(canvas[:, :, [2, 1, 0]])
|
||||
return canvas
|
||||
|
||||
|
||||
# image drawed by opencv is not good.
|
||||
def draw_handpose(canvas, all_hand_peaks, show_number=False):
|
||||
edges = [
|
||||
[0, 1],
|
||||
[1, 2],
|
||||
[2, 3],
|
||||
[3, 4],
|
||||
[0, 5],
|
||||
[5, 6],
|
||||
[6, 7],
|
||||
[7, 8],
|
||||
[0, 9],
|
||||
[9, 10],
|
||||
[10, 11],
|
||||
[11, 12],
|
||||
[0, 13],
|
||||
[13, 14],
|
||||
[14, 15],
|
||||
[15, 16],
|
||||
[0, 17],
|
||||
[17, 18],
|
||||
[18, 19],
|
||||
[19, 20],
|
||||
]
|
||||
|
||||
for peaks in all_hand_peaks:
|
||||
for ie, e in enumerate(edges):
|
||||
if np.sum(np.all(peaks[e], axis=1) == 0) == 0:
|
||||
x1, y1 = peaks[e[0]]
|
||||
x2, y2 = peaks[e[1]]
|
||||
cv2.line(
|
||||
canvas,
|
||||
(x1, y1),
|
||||
(x2, y2),
|
||||
matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0])
|
||||
* 255,
|
||||
thickness=2,
|
||||
)
|
||||
|
||||
for i, keyponit in enumerate(peaks):
|
||||
x, y = keyponit
|
||||
cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
|
||||
if show_number:
|
||||
cv2.putText(
|
||||
canvas,
|
||||
str(i),
|
||||
(x, y),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.3,
|
||||
(0, 0, 0),
|
||||
lineType=cv2.LINE_AA,
|
||||
)
|
||||
return canvas
|
||||
|
||||
|
||||
# detect hand according to body pose keypoints
|
||||
# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp
|
||||
def handDetect(candidate, subset, oriImg):
|
||||
# right hand: wrist 4, elbow 3, shoulder 2
|
||||
# left hand: wrist 7, elbow 6, shoulder 5
|
||||
ratioWristElbow = 0.33
|
||||
detect_result = []
|
||||
image_height, image_width = oriImg.shape[0:2]
|
||||
for person in subset.astype(int):
|
||||
# if any of three not detected
|
||||
has_left = np.sum(person[[5, 6, 7]] == -1) == 0
|
||||
has_right = np.sum(person[[2, 3, 4]] == -1) == 0
|
||||
if not (has_left or has_right):
|
||||
continue
|
||||
hands = []
|
||||
# left hand
|
||||
if has_left:
|
||||
left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]]
|
||||
x1, y1 = candidate[left_shoulder_index][:2]
|
||||
x2, y2 = candidate[left_elbow_index][:2]
|
||||
x3, y3 = candidate[left_wrist_index][:2]
|
||||
hands.append([x1, y1, x2, y2, x3, y3, True])
|
||||
# right hand
|
||||
if has_right:
|
||||
right_shoulder_index, right_elbow_index, right_wrist_index = person[
|
||||
[2, 3, 4]
|
||||
]
|
||||
x1, y1 = candidate[right_shoulder_index][:2]
|
||||
x2, y2 = candidate[right_elbow_index][:2]
|
||||
x3, y3 = candidate[right_wrist_index][:2]
|
||||
hands.append([x1, y1, x2, y2, x3, y3, False])
|
||||
|
||||
for x1, y1, x2, y2, x3, y3, is_left in hands:
|
||||
# pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox
|
||||
# handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]);
|
||||
# handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]);
|
||||
# const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow);
|
||||
# const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder);
|
||||
# handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder);
|
||||
x = x3 + ratioWristElbow * (x3 - x2)
|
||||
y = y3 + ratioWristElbow * (y3 - y2)
|
||||
distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2)
|
||||
distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
|
||||
width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder)
|
||||
# x-y refers to the center --> offset to topLeft point
|
||||
# handRectangle.x -= handRectangle.width / 2.f;
|
||||
# handRectangle.y -= handRectangle.height / 2.f;
|
||||
x -= width / 2
|
||||
y -= width / 2 # width = height
|
||||
# overflow the image
|
||||
|
||||
x = max(x, 0)
|
||||
y = max(y, 0)
|
||||
|
||||
width1 = width
|
||||
width2 = width
|
||||
if x + width > image_width:
|
||||
width1 = image_width - x
|
||||
if y + width > image_height:
|
||||
width2 = image_height - y
|
||||
width = min(width1, width2)
|
||||
# the max hand box value is 20 pixels
|
||||
if width >= 20:
|
||||
detect_result.append([int(x), int(y), int(width), is_left])
|
||||
|
||||
# return value: [[x, y, w, True if left hand else False]].
|
||||
# width=height since the network require squared input.
|
||||
# x, y is the coordinate of top left
|
||||
return detect_result
|
||||
|
||||
|
||||
# get max index of 2d array
|
||||
def npmax(array):
|
||||
arrayindex = array.argmax(1)
|
||||
arrayvalue = array.max(1)
|
||||
i = arrayvalue.argmax()
|
||||
j = arrayindex[i]
|
||||
return i, j
|
||||
|
||||
|
||||
def make_layers(block, no_relu_layers):
|
||||
layers = []
|
||||
for layer_name, v in block.items():
|
||||
if "pool" in layer_name:
|
||||
layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1], padding=v[2])
|
||||
layers.append((layer_name, layer))
|
||||
else:
|
||||
conv2d = nn.Conv2d(
|
||||
in_channels=v[0],
|
||||
out_channels=v[1],
|
||||
kernel_size=v[2],
|
||||
stride=v[3],
|
||||
padding=v[4],
|
||||
)
|
||||
layers.append((layer_name, conv2d))
|
||||
if layer_name not in no_relu_layers:
|
||||
layers.append(("relu_" + layer_name, nn.ReLU(inplace=True)))
|
||||
|
||||
return nn.Sequential(OrderedDict(layers))
|
||||
|
||||
|
||||
class bodypose_model(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
# these layers have no relu layer
|
||||
no_relu_layers = [
|
||||
"conv5_5_CPM_L1",
|
||||
"conv5_5_CPM_L2",
|
||||
"Mconv7_stage2_L1",
|
||||
"Mconv7_stage2_L2",
|
||||
"Mconv7_stage3_L1",
|
||||
"Mconv7_stage3_L2",
|
||||
"Mconv7_stage4_L1",
|
||||
"Mconv7_stage4_L2",
|
||||
"Mconv7_stage5_L1",
|
||||
"Mconv7_stage5_L2",
|
||||
"Mconv7_stage6_L1",
|
||||
"Mconv7_stage6_L1",
|
||||
]
|
||||
blocks = {}
|
||||
block0 = OrderedDict(
|
||||
[
|
||||
("conv1_1", [3, 64, 3, 1, 1]),
|
||||
("conv1_2", [64, 64, 3, 1, 1]),
|
||||
("pool1_stage1", [2, 2, 0]),
|
||||
("conv2_1", [64, 128, 3, 1, 1]),
|
||||
("conv2_2", [128, 128, 3, 1, 1]),
|
||||
("pool2_stage1", [2, 2, 0]),
|
||||
("conv3_1", [128, 256, 3, 1, 1]),
|
||||
("conv3_2", [256, 256, 3, 1, 1]),
|
||||
("conv3_3", [256, 256, 3, 1, 1]),
|
||||
("conv3_4", [256, 256, 3, 1, 1]),
|
||||
("pool3_stage1", [2, 2, 0]),
|
||||
("conv4_1", [256, 512, 3, 1, 1]),
|
||||
("conv4_2", [512, 512, 3, 1, 1]),
|
||||
("conv4_3_CPM", [512, 256, 3, 1, 1]),
|
||||
("conv4_4_CPM", [256, 128, 3, 1, 1]),
|
||||
]
|
||||
)
|
||||
|
||||
# Stage 1
|
||||
block1_1 = OrderedDict(
|
||||
[
|
||||
("conv5_1_CPM_L1", [128, 128, 3, 1, 1]),
|
||||
("conv5_2_CPM_L1", [128, 128, 3, 1, 1]),
|
||||
("conv5_3_CPM_L1", [128, 128, 3, 1, 1]),
|
||||
("conv5_4_CPM_L1", [128, 512, 1, 1, 0]),
|
||||
("conv5_5_CPM_L1", [512, 38, 1, 1, 0]),
|
||||
]
|
||||
)
|
||||
|
||||
block1_2 = OrderedDict(
|
||||
[
|
||||
("conv5_1_CPM_L2", [128, 128, 3, 1, 1]),
|
||||
("conv5_2_CPM_L2", [128, 128, 3, 1, 1]),
|
||||
("conv5_3_CPM_L2", [128, 128, 3, 1, 1]),
|
||||
("conv5_4_CPM_L2", [128, 512, 1, 1, 0]),
|
||||
("conv5_5_CPM_L2", [512, 19, 1, 1, 0]),
|
||||
]
|
||||
)
|
||||
blocks["block1_1"] = block1_1
|
||||
blocks["block1_2"] = block1_2
|
||||
|
||||
self.model0 = make_layers(block0, no_relu_layers)
|
||||
|
||||
# Stages 2 - 6
|
||||
for i in range(2, 7):
|
||||
blocks[f"block{i}_1"] = OrderedDict(
|
||||
[
|
||||
(f"Mconv1_stage{i}_L1", [185, 128, 7, 1, 3]),
|
||||
(f"Mconv2_stage{i}_L1", [128, 128, 7, 1, 3]),
|
||||
(f"Mconv3_stage{i}_L1", [128, 128, 7, 1, 3]),
|
||||
(f"Mconv4_stage{i}_L1", [128, 128, 7, 1, 3]),
|
||||
(f"Mconv5_stage{i}_L1", [128, 128, 7, 1, 3]),
|
||||
(f"Mconv6_stage{i}_L1", [128, 128, 1, 1, 0]),
|
||||
(f"Mconv7_stage{i}_L1", [128, 38, 1, 1, 0]),
|
||||
]
|
||||
)
|
||||
|
||||
blocks[f"block{i}_2"] = OrderedDict(
|
||||
[
|
||||
(f"Mconv1_stage{i}_L2", [185, 128, 7, 1, 3]),
|
||||
(f"Mconv2_stage{i}_L2", [128, 128, 7, 1, 3]),
|
||||
(f"Mconv3_stage{i}_L2", [128, 128, 7, 1, 3]),
|
||||
(f"Mconv4_stage{i}_L2", [128, 128, 7, 1, 3]),
|
||||
(f"Mconv5_stage{i}_L2", [128, 128, 7, 1, 3]),
|
||||
(f"Mconv6_stage{i}_L2", [128, 128, 1, 1, 0]),
|
||||
(f"Mconv7_stage{i}_L2", [128, 19, 1, 1, 0]),
|
||||
]
|
||||
)
|
||||
|
||||
for k in blocks.keys():
|
||||
blocks[k] = make_layers(blocks[k], no_relu_layers)
|
||||
|
||||
self.model1_1 = blocks["block1_1"]
|
||||
self.model2_1 = blocks["block2_1"]
|
||||
self.model3_1 = blocks["block3_1"]
|
||||
self.model4_1 = blocks["block4_1"]
|
||||
self.model5_1 = blocks["block5_1"]
|
||||
self.model6_1 = blocks["block6_1"]
|
||||
|
||||
self.model1_2 = blocks["block1_2"]
|
||||
self.model2_2 = blocks["block2_2"]
|
||||
self.model3_2 = blocks["block3_2"]
|
||||
self.model4_2 = blocks["block4_2"]
|
||||
self.model5_2 = blocks["block5_2"]
|
||||
self.model6_2 = blocks["block6_2"]
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
out1 = self.model0(x)
|
||||
|
||||
out1_1 = self.model1_1(out1)
|
||||
out1_2 = self.model1_2(out1)
|
||||
out2 = torch.cat([out1_1, out1_2, out1], 1)
|
||||
|
||||
out2_1 = self.model2_1(out2)
|
||||
out2_2 = self.model2_2(out2)
|
||||
out3 = torch.cat([out2_1, out2_2, out1], 1)
|
||||
|
||||
out3_1 = self.model3_1(out3)
|
||||
out3_2 = self.model3_2(out3)
|
||||
out4 = torch.cat([out3_1, out3_2, out1], 1)
|
||||
|
||||
out4_1 = self.model4_1(out4)
|
||||
out4_2 = self.model4_2(out4)
|
||||
out5 = torch.cat([out4_1, out4_2, out1], 1)
|
||||
|
||||
out5_1 = self.model5_1(out5)
|
||||
out5_2 = self.model5_2(out5)
|
||||
out6 = torch.cat([out5_1, out5_2, out1], 1)
|
||||
|
||||
out6_1 = self.model6_1(out6)
|
||||
out6_2 = self.model6_2(out6)
|
||||
|
||||
return out6_1, out6_2
|
||||
|
||||
|
||||
class handpose_model(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
# these layers have no relu layer
|
||||
no_relu_layers = [
|
||||
"conv6_2_CPM",
|
||||
"Mconv7_stage2",
|
||||
"Mconv7_stage3",
|
||||
"Mconv7_stage4",
|
||||
"Mconv7_stage5",
|
||||
"Mconv7_stage6",
|
||||
]
|
||||
# stage 1
|
||||
block1_0 = OrderedDict(
|
||||
[
|
||||
("conv1_1", [3, 64, 3, 1, 1]),
|
||||
("conv1_2", [64, 64, 3, 1, 1]),
|
||||
("pool1_stage1", [2, 2, 0]),
|
||||
("conv2_1", [64, 128, 3, 1, 1]),
|
||||
("conv2_2", [128, 128, 3, 1, 1]),
|
||||
("pool2_stage1", [2, 2, 0]),
|
||||
("conv3_1", [128, 256, 3, 1, 1]),
|
||||
("conv3_2", [256, 256, 3, 1, 1]),
|
||||
("conv3_3", [256, 256, 3, 1, 1]),
|
||||
("conv3_4", [256, 256, 3, 1, 1]),
|
||||
("pool3_stage1", [2, 2, 0]),
|
||||
("conv4_1", [256, 512, 3, 1, 1]),
|
||||
("conv4_2", [512, 512, 3, 1, 1]),
|
||||
("conv4_3", [512, 512, 3, 1, 1]),
|
||||
("conv4_4", [512, 512, 3, 1, 1]),
|
||||
("conv5_1", [512, 512, 3, 1, 1]),
|
||||
("conv5_2", [512, 512, 3, 1, 1]),
|
||||
("conv5_3_CPM", [512, 128, 3, 1, 1]),
|
||||
]
|
||||
)
|
||||
|
||||
block1_1 = OrderedDict(
|
||||
[("conv6_1_CPM", [128, 512, 1, 1, 0]), ("conv6_2_CPM", [512, 22, 1, 1, 0])]
|
||||
)
|
||||
|
||||
blocks = {}
|
||||
blocks["block1_0"] = block1_0
|
||||
blocks["block1_1"] = block1_1
|
||||
|
||||
# stage 2-6
|
||||
for i in range(2, 7):
|
||||
blocks[f"block{i}"] = OrderedDict(
|
||||
[
|
||||
(f"Mconv1_stage{i}", [150, 128, 7, 1, 3]),
|
||||
(f"Mconv2_stage{i}", [128, 128, 7, 1, 3]),
|
||||
(f"Mconv3_stage{i}", [128, 128, 7, 1, 3]),
|
||||
(f"Mconv4_stage{i}", [128, 128, 7, 1, 3]),
|
||||
(f"Mconv5_stage{i}", [128, 128, 7, 1, 3]),
|
||||
(f"Mconv6_stage{i}", [128, 128, 1, 1, 0]),
|
||||
(f"Mconv7_stage{i}", [128, 22, 1, 1, 0]),
|
||||
]
|
||||
)
|
||||
|
||||
for k in blocks.keys():
|
||||
blocks[k] = make_layers(blocks[k], no_relu_layers)
|
||||
|
||||
self.model1_0 = blocks["block1_0"]
|
||||
self.model1_1 = blocks["block1_1"]
|
||||
self.model2 = blocks["block2"]
|
||||
self.model3 = blocks["block3"]
|
||||
self.model4 = blocks["block4"]
|
||||
self.model5 = blocks["block5"]
|
||||
self.model6 = blocks["block6"]
|
||||
|
||||
def forward(self, x):
|
||||
out1_0 = self.model1_0(x)
|
||||
out1_1 = self.model1_1(out1_0)
|
||||
concat_stage2 = torch.cat([out1_1, out1_0], 1)
|
||||
out_stage2 = self.model2(concat_stage2)
|
||||
concat_stage3 = torch.cat([out_stage2, out1_0], 1)
|
||||
out_stage3 = self.model3(concat_stage3)
|
||||
concat_stage4 = torch.cat([out_stage3, out1_0], 1)
|
||||
out_stage4 = self.model4(concat_stage4)
|
||||
concat_stage5 = torch.cat([out_stage4, out1_0], 1)
|
||||
out_stage5 = self.model5(concat_stage5)
|
||||
concat_stage6 = torch.cat([out_stage5, out1_0], 1)
|
||||
out_stage6 = self.model6(concat_stage6)
|
||||
return out_stage6
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def openpose_model():
|
||||
model = bodypose_model()
|
||||
weights_url = "https://huggingface.co/lllyasviel/ControlNet/resolve/38a62cbf79862c1bac73405ec8dc46133aee3e36/annotator/ckpts/body_pose_model.pth"
|
||||
model_path = get_cached_url_path(weights_url)
|
||||
model_dict = transfer(model, torch.load(model_path))
|
||||
model.load_state_dict(model_dict)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def create_body_pose_img(original_img_t):
|
||||
candidate, subset = create_body_pose(original_img_t)
|
||||
canvas = np.zeros((original_img_t.shape[2], original_img_t.shape[3], 3))
|
||||
canvas = draw_bodypose(canvas, candidate, subset)
|
||||
canvas = torch.from_numpy(canvas).to(dtype=torch.float32)
|
||||
# canvas = canvas.unsqueeze(0)
|
||||
canvas = canvas.permute(2, 0, 1).unsqueeze(0)
|
||||
return canvas
|
||||
|
||||
|
||||
def create_body_pose(original_img_t):
|
||||
original_img = torch_image_to_openvcv_img(original_img_t)
|
||||
|
||||
model = openpose_model()
|
||||
# scale_search = [0.5, 1.0, 1.5, 2.0]
|
||||
scale_search = [0.5]
|
||||
boxsize = 368
|
||||
stride = 8
|
||||
padValue = 128
|
||||
thre1 = 0.1
|
||||
thre2 = 0.05
|
||||
multiplier = [x * boxsize / original_img.shape[0] for x in scale_search]
|
||||
heatmap_avg = np.zeros((original_img.shape[0], original_img.shape[1], 19))
|
||||
paf_avg = np.zeros((original_img.shape[0], original_img.shape[1], 38))
|
||||
|
||||
for m, scale in enumerate(multiplier):
|
||||
imageToTest = cv2.resize(
|
||||
original_img, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC
|
||||
)
|
||||
imageToTest_padded, pad = pad_right_down_corner(imageToTest, stride, padValue)
|
||||
im = (
|
||||
np.transpose(
|
||||
np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)
|
||||
)
|
||||
/ 256
|
||||
- 0.5
|
||||
)
|
||||
im = np.ascontiguousarray(im)
|
||||
|
||||
data = torch.from_numpy(im).float()
|
||||
data.to(get_device())
|
||||
|
||||
# data = data.permute([2, 0, 1]).unsqueeze(0).float()
|
||||
with torch.no_grad():
|
||||
Mconv7_stage6_L1, Mconv7_stage6_L2 = model(data)
|
||||
Mconv7_stage6_L1 = Mconv7_stage6_L1.cpu().numpy()
|
||||
Mconv7_stage6_L2 = Mconv7_stage6_L2.cpu().numpy()
|
||||
|
||||
# extract outputs, resize, and remove padding
|
||||
# heatmap = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[1]].data), (1, 2, 0)) # output 1 is heatmaps
|
||||
heatmap = np.transpose(
|
||||
np.squeeze(Mconv7_stage6_L2), (1, 2, 0)
|
||||
) # output 1 is heatmaps
|
||||
heatmap = cv2.resize(
|
||||
heatmap, (0, 0), fx=stride, fy=stride, interpolation=cv2.INTER_CUBIC
|
||||
)
|
||||
heatmap = heatmap[
|
||||
: imageToTest_padded.shape[0] - pad[2],
|
||||
: imageToTest_padded.shape[1] - pad[3],
|
||||
:,
|
||||
]
|
||||
heatmap = cv2.resize(
|
||||
heatmap,
|
||||
(original_img.shape[1], original_img.shape[0]),
|
||||
interpolation=cv2.INTER_CUBIC,
|
||||
)
|
||||
|
||||
# paf = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[0]].data), (1, 2, 0)) # output 0 is PAFs
|
||||
paf = np.transpose(np.squeeze(Mconv7_stage6_L1), (1, 2, 0)) # output 0 is PAFs
|
||||
paf = cv2.resize(
|
||||
paf, (0, 0), fx=stride, fy=stride, interpolation=cv2.INTER_CUBIC
|
||||
)
|
||||
paf = paf[
|
||||
: imageToTest_padded.shape[0] - pad[2],
|
||||
: imageToTest_padded.shape[1] - pad[3],
|
||||
:,
|
||||
]
|
||||
paf = cv2.resize(
|
||||
paf,
|
||||
(original_img.shape[1], original_img.shape[0]),
|
||||
interpolation=cv2.INTER_CUBIC,
|
||||
)
|
||||
|
||||
heatmap_avg += heatmap_avg + heatmap / len(multiplier)
|
||||
paf_avg += +paf / len(multiplier)
|
||||
|
||||
all_peaks = []
|
||||
peak_counter = 0
|
||||
|
||||
for part in range(18):
|
||||
map_ori = heatmap_avg[:, :, part]
|
||||
one_heatmap = gaussian_filter(map_ori, sigma=3)
|
||||
|
||||
map_left = np.zeros(one_heatmap.shape)
|
||||
map_left[1:, :] = one_heatmap[:-1, :]
|
||||
map_right = np.zeros(one_heatmap.shape)
|
||||
map_right[:-1, :] = one_heatmap[1:, :]
|
||||
map_up = np.zeros(one_heatmap.shape)
|
||||
map_up[:, 1:] = one_heatmap[:, :-1]
|
||||
map_down = np.zeros(one_heatmap.shape)
|
||||
map_down[:, :-1] = one_heatmap[:, 1:]
|
||||
|
||||
peaks_binary = np.logical_and.reduce(
|
||||
(
|
||||
one_heatmap >= map_left,
|
||||
one_heatmap >= map_right,
|
||||
one_heatmap >= map_up,
|
||||
one_heatmap >= map_down,
|
||||
one_heatmap > thre1,
|
||||
)
|
||||
)
|
||||
peaks = list(
|
||||
zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0])
|
||||
) # note reverse
|
||||
peaks_with_score = [x + (map_ori[x[1], x[0]],) for x in peaks]
|
||||
peak_id = range(peak_counter, peak_counter + len(peaks))
|
||||
peaks_with_score_and_id = [
|
||||
peaks_with_score[i] + (peak_id[i],) for i in range(len(peak_id))
|
||||
]
|
||||
|
||||
all_peaks.append(peaks_with_score_and_id)
|
||||
peak_counter += len(peaks)
|
||||
|
||||
# find connection in the specified sequence, center 29 is in the position 15
|
||||
limbSeq = [
|
||||
[2, 3],
|
||||
[2, 6],
|
||||
[3, 4],
|
||||
[4, 5],
|
||||
[6, 7],
|
||||
[7, 8],
|
||||
[2, 9],
|
||||
[9, 10],
|
||||
[10, 11],
|
||||
[2, 12],
|
||||
[12, 13],
|
||||
[13, 14],
|
||||
[2, 1],
|
||||
[1, 15],
|
||||
[15, 17],
|
||||
[1, 16],
|
||||
[16, 18],
|
||||
[3, 17],
|
||||
[6, 18],
|
||||
]
|
||||
# the middle joints heatmap correpondence
|
||||
mapIdx = [
|
||||
[31, 32],
|
||||
[39, 40],
|
||||
[33, 34],
|
||||
[35, 36],
|
||||
[41, 42],
|
||||
[43, 44],
|
||||
[19, 20],
|
||||
[21, 22],
|
||||
[23, 24],
|
||||
[25, 26],
|
||||
[27, 28],
|
||||
[29, 30],
|
||||
[47, 48],
|
||||
[49, 50],
|
||||
[53, 54],
|
||||
[51, 52],
|
||||
[55, 56],
|
||||
[37, 38],
|
||||
[45, 46],
|
||||
]
|
||||
|
||||
connection_all = []
|
||||
special_k = []
|
||||
mid_num = 10
|
||||
|
||||
for k in range(len(mapIdx)):
|
||||
score_mid = paf_avg[:, :, [x - 19 for x in mapIdx[k]]]
|
||||
candA = all_peaks[limbSeq[k][0] - 1]
|
||||
candB = all_peaks[limbSeq[k][1] - 1]
|
||||
nA = len(candA)
|
||||
nB = len(candB)
|
||||
indexA, indexB = limbSeq[k]
|
||||
if nA != 0 and nB != 0:
|
||||
connection_candidate = []
|
||||
for i in range(nA):
|
||||
for j in range(nB):
|
||||
vec = np.subtract(candB[j][:2], candA[i][:2])
|
||||
norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1])
|
||||
norm = max(0.001, norm)
|
||||
vec = np.divide(vec, norm)
|
||||
|
||||
startend = list(
|
||||
zip(
|
||||
np.linspace(candA[i][0], candB[j][0], num=mid_num),
|
||||
np.linspace(candA[i][1], candB[j][1], num=mid_num),
|
||||
)
|
||||
)
|
||||
|
||||
vec_x = np.array(
|
||||
[
|
||||
score_mid[
|
||||
int(round(startend[I][1])),
|
||||
int(round(startend[I][0])),
|
||||
0,
|
||||
]
|
||||
for I in range(len(startend)) # noqa
|
||||
]
|
||||
)
|
||||
vec_y = np.array(
|
||||
[
|
||||
score_mid[
|
||||
int(round(startend[I][1])),
|
||||
int(round(startend[I][0])),
|
||||
1,
|
||||
]
|
||||
for I in range(len(startend)) # noqa
|
||||
]
|
||||
)
|
||||
|
||||
score_midpts = np.multiply(vec_x, vec[0]) + np.multiply(
|
||||
vec_y, vec[1]
|
||||
)
|
||||
score_with_dist_prior = sum(score_midpts) / len(score_midpts) + min(
|
||||
0.5 * original_img.shape[0] / norm - 1, 0
|
||||
)
|
||||
criterion1 = len(np.nonzero(score_midpts > thre2)[0]) > 0.8 * len(
|
||||
score_midpts
|
||||
)
|
||||
criterion2 = score_with_dist_prior > 0
|
||||
if criterion1 and criterion2:
|
||||
connection_candidate.append(
|
||||
[
|
||||
i,
|
||||
j,
|
||||
score_with_dist_prior,
|
||||
score_with_dist_prior + candA[i][2] + candB[j][2],
|
||||
]
|
||||
)
|
||||
|
||||
connection_candidate = sorted(
|
||||
connection_candidate, key=lambda x: x[2], reverse=True
|
||||
)
|
||||
connection = np.zeros((0, 5))
|
||||
for c in range(len(connection_candidate)): # noqa
|
||||
i, j, s = connection_candidate[c][0:3]
|
||||
if i not in connection[:, 3] and j not in connection[:, 4]:
|
||||
connection = np.vstack(
|
||||
[connection, [candA[i][3], candB[j][3], s, i, j]]
|
||||
)
|
||||
if len(connection) >= min(nA, nB):
|
||||
break
|
||||
|
||||
connection_all.append(connection)
|
||||
else:
|
||||
special_k.append(k)
|
||||
connection_all.append([])
|
||||
|
||||
# last number in each row is the total parts number of that person
|
||||
# the second last number in each row is the score of the overall configuration
|
||||
subset = -1 * np.ones((0, 20))
|
||||
candidate = np.array([item for sublist in all_peaks for item in sublist])
|
||||
|
||||
for k in range(len(mapIdx)):
|
||||
if k not in special_k:
|
||||
partAs = connection_all[k][:, 0]
|
||||
partBs = connection_all[k][:, 1]
|
||||
indexA, indexB = np.array(limbSeq[k]) - 1
|
||||
|
||||
for i in range(len(connection_all[k])): # = 1:size(temp,1)
|
||||
found = 0
|
||||
subset_idx = [-1, -1]
|
||||
for j, row in enumerate(subset): # 1:size(subset,1):
|
||||
if row[indexA] == partAs[i] or row[indexB] == partBs[i]:
|
||||
subset_idx[found] = j
|
||||
found += 1
|
||||
|
||||
if found == 1:
|
||||
j = subset_idx[0]
|
||||
if subset[j][indexB] != partBs[i]:
|
||||
subset[j][indexB] = partBs[i]
|
||||
subset[j][-1] += 1
|
||||
subset[j][-2] += (
|
||||
candidate[partBs[i].astype(int), 2]
|
||||
+ connection_all[k][i][2]
|
||||
)
|
||||
elif found == 2: # if found 2 and disjoint, merge them
|
||||
j1, j2 = subset_idx
|
||||
membership = (
|
||||
(subset[j1] >= 0).astype(int) + (subset[j2] >= 0).astype(int)
|
||||
)[:-2]
|
||||
if len(np.nonzero(membership == 2)[0]) == 0: # merge
|
||||
subset[j1][:-2] += subset[j2][:-2] + 1
|
||||
subset[j1][-2:] += subset[j2][-2:]
|
||||
subset[j1][-2] += connection_all[k][i][2]
|
||||
subset = np.delete(subset, j2, 0)
|
||||
else: # as like found == 1
|
||||
subset[j1][indexB] = partBs[i]
|
||||
subset[j1][-1] += 1
|
||||
subset[j1][-2] += (
|
||||
candidate[partBs[i].astype(int), 2]
|
||||
+ connection_all[k][i][2]
|
||||
)
|
||||
|
||||
# if find no partA in the subset, create a new subset
|
||||
elif not found and k < 17:
|
||||
row = -1 * np.ones(20)
|
||||
row[indexA] = partAs[i]
|
||||
row[indexB] = partBs[i]
|
||||
row[-1] = 2
|
||||
row[-2] = (
|
||||
sum(candidate[connection_all[k][i, :2].astype(int), 2])
|
||||
+ connection_all[k][i][2]
|
||||
)
|
||||
subset = np.vstack([subset, row])
|
||||
# delete some rows of subset which has few parts occur
|
||||
deleteIdx = []
|
||||
|
||||
for i, s in enumerate(subset):
|
||||
if s[-1] < 4 or s[-2] / s[-1] < 0.4:
|
||||
deleteIdx.append(i)
|
||||
subset = np.delete(subset, deleteIdx, axis=0)
|
||||
|
||||
# subset: n*20 array, 0-17 is the index in candidate, 18 is the total score, 19 is the total parts
|
||||
# candidate: x, y, score, id
|
||||
return candidate, subset
|
@ -0,0 +1,529 @@
|
||||
import einops
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from torch import nn
|
||||
from torchvision.utils import make_grid
|
||||
|
||||
from imaginairy.modules.attention import SpatialTransformer
|
||||
from imaginairy.modules.diffusion.ddpm import LatentDiffusion, log_txt_as_img
|
||||
from imaginairy.modules.diffusion.openaimodel import (
|
||||
AttentionBlock,
|
||||
Downsample,
|
||||
ResBlock,
|
||||
TimestepEmbedSequential,
|
||||
UNetModel,
|
||||
)
|
||||
from imaginairy.modules.diffusion.util import (
|
||||
conv_nd,
|
||||
linear,
|
||||
timestep_embedding,
|
||||
zero_module,
|
||||
)
|
||||
from imaginairy.samplers import DDIMSampler
|
||||
from imaginairy.utils import instantiate_from_config
|
||||
|
||||
|
||||
class ControlledUnetModel(UNetModel):
|
||||
def forward( # noqa
|
||||
self,
|
||||
x,
|
||||
timesteps=None,
|
||||
context=None,
|
||||
control=None, # noqa
|
||||
only_mid_control=False,
|
||||
**kwargs,
|
||||
):
|
||||
hs = []
|
||||
with torch.no_grad():
|
||||
t_emb = timestep_embedding(
|
||||
timesteps, self.model_channels, repeat_only=False
|
||||
)
|
||||
emb = self.time_embed(t_emb)
|
||||
h = x.type(self.dtype)
|
||||
for module in self.input_blocks:
|
||||
h = module(h, emb, context)
|
||||
hs.append(h)
|
||||
h = self.middle_block(h, emb, context)
|
||||
ctrl = control.pop()
|
||||
h += ctrl
|
||||
|
||||
for i, module in enumerate(self.output_blocks):
|
||||
# allows us to work with multiples of 8 instead of just 32
|
||||
if h.shape[-2:] != hs[-1].shape[-2:]:
|
||||
h = nn.functional.interpolate(h, hs[-1].shape[-2:], mode="nearest")
|
||||
if only_mid_control:
|
||||
h = torch.cat([h, hs.pop()], dim=1)
|
||||
else:
|
||||
ctrl = control.pop()
|
||||
if ctrl.shape[-2:] != hs[-1].shape[-2:]:
|
||||
ctrl = nn.functional.interpolate(
|
||||
ctrl, hs[-1].shape[-2:], mode="nearest"
|
||||
)
|
||||
h = torch.cat([h, hs.pop() + ctrl], dim=1)
|
||||
h = module(h, emb, context)
|
||||
|
||||
h = h.type(x.dtype)
|
||||
return self.out(h)
|
||||
|
||||
|
||||
class ControlNet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
image_size,
|
||||
in_channels,
|
||||
model_channels,
|
||||
hint_channels,
|
||||
num_res_blocks,
|
||||
attention_resolutions,
|
||||
dropout=0,
|
||||
channel_mult=(1, 2, 4, 8),
|
||||
conv_resample=True,
|
||||
dims=2,
|
||||
use_checkpoint=False,
|
||||
use_fp16=False,
|
||||
num_heads=-1,
|
||||
num_head_channels=-1,
|
||||
num_heads_upsample=-1,
|
||||
use_scale_shift_norm=False,
|
||||
resblock_updown=False,
|
||||
use_new_attention_order=False,
|
||||
use_spatial_transformer=False, # custom transformer support
|
||||
transformer_depth=1, # custom transformer support
|
||||
context_dim=None, # custom transformer support
|
||||
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
||||
legacy=True,
|
||||
disable_self_attentions=None,
|
||||
num_attention_blocks=None,
|
||||
disable_middle_self_attn=False,
|
||||
use_linear_in_transformer=False,
|
||||
):
|
||||
super().__init__()
|
||||
if use_spatial_transformer:
|
||||
assert (
|
||||
context_dim is not None
|
||||
), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
|
||||
|
||||
if context_dim is not None:
|
||||
assert (
|
||||
use_spatial_transformer
|
||||
), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
|
||||
from omegaconf.listconfig import ListConfig
|
||||
|
||||
if isinstance(context_dim, ListConfig):
|
||||
context_dim = list(context_dim)
|
||||
|
||||
if num_heads_upsample == -1:
|
||||
num_heads_upsample = num_heads
|
||||
|
||||
if num_heads == -1:
|
||||
assert (
|
||||
num_head_channels != -1
|
||||
), "Either num_heads or num_head_channels has to be set"
|
||||
|
||||
if num_head_channels == -1:
|
||||
assert (
|
||||
num_heads != -1
|
||||
), "Either num_heads or num_head_channels has to be set"
|
||||
|
||||
self.dims = dims
|
||||
self.image_size = image_size
|
||||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
if isinstance(num_res_blocks, int):
|
||||
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
||||
else:
|
||||
if len(num_res_blocks) != len(channel_mult):
|
||||
raise ValueError(
|
||||
"provide num_res_blocks either as an int (globally constant) or "
|
||||
"as a list/tuple (per-level) with the same length as channel_mult"
|
||||
)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
if disable_self_attentions is not None:
|
||||
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
||||
assert len(disable_self_attentions) == len(channel_mult)
|
||||
if num_attention_blocks is not None:
|
||||
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
||||
assert all(
|
||||
map(
|
||||
lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
|
||||
range(len(num_attention_blocks)),
|
||||
)
|
||||
)
|
||||
print(
|
||||
f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
||||
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
||||
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
||||
f"attention will still not be set."
|
||||
)
|
||||
|
||||
self.attention_resolutions = attention_resolutions
|
||||
self.dropout = dropout
|
||||
self.channel_mult = channel_mult
|
||||
self.conv_resample = conv_resample
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.dtype = torch.float16 if use_fp16 else torch.float32
|
||||
self.num_heads = num_heads
|
||||
self.num_head_channels = num_head_channels
|
||||
self.num_heads_upsample = num_heads_upsample
|
||||
self.predict_codebook_ids = n_embed is not None
|
||||
|
||||
time_embed_dim = model_channels * 4
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(model_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
|
||||
self.input_blocks = nn.ModuleList(
|
||||
[
|
||||
TimestepEmbedSequential(
|
||||
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
||||
)
|
||||
]
|
||||
)
|
||||
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
|
||||
|
||||
self.input_hint_block = TimestepEmbedSequential(
|
||||
conv_nd(dims, hint_channels, 16, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 16, 16, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 16, 32, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 32, 32, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 32, 96, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 96, 96, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 96, 256, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
zero_module(conv_nd(dims, 256, model_channels, 3, padding=1)),
|
||||
)
|
||||
|
||||
self._feature_size = model_channels
|
||||
input_block_chans = [model_channels]
|
||||
ch = model_channels
|
||||
ds = 1
|
||||
for level, mult in enumerate(channel_mult):
|
||||
for nr in range(self.num_res_blocks[level]):
|
||||
layers = [
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=mult * model_channels,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
)
|
||||
]
|
||||
ch = mult * model_channels
|
||||
if ds in attention_resolutions:
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
# num_heads = 1
|
||||
dim_head = (
|
||||
ch // num_heads
|
||||
if use_spatial_transformer
|
||||
else num_head_channels
|
||||
)
|
||||
if disable_self_attentions is not None:
|
||||
disabled_sa = disable_self_attentions[level]
|
||||
else:
|
||||
disabled_sa = False
|
||||
|
||||
if num_attention_blocks is None or nr < num_attention_blocks[level]:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
)
|
||||
if not use_spatial_transformer
|
||||
else SpatialTransformer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth,
|
||||
context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa,
|
||||
use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint,
|
||||
)
|
||||
)
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self.zero_convs.append(self.make_zero_conv(ch))
|
||||
self._feature_size += ch
|
||||
input_block_chans.append(ch)
|
||||
if level != len(channel_mult) - 1:
|
||||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=True,
|
||||
)
|
||||
if resblock_updown
|
||||
else Downsample(
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch
|
||||
)
|
||||
)
|
||||
)
|
||||
ch = out_ch
|
||||
input_block_chans.append(ch)
|
||||
self.zero_convs.append(self.make_zero_conv(ch))
|
||||
ds *= 2
|
||||
self._feature_size += ch
|
||||
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
# num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
),
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
)
|
||||
if not use_spatial_transformer
|
||||
else SpatialTransformer( # always uses a self-attn
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth,
|
||||
context_dim=context_dim,
|
||||
disable_self_attn=disable_middle_self_attn,
|
||||
use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint,
|
||||
),
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
),
|
||||
)
|
||||
self.middle_block_out = self.make_zero_conv(ch)
|
||||
self._feature_size += ch
|
||||
|
||||
def make_zero_conv(self, channels):
|
||||
return TimestepEmbedSequential(
|
||||
zero_module(conv_nd(self.dims, channels, channels, 1, padding=0))
|
||||
)
|
||||
|
||||
def forward(self, x, hint, timesteps, context, **kwargs):
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
||||
emb = self.time_embed(t_emb)
|
||||
hint = hint.to(dtype=emb.dtype).to(device=emb.device)
|
||||
|
||||
guided_hint = self.input_hint_block(hint, emb, context)
|
||||
|
||||
outs = []
|
||||
|
||||
h = x.type(self.dtype)
|
||||
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
|
||||
if guided_hint is not None:
|
||||
h = module(h, emb, context)
|
||||
# for wider img resolution handling?
|
||||
if h.shape[-2:] != guided_hint[-1].shape[-2:]:
|
||||
guided_hint = nn.functional.interpolate(
|
||||
guided_hint, h[-1].shape[-2:], mode="nearest"
|
||||
)
|
||||
h += guided_hint
|
||||
guided_hint = None
|
||||
else:
|
||||
h = module(h, emb, context)
|
||||
outs.append(zero_conv(h, emb, context))
|
||||
|
||||
h = self.middle_block(h, emb, context)
|
||||
outs.append(self.middle_block_out(h, emb, context))
|
||||
|
||||
return outs
|
||||
|
||||
|
||||
class ControlLDM(LatentDiffusion):
|
||||
def __init__(
|
||||
self, control_stage_config, control_key, only_mid_control, *args, **kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.control_model = instantiate_from_config(control_stage_config)
|
||||
self.control_key = control_key
|
||||
self.only_mid_control = only_mid_control
|
||||
|
||||
@torch.no_grad()
|
||||
def get_input(self, batch, k, bs=None, *args, **kwargs): # noqa
|
||||
x, c = super().get_input(batch, self.first_stage_key, *args, **kwargs)
|
||||
control = batch[self.control_key]
|
||||
if bs is not None:
|
||||
control = control[:bs]
|
||||
control = control.to(self.device)
|
||||
control = einops.rearrange(control, "b h w c -> b c h w")
|
||||
control = control.to(memory_format=torch.contiguous_format).float()
|
||||
return x, {"c_crossattn": [c], "c_concat": [control]}
|
||||
|
||||
def apply_model(self, x_noisy, t, cond, *args, **kwargs):
|
||||
assert isinstance(cond, dict)
|
||||
diffusion_model = self.model.diffusion_model
|
||||
cond_txt = torch.cat(cond["c_crossattn"], 1)
|
||||
cond_hint = torch.cat(cond["c_concat"], 1)
|
||||
|
||||
control = self.control_model(
|
||||
x=x_noisy, hint=cond_hint, timesteps=t, context=cond_txt
|
||||
)
|
||||
eps = diffusion_model(
|
||||
x=x_noisy,
|
||||
timesteps=t,
|
||||
context=cond_txt,
|
||||
control=control,
|
||||
only_mid_control=self.only_mid_control,
|
||||
)
|
||||
|
||||
return eps
|
||||
|
||||
@torch.no_grad()
|
||||
def get_unconditional_conditioning(self, N):
|
||||
return self.get_learned_conditioning([""] * N)
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(
|
||||
self,
|
||||
batch,
|
||||
N=4,
|
||||
n_row=2,
|
||||
sample=False,
|
||||
ddim_steps=50,
|
||||
ddim_eta=0.0,
|
||||
return_keys=None,
|
||||
quantize_denoised=True,
|
||||
inpaint=True,
|
||||
plot_denoise_rows=False,
|
||||
plot_progressive_rows=True,
|
||||
plot_diffusion_rows=False,
|
||||
unconditional_guidance_scale=9.0,
|
||||
unconditional_guidance_label=None,
|
||||
use_ema_scope=True,
|
||||
**kwargs,
|
||||
):
|
||||
use_ddim = ddim_steps is not None
|
||||
|
||||
log = {}
|
||||
z, c = self.get_input(batch, self.first_stage_key, bs=N)
|
||||
c_cat, c = c["c_concat"][0][:N], c["c_crossattn"][0][:N]
|
||||
N = min(z.shape[0], N)
|
||||
n_row = min(z.shape[0], n_row)
|
||||
log["reconstruction"] = self.decode_first_stage(z)
|
||||
log["control"] = c_cat * 2.0 - 1.0
|
||||
log["conditioning"] = log_txt_as_img(
|
||||
(512, 512), batch[self.cond_stage_key], size=16
|
||||
)
|
||||
|
||||
if plot_diffusion_rows:
|
||||
# get diffusion row
|
||||
diffusion_row = []
|
||||
z_start = z[:n_row]
|
||||
for t in range(self.num_timesteps):
|
||||
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
||||
t = repeat(torch.tensor([t]), "1 -> b", b=n_row)
|
||||
t = t.to(self.device).long()
|
||||
noise = torch.randn_like(z_start)
|
||||
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
|
||||
diffusion_row.append(self.decode_first_stage(z_noisy))
|
||||
|
||||
diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
|
||||
diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w")
|
||||
diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w")
|
||||
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
|
||||
log["diffusion_row"] = diffusion_grid
|
||||
|
||||
if sample:
|
||||
# get denoise row
|
||||
samples, z_denoise_row = self.sample_log(
|
||||
cond={"c_concat": [c_cat], "c_crossattn": [c]},
|
||||
batch_size=N,
|
||||
ddim=use_ddim,
|
||||
ddim_steps=ddim_steps,
|
||||
eta=ddim_eta,
|
||||
)
|
||||
x_samples = self.decode_first_stage(samples)
|
||||
log["samples"] = x_samples
|
||||
if plot_denoise_rows:
|
||||
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
|
||||
log["denoise_row"] = denoise_grid
|
||||
|
||||
if unconditional_guidance_scale > 1.0:
|
||||
uc_cross = self.get_unconditional_conditioning(N)
|
||||
uc_cat = c_cat # torch.zeros_like(c_cat)
|
||||
uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
|
||||
samples_cfg, _ = self.sample_log(
|
||||
cond={"c_concat": [c_cat], "c_crossattn": [c]},
|
||||
batch_size=N,
|
||||
ddim=use_ddim,
|
||||
ddim_steps=ddim_steps,
|
||||
eta=ddim_eta,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=uc_full,
|
||||
)
|
||||
x_samples_cfg = self.decode_first_stage(samples_cfg)
|
||||
log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
|
||||
|
||||
return log
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
|
||||
ddim_sampler = DDIMSampler(self)
|
||||
b, c, h, w = cond["c_concat"][0].shape
|
||||
shape = (self.channels, h // 8, w // 8)
|
||||
samples, intermediates = ddim_sampler.sample(
|
||||
ddim_steps, batch_size, shape, cond, verbose=False, **kwargs
|
||||
)
|
||||
return samples, intermediates
|
||||
|
||||
def configure_optimizers(self):
|
||||
lr = self.learning_rate
|
||||
params = list(self.control_model.parameters())
|
||||
if not self.sd_locked:
|
||||
params += list(self.model.diffusion_model.output_blocks.parameters())
|
||||
params += list(self.model.diffusion_model.out.parameters())
|
||||
opt = torch.optim.AdamW(params, lr=lr)
|
||||
return opt
|
||||
|
||||
def low_vram_shift(self, is_diffusing):
|
||||
if is_diffusing:
|
||||
self.model = self.model.cuda()
|
||||
self.control_model = self.control_model.cuda()
|
||||
self.first_stage_model = self.first_stage_model.cpu() # noqa
|
||||
self.cond_stage_model = self.cond_stage_model.cpu()
|
||||
else:
|
||||
self.model = self.model.cpu()
|
||||
self.control_model = self.control_model.cpu()
|
||||
self.first_stage_model = self.first_stage_model.cuda() # noqa
|
||||
self.cond_stage_model = self.cond_stage_model.cuda()
|
@ -0,0 +1,131 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from imaginairy.model_manager import get_cached_url_path
|
||||
from imaginairy.paths import PKG_ROOT
|
||||
|
||||
sd15_url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/889b629140e71758e1e0006e355c331a5744b4bf/v1-5-pruned-emaonly.ckpt"
|
||||
|
||||
|
||||
def main():
|
||||
"""Script to convert the controlnet weights into diffs that are ready to be applied to any s1.5 weights."""
|
||||
|
||||
control_types = [
|
||||
"canny",
|
||||
"depth",
|
||||
"hed",
|
||||
"mlsd",
|
||||
"normal",
|
||||
"openpose",
|
||||
"scribble",
|
||||
"seg",
|
||||
]
|
||||
url_template = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/models/control_sd15_{control_type}.pth"
|
||||
urls = {
|
||||
control_type: url_template.format(control_type=control_type)
|
||||
for control_type in control_types
|
||||
}
|
||||
dest = f"{PKG_ROOT}/../other/weights/controlnet"
|
||||
|
||||
for control_type, url in urls.items():
|
||||
print(f"Downloading {control_type} weights from {url}")
|
||||
|
||||
out_filepath = extract_controlnet_essence(
|
||||
control_type=control_type,
|
||||
controlnet_url=url,
|
||||
dest_folder=dest,
|
||||
)
|
||||
|
||||
sd15_path = get_cached_url_path(sd15_url)
|
||||
sd15_state_dict = torch.load(sd15_path, map_location="cpu")
|
||||
sd15_state_dict = sd15_state_dict.get("state_dict", sd15_state_dict)
|
||||
reconstituted_controlnet_statedict = apply_controlnet(
|
||||
base_state_dict=sd15_state_dict,
|
||||
controlnet_state_dict=load_file(out_filepath),
|
||||
)
|
||||
|
||||
controlnet_path = get_cached_url_path(url)
|
||||
import time
|
||||
|
||||
time.sleep(1)
|
||||
controlnet_statedict = torch.load(controlnet_path, map_location="cpu")
|
||||
print("\n\nComparing reconstructed controlnet with original")
|
||||
for k in controlnet_statedict.keys():
|
||||
if k not in reconstituted_controlnet_statedict.keys():
|
||||
print(f"Key {k} not in reconstituted")
|
||||
elif (
|
||||
controlnet_statedict[k].shape
|
||||
!= reconstituted_controlnet_statedict[k].shape
|
||||
):
|
||||
print(f"Key {k} has different shape")
|
||||
print(controlnet_statedict[k].shape)
|
||||
print(reconstituted_controlnet_statedict[k].shape)
|
||||
else:
|
||||
diff = controlnet_statedict[k] - reconstituted_controlnet_statedict[k]
|
||||
diff_sum = torch.abs(diff).sum()
|
||||
if diff_sum > 3.467949682089966e-05:
|
||||
print(f"Key {k} has different values {diff_sum}")
|
||||
|
||||
|
||||
def extract_controlnet_essence(control_type, controlnet_url, dest_folder):
|
||||
print(f"Extracting essence of {control_type} weights from {controlnet_url}")
|
||||
outpath = f"{dest_folder}/controlnet15_diff_{control_type}.safetensors"
|
||||
if os.path.exists(outpath):
|
||||
print(f"File {outpath} already exists, skipping")
|
||||
return outpath
|
||||
os.makedirs(dest_folder, exist_ok=True)
|
||||
sd15_path = get_cached_url_path(sd15_url)
|
||||
controlnet_path = get_cached_url_path(controlnet_url)
|
||||
print(f"sd15_path: {sd15_path}")
|
||||
print(f"controlnet_path: {controlnet_path}")
|
||||
|
||||
sd15_state_dict = torch.load(sd15_path, map_location="cpu")
|
||||
sd15_state_dict = sd15_state_dict.get("state_dict", sd15_state_dict)
|
||||
|
||||
controlnet_state_dict = torch.load(controlnet_path, map_location="cpu")
|
||||
controlnet_state_dict = controlnet_state_dict.get(
|
||||
"state_dict", controlnet_state_dict
|
||||
)
|
||||
|
||||
final_state_dict = {}
|
||||
skip_prefixes = ("first_stage_model", "cond_stage_model")
|
||||
for key in controlnet_state_dict:
|
||||
|
||||
if key.startswith(skip_prefixes):
|
||||
continue
|
||||
|
||||
if key.startswith("control_"):
|
||||
sd15_key_name = "model.diffusion_" + key[len("control_") :]
|
||||
else:
|
||||
sd15_key_name = key
|
||||
|
||||
if sd15_key_name in sd15_state_dict:
|
||||
diff_value = controlnet_state_dict[key] - sd15_state_dict[sd15_key_name]
|
||||
final_state_dict[key] = diff_value
|
||||
else:
|
||||
final_state_dict[key] = controlnet_state_dict[key]
|
||||
save_file(final_state_dict, outpath)
|
||||
return outpath
|
||||
|
||||
|
||||
def apply_controlnet(base_state_dict, controlnet_state_dict):
|
||||
for key in controlnet_state_dict:
|
||||
if key.startswith("control_"):
|
||||
sd15_key_name = "model.diffusion_" + key[len("control_") :]
|
||||
else:
|
||||
sd15_key_name = key
|
||||
|
||||
if sd15_key_name in base_state_dict:
|
||||
b = base_state_dict[sd15_key_name]
|
||||
c_diff = controlnet_state_dict[key]
|
||||
new_c = b + c_diff
|
||||
base_state_dict[key] = new_c
|
||||
else:
|
||||
base_state_dict[key] = controlnet_state_dict[key]
|
||||
return base_state_dict
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
After Width: | Height: | Size: 41 KiB |
After Width: | Height: | Size: 68 KiB |
After Width: | Height: | Size: 224 KiB |
After Width: | Height: | Size: 354 KiB |
After Width: | Height: | Size: 4.3 KiB |
After Width: | Height: | Size: 548 KiB |
After Width: | Height: | Size: 507 KiB |
After Width: | Height: | Size: 608 KiB |
After Width: | Height: | Size: 500 KiB |
After Width: | Height: | Size: 481 KiB |
@ -0,0 +1,26 @@
|
||||
import pytest
|
||||
|
||||
from imaginairy import LazyLoadingImage
|
||||
from imaginairy.img_processors.control_modes import CONTROL_MODES
|
||||
from imaginairy.img_utils import pillow_img_to_torch_image, torch_img_to_pillow_img
|
||||
from tests import TESTS_FOLDER
|
||||
from tests.utils import assert_image_similar_to_expectation
|
||||
|
||||
|
||||
def control_img_to_pillow_img(img_t):
|
||||
return torch_img_to_pillow_img((img_t - 0.5) * 2)
|
||||
|
||||
|
||||
control_mode_params = list(CONTROL_MODES.items())
|
||||
|
||||
|
||||
@pytest.mark.parametrize("control_name,control_func", control_mode_params)
|
||||
def test_control_images(filename_base_for_outputs, control_func, control_name):
|
||||
img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/bench2.png")
|
||||
img_t = pillow_img_to_torch_image(img)
|
||||
|
||||
control_t = control_func(img_t.clone())
|
||||
control_img = control_img_to_pillow_img(control_t)
|
||||
img_path = f"{filename_base_for_outputs}.png"
|
||||
|
||||
assert_image_similar_to_expectation(control_img, img_path, threshold=3500)
|