feature: update midas (depth maps)

pull/421/head
Bryce 6 months ago committed by Bryce Drennan
parent bf14ee6ee6
commit 7880ee1389

@ -2,7 +2,7 @@
### v14 todo
- configurable composition cutoff
- configurable composition cutoff
- ✅ rename model parameter weights
- ✅ rename model_config parameter to architecture and make it case insensitive
- ✅ add --size parameter that accepts strings (e.g. 256x256, 4k, uhd, 8k, etc)
@ -10,7 +10,7 @@
- ✅ add method to install correct torch version
- ✅ make cli run faster again
- ✅ add tests for cli commands
- add type checker
- add type checker
- only output the main image unless some flag is set
- allow selection of output video format
- chain multiple operations together imggen => videogen

@ -16,8 +16,6 @@ DEFAULT_NEGATIVE_PROMPT = (
"grainy, blurred, blurry, writing, calligraphy, signature, text, watermark, bad art,"
)
midas_url = "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt"
@dataclass
class ModelArchitecture:

@ -38,12 +38,14 @@ def create_canny_edges(img: "Tensor") -> "Tensor":
return canny_image
def create_depth_map(img: "Tensor") -> "Tensor":
def create_depth_map(
img: "Tensor", model_type="dpt_hybrid_384", max_size=512
) -> "Tensor":
import torch
orig_size = img.shape[2:]
depth_pt = _create_depth_map_raw(img, max_size=1024)
depth_pt = _create_depth_map_raw(img, max_size=max_size, model_type=model_type)
# copy the depth map to the other channels
depth_pt = torch.cat([depth_pt, depth_pt, depth_pt], dim=0)
@ -61,12 +63,14 @@ def create_depth_map(img: "Tensor") -> "Tensor":
return depth_pt
def _create_depth_map_raw(img: "Tensor", max_size=512) -> "Tensor":
def _create_depth_map_raw(
img: "Tensor", max_size=512, model_type="dpt_large_384"
) -> "Tensor":
import torch
from imaginairy.modules.midas.api import MiDaSInference, midas_device
model = MiDaSInference(model_type="dpt_hybrid").to(midas_device())
model = MiDaSInference(model_type=model_type).to(midas_device())
img = img.to(midas_device())
# calculate new size such that image fits within 512x512 but keeps aspect ratio

@ -20,10 +20,19 @@ from imaginairy.modules.midas.midas.transforms import (
from imaginairy.utils import get_device
ISL_PATHS = {
"dpt_large": "midas_models/dpt_large-midas-2f21e586.pt",
"dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt",
"midas_v21": "",
"midas_v21_small": "",
"dpt_beit_large_512": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt",
"dpt_beit_large_384": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_384.pt",
"dpt_beit_base_384": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_base_384.pt",
# "dpt_swin2_large_384": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_large_384.pt",
# "dpt_swin2_base_384": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_base_384.pt",
# "dpt_swin2_tiny_256": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_tiny_256.pt",
# "dpt_swin_large_384": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin_large_384.pt",
# "dpt_next_vit_large_384": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_next_vit_large_384.pt",
# "dpt_levit_224": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_levit_224.pt",
"dpt_large_384": "https://github.com/isl-org/MiDaS/releases/download/v3/dpt_large_384.pt",
"dpt_hybrid_384": "https://github.com/isl-org/MiDaS/releases/download/v3/dpt_hybrid_384.pt",
"midas_v21_384": "https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_384.pt",
# "midas_v21_small_256": "https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_small_256.pt",
}
@ -38,12 +47,12 @@ def disabled_train(self, mode=True):
def load_midas_transform(model_type="dpt_hybrid"):
# https://github.com/isl-org/MiDaS/blob/master/run.py
# load transform only
if model_type == "dpt_large": # DPT-Large
if model_type in ("dpt_large_384", "dpt_large_384_v1"): # DPT-Large
net_w, net_h = 384, 384
resize_mode = "minimal"
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
elif model_type == "dpt_hybrid": # DPT-Hybrid
elif model_type in ("dpt_hybrid_384", "dpt_hybrid_384_v1"): # DPT-Hybrid
net_w, net_h = 384, 384
resize_mode = "minimal"
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
@ -86,11 +95,134 @@ def load_midas_transform(model_type="dpt_hybrid"):
@lru_cache(maxsize=1)
def load_model(model_type):
# https://github.com/isl-org/MiDaS/blob/master/run.py
# load network
def load_model(
# device,
# model_path,
model_type="dpt_large_384",
optimize=True,
height=None,
square=False,
):
"""Load the specified network.
Args:
device (device): the torch device used
model_path (str): path to saved model
model_type (str): the type of the model to be loaded
optimize (bool): optimize the model to half-integer on CUDA?
height (int): inference encoder image height
square (bool): resize to a square resolution?
Returns:
The loaded network, the transform which prepares images as input to the network and the dimensions of the
network input
"""
model_path = ISL_PATHS[model_type]
if model_type == "dpt_large": # DPT-Large
keep_aspect_ratio = not square
if model_type == "dpt_beit_large_512":
model = DPTDepthModel(
path=model_path,
backbone="beitl16_512",
non_negative=True,
)
net_w, net_h = 512, 512
resize_mode = "minimal"
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
elif model_type == "dpt_beit_large_384":
model = DPTDepthModel(
path=model_path,
backbone="beitl16_384",
non_negative=True,
)
net_w, net_h = 384, 384
resize_mode = "minimal"
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
elif model_type == "dpt_beit_base_384":
model = DPTDepthModel(
path=model_path,
backbone="beitb16_384",
non_negative=True,
)
net_w, net_h = 384, 384
resize_mode = "minimal"
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
elif model_type == "dpt_swin2_large_384":
model = DPTDepthModel(
path=model_path,
backbone="swin2l24_384",
non_negative=True,
)
net_w, net_h = 384, 384
keep_aspect_ratio = False
resize_mode = "minimal"
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
elif model_type == "dpt_swin2_base_384":
model = DPTDepthModel(
path=model_path,
backbone="swin2b24_384",
non_negative=True,
)
net_w, net_h = 384, 384
keep_aspect_ratio = False
resize_mode = "minimal"
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
elif model_type == "dpt_swin2_tiny_256":
model = DPTDepthModel(
path=model_path,
backbone="swin2t16_256",
non_negative=True,
)
net_w, net_h = 256, 256
keep_aspect_ratio = False
resize_mode = "minimal"
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
elif model_type == "dpt_swin_large_384":
model = DPTDepthModel(
path=model_path,
backbone="swinl12_384",
non_negative=True,
)
net_w, net_h = 384, 384
keep_aspect_ratio = False
resize_mode = "minimal"
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
elif model_type == "dpt_next_vit_large_384":
model = DPTDepthModel(
path=model_path,
backbone="next_vit_large_6m",
non_negative=True,
)
net_w, net_h = 384, 384
resize_mode = "minimal"
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
# We change the notation from dpt_levit_224 (MiDaS notation) to levit_384 (timm notation) here, where the 224 refers
# to the resolution 224x224 used by LeViT and 384 is the first entry of the embed_dim, see _cfg and model_cfgs of
# https://github.com/rwightman/pytorch-image-models/blob/main/timm/models/levit.py
# (commit id: 927f031293a30afb940fff0bee34b85d9c059b0e)
elif model_type == "dpt_levit_224":
model = DPTDepthModel(
path=model_path,
backbone="levit_384",
non_negative=True,
head_features_1=64,
head_features_2=8,
)
net_w, net_h = 224, 224
keep_aspect_ratio = False
resize_mode = "minimal"
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
elif model_type in ("dpt_large_384", "dpt_large_384_v1"):
model = DPTDepthModel(
path=model_path,
backbone="vitl16_384",
@ -100,7 +232,7 @@ def load_model(model_type):
resize_mode = "minimal"
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
elif model_type == "dpt_hybrid": # DPT-Hybrid
elif model_type in ("dpt_hybrid_384", "dpt_hybrid_384_v1"):
model = DPTDepthModel(
path=model_path,
backbone="vitb_rn50_384",
@ -110,7 +242,7 @@ def load_model(model_type):
resize_mode = "minimal"
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
elif model_type == "midas_v21":
elif model_type == "midas_v21_384":
model = MidasNet(model_path, non_negative=True)
net_w, net_h = 384, 384
resize_mode = "upper_bound"
@ -118,7 +250,7 @@ def load_model(model_type):
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
elif model_type == "midas_v21_small":
elif model_type == "midas_v21_small_256":
model = MidasNet_small(
model_path,
features=64,
@ -137,13 +269,28 @@ def load_model(model_type):
msg = f"model_type '{model_type}' not implemented, use: --model_type large"
raise NotImplementedError(msg)
if "openvino" not in model_type:
print(
"Model loaded, number of parameters = {:.0f}M".format(
sum(p.numel() for p in model.parameters()) / 1e6
)
)
else:
print("Model loaded, optimized with OpenVINO")
if "openvino" in model_type:
keep_aspect_ratio = False
if height is not None:
net_w, net_h = height, height
transform = Compose(
[
Resize(
net_w,
net_h,
resize_target=None,
keep_aspect_ratio=True,
keep_aspect_ratio=keep_aspect_ratio,
ensure_multiple_of=32,
resize_method=resize_mode,
image_interpolation_method=cv2.INTER_CUBIC,
@ -187,17 +334,9 @@ def torch_image_to_depth_map(image_t: torch.Tensor, model_type="dpt_hybrid"):
class MiDaSInference(nn.Module):
MODEL_TYPES_TORCH_HUB = ["DPT_Large", "DPT_Hybrid", "MiDaS_small"]
MODEL_TYPES_ISL = [
"dpt_large",
"dpt_hybrid",
"midas_v21",
"midas_v21_small",
]
def __init__(self, model_type):
super().__init__()
assert model_type in self.MODEL_TYPES_ISL
# assert model_type in self.MODEL_TYPES_ISL
model, _ = load_model(model_type)
self.model = model
self.model.train = disabled_train

@ -0,0 +1,238 @@
import types
from typing import Optional
import numpy as np
import timm
import torch
import torch.nn.functional as F
from timm.models.beit import gen_relative_position_index
from torch.utils.checkpoint import checkpoint
from .utils import forward_adapted_unflatten, make_backbone_default
def forward_beit(pretrained, x):
return forward_adapted_unflatten(pretrained, x, "forward_features")
def patch_embed_forward(self, x):
"""
Modification of timm.models.layers.patch_embed.py: PatchEmbed.forward to support arbitrary window sizes.
"""
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
return x
def _get_rel_pos_bias(self, window_size):
"""
Modification of timm.models.beit.py: Attention._get_rel_pos_bias to support arbitrary window sizes.
"""
old_height = 2 * self.window_size[0] - 1
old_width = 2 * self.window_size[1] - 1
new_height = 2 * window_size[0] - 1
new_width = 2 * window_size[1] - 1
old_relative_position_bias_table = self.relative_position_bias_table
old_num_relative_distance = self.num_relative_distance
new_num_relative_distance = new_height * new_width + 3
old_sub_table = old_relative_position_bias_table[: old_num_relative_distance - 3]
old_sub_table = old_sub_table.reshape(1, old_width, old_height, -1).permute(
0, 3, 1, 2
)
new_sub_table = F.interpolate(
old_sub_table, size=(int(new_height), int(new_width)), mode="bilinear"
)
new_sub_table = new_sub_table.permute(0, 2, 3, 1).reshape(
new_num_relative_distance - 3, -1
)
new_relative_position_bias_table = torch.cat(
[
new_sub_table,
old_relative_position_bias_table[old_num_relative_distance - 3 :],
]
)
key = str(window_size[1]) + "," + str(window_size[0])
if key not in self.relative_position_indices:
self.relative_position_indices[key] = gen_relative_position_index(window_size)
relative_position_bias = new_relative_position_bias_table[
self.relative_position_indices[key].view(-1)
].view(
window_size[0] * window_size[1] + 1, window_size[0] * window_size[1] + 1, -1
) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(
2, 0, 1
).contiguous() # nH, Wh*Ww, Wh*Ww
return relative_position_bias.unsqueeze(0)
def attention_forward(
self, x, resolution, shared_rel_pos_bias: Optional[torch.Tensor] = None
):
"""
Modification of timm.models.beit.py: Attention.forward to support arbitrary window sizes.
"""
B, N, C = x.shape
qkv_bias = (
torch.cat((self.q_bias, self.k_bias, self.v_bias))
if self.q_bias is not None
else None
)
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = q @ k.transpose(-2, -1)
if self.relative_position_bias_table is not None:
window_size = tuple(np.array(resolution) // 16)
attn = attn + self._get_rel_pos_bias(window_size)
if shared_rel_pos_bias is not None:
attn = attn + shared_rel_pos_bias
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
def block_forward(
self, x, resolution, shared_rel_pos_bias: Optional[torch.Tensor] = None
):
"""
Modification of timm.models.beit.py: Block.forward to support arbitrary window sizes.
See https://github.com/isl-org/MiDaS/pull/234
"""
if self.gamma_1 is None:
x = x + self.drop_path1(
self.attn(
self.norm1(x), resolution, shared_rel_pos_bias=shared_rel_pos_bias
)
)
x = x + self.drop_path2(self.mlp(self.norm2(x)))
else:
x = x + self.drop_path1(
self.gamma_1
* self.attn(
self.norm1(x), resolution, shared_rel_pos_bias=shared_rel_pos_bias
)
)
x = x + self.drop_path2(self.gamma_2 * self.mlp(self.norm2(x)))
return x
def beit_forward_features(self, x):
"""
Modification of timm.models.beit.py: Beit.forward_features to support arbitrary window sizes.
"""
resolution = x.shape[2:]
x = self.patch_embed(x)
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
if self.pos_embed is not None:
x = x + self.pos_embed
x = self.pos_drop(x)
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
for blk in self.blocks:
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(blk, x, shared_rel_pos_bias=rel_pos_bias)
else:
x = blk(x, resolution, shared_rel_pos_bias=rel_pos_bias)
x = self.norm(x)
return x
def _make_beit_backbone(
model,
features=[96, 192, 384, 768],
size=[384, 384],
hooks=[0, 4, 8, 11],
vit_features=768,
use_readout="ignore",
start_index=1,
start_index_readout=1,
):
backbone = make_backbone_default(
model,
features,
size,
hooks,
vit_features,
use_readout,
start_index,
start_index_readout,
)
backbone.model.patch_embed.forward = types.MethodType(
patch_embed_forward, backbone.model.patch_embed
)
backbone.model.forward_features = types.MethodType(
beit_forward_features, backbone.model
)
for block in backbone.model.blocks:
attn = block.attn
attn._get_rel_pos_bias = types.MethodType(_get_rel_pos_bias, attn)
attn.forward = types.MethodType(attention_forward, attn)
attn.relative_position_indices = {}
block.forward = types.MethodType(block_forward, block)
return backbone
def _make_pretrained_beitl16_512(pretrained, use_readout="ignore", hooks=None):
model = timm.create_model("beit_large_patch16_512", pretrained=pretrained)
hooks = [5, 11, 17, 23] if hooks is None else hooks
features = [256, 512, 1024, 1024]
return _make_beit_backbone(
model,
features=features,
size=[512, 512],
hooks=hooks,
vit_features=1024,
use_readout=use_readout,
)
def _make_pretrained_beitl16_384(pretrained, use_readout="ignore", hooks=None):
model = timm.create_model("beit_large_patch16_384", pretrained=pretrained)
hooks = [5, 11, 17, 23] if hooks is None else hooks
return _make_beit_backbone(
model,
features=[256, 512, 1024, 1024],
hooks=hooks,
vit_features=1024,
use_readout=use_readout,
)
def _make_pretrained_beitb16_384(pretrained, use_readout="ignore", hooks=None):
model = timm.create_model("beit_base_patch16_384", pretrained=pretrained)
hooks = [2, 5, 8, 11] if hooks is None else hooks
return _make_beit_backbone(
model,
features=[96, 192, 384, 768],
hooks=hooks,
use_readout=use_readout,
)

@ -0,0 +1,249 @@
import torch
import torch.nn as nn
class Slice(nn.Module):
def __init__(self, start_index=1):
super().__init__()
self.start_index = start_index
def forward(self, x):
return x[:, self.start_index :]
class AddReadout(nn.Module):
def __init__(self, start_index=1):
super().__init__()
self.start_index = start_index
def forward(self, x):
if self.start_index == 2: # noqa
readout = (x[:, 0] + x[:, 1]) / 2
else:
readout = x[:, 0]
return x[:, self.start_index :] + readout.unsqueeze(1)
class ProjectReadout(nn.Module):
def __init__(self, in_features, start_index=1):
super().__init__()
self.start_index = start_index
self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
def forward(self, x):
readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
features = torch.cat((x[:, self.start_index :], readout), -1)
return self.project(features)
class Transpose(nn.Module):
def __init__(self, dim0, dim1):
super().__init__()
self.dim0 = dim0
self.dim1 = dim1
def forward(self, x):
x = x.transpose(self.dim0, self.dim1)
return x
activations = {}
def get_activation(name):
def hook(model, input, output): # noqa
activations[name] = output
return hook
def forward_default(pretrained, x, function_name="forward_features"):
exec(f"pretrained.model.{function_name}(x)")
layer_1 = pretrained.activations["1"]
layer_2 = pretrained.activations["2"]
layer_3 = pretrained.activations["3"]
layer_4 = pretrained.activations["4"]
if hasattr(pretrained, "act_postprocess1"):
layer_1 = pretrained.act_postprocess1(layer_1)
if hasattr(pretrained, "act_postprocess2"):
layer_2 = pretrained.act_postprocess2(layer_2)
if hasattr(pretrained, "act_postprocess3"):
layer_3 = pretrained.act_postprocess3(layer_3)
if hasattr(pretrained, "act_postprocess4"):
layer_4 = pretrained.act_postprocess4(layer_4)
return layer_1, layer_2, layer_3, layer_4
def forward_adapted_unflatten(pretrained, x, function_name="forward_features"):
b, c, h, w = x.shape
exec(f"glob = pretrained.model.{function_name}(x)")
layer_1 = pretrained.activations["1"]
layer_2 = pretrained.activations["2"]
layer_3 = pretrained.activations["3"]
layer_4 = pretrained.activations["4"]
layer_1 = pretrained.act_postprocess1[0:2](layer_1)
layer_2 = pretrained.act_postprocess2[0:2](layer_2)
layer_3 = pretrained.act_postprocess3[0:2](layer_3)
layer_4 = pretrained.act_postprocess4[0:2](layer_4)
unflatten = nn.Sequential(
nn.Unflatten(
2,
torch.Size(
[
h // pretrained.model.patch_size[1],
w // pretrained.model.patch_size[0],
]
),
)
)
if layer_1.ndim == 3:
layer_1 = unflatten(layer_1)
if layer_2.ndim == 3:
layer_2 = unflatten(layer_2)
if layer_3.ndim == 3:
layer_3 = unflatten(layer_3)
if layer_4.ndim == 3:
layer_4 = unflatten(layer_4)
layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
return layer_1, layer_2, layer_3, layer_4
def get_readout_oper(vit_features, features, use_readout, start_index=1):
if use_readout == "ignore":
readout_oper = [Slice(start_index)] * len(features)
elif use_readout == "add":
readout_oper = [AddReadout(start_index)] * len(features)
elif use_readout == "project":
readout_oper = [
ProjectReadout(vit_features, start_index) for out_feat in features
]
else:
msg = "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
raise ValueError(msg)
return readout_oper
def make_backbone_default(
model,
features=[96, 192, 384, 768],
size=[384, 384],
hooks=[2, 5, 8, 11],
vit_features=768,
use_readout="ignore",
start_index=1,
start_index_readout=1,
):
pretrained = nn.Module()
pretrained.model = model
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
pretrained.activations = activations
readout_oper = get_readout_oper(
vit_features, features, use_readout, start_index_readout
)
# 32, 48, 136, 384
pretrained.act_postprocess1 = nn.Sequential(
readout_oper[0],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[0],
kernel_size=1,
stride=1,
padding=0,
),
nn.ConvTranspose2d(
in_channels=features[0],
out_channels=features[0],
kernel_size=4,
stride=4,
padding=0,
bias=True,
dilation=1,
groups=1,
),
)
pretrained.act_postprocess2 = nn.Sequential(
readout_oper[1],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[1],
kernel_size=1,
stride=1,
padding=0,
),
nn.ConvTranspose2d(
in_channels=features[1],
out_channels=features[1],
kernel_size=2,
stride=2,
padding=0,
bias=True,
dilation=1,
groups=1,
),
)
pretrained.act_postprocess3 = nn.Sequential(
readout_oper[2],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[2],
kernel_size=1,
stride=1,
padding=0,
),
)
pretrained.act_postprocess4 = nn.Sequential(
readout_oper[3],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[3],
kernel_size=1,
stride=1,
padding=0,
),
nn.Conv2d(
in_channels=features[3],
out_channels=features[3],
kernel_size=3,
stride=2,
padding=1,
),
)
pretrained.model.start_index = start_index
pretrained.model.patch_size = [16, 16]
return pretrained

@ -0,0 +1,238 @@
import math
import types
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
from .utils import (
Transpose,
activations,
forward_adapted_unflatten,
get_activation,
get_readout_oper,
make_backbone_default,
)
def forward_vit(pretrained, x):
return forward_adapted_unflatten(pretrained, x, "forward_flex")
def _resize_pos_embed(self, posemb, gs_h, gs_w):
posemb_tok, posemb_grid = (
posemb[:, : self.start_index],
posemb[0, self.start_index :],
)
gs_old = int(math.sqrt(len(posemb_grid)))
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
return posemb
def forward_flex(self, x):
b, c, h, w = x.shape
pos_embed = self._resize_pos_embed(
self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
)
B = x.shape[0]
if hasattr(self.patch_embed, "backbone"):
x = self.patch_embed.backbone(x)
if isinstance(x, (list, tuple)):
x = x[-1] # last feature if backbone outputs list/tuple of features
x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
if getattr(self, "dist_token", None) is not None:
cls_tokens = self.cls_token.expand(
B, -1, -1
) # stole cls_tokens impl from Phil Wang, thanks
dist_token = self.dist_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, dist_token, x), dim=1)
else:
if self.no_embed_class:
x = x + pos_embed
cls_tokens = self.cls_token.expand(
B, -1, -1
) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
if not self.no_embed_class:
x = x + pos_embed
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x
def _make_vit_b16_backbone(
model,
features=[96, 192, 384, 768],
size=[384, 384],
hooks=[2, 5, 8, 11],
vit_features=768,
use_readout="ignore",
start_index=1,
start_index_readout=1,
):
pretrained = make_backbone_default(
model,
features,
size,
hooks,
vit_features,
use_readout,
start_index,
start_index_readout,
)
# We inject this function into the VisionTransformer instances so that
# we can use it with interpolated position embeddings without modifying the library source.
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
pretrained.model._resize_pos_embed = types.MethodType(
_resize_pos_embed, pretrained.model
)
return pretrained
def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
hooks = [5, 11, 17, 23] if hooks is None else hooks
return _make_vit_b16_backbone(
model,
features=[256, 512, 1024, 1024],
hooks=hooks,
vit_features=1024,
use_readout=use_readout,
)
def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
hooks = [2, 5, 8, 11] if hooks is None else hooks
return _make_vit_b16_backbone(
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
)
def _make_vit_b_rn50_backbone(
model,
features=[256, 512, 768, 768],
size=[384, 384],
hooks=[0, 1, 8, 11],
vit_features=768,
patch_size=[16, 16],
number_stages=2,
use_vit_only=False,
use_readout="ignore",
start_index=1,
):
pretrained = nn.Module()
pretrained.model = model
used_number_stages = 0 if use_vit_only else number_stages
for s in range(used_number_stages):
pretrained.model.patch_embed.backbone.stages[s].register_forward_hook(
get_activation(str(s + 1))
)
for s in range(used_number_stages, 4):
pretrained.model.blocks[hooks[s]].register_forward_hook(
get_activation(str(s + 1))
)
pretrained.activations = activations
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
for s in range(used_number_stages):
value = nn.Sequential(nn.Identity(), nn.Identity(), nn.Identity())
exec(f"pretrained.act_postprocess{s + 1}=value")
for s in range(used_number_stages, 4):
if s < number_stages:
final_layer = nn.ConvTranspose2d(
in_channels=features[s],
out_channels=features[s],
kernel_size=4 // (2**s),
stride=4 // (2**s),
padding=0,
bias=True,
dilation=1,
groups=1,
)
elif s > number_stages:
final_layer = nn.Conv2d(
in_channels=features[3],
out_channels=features[3],
kernel_size=3,
stride=2,
padding=1,
)
else:
final_layer = None
layers = [
readout_oper[s],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[s],
kernel_size=1,
stride=1,
padding=0,
),
]
if final_layer is not None:
layers.append(final_layer)
value = nn.Sequential(*layers) # noqa
exec(f"pretrained.act_postprocess{s + 1}=value")
pretrained.model.start_index = start_index
pretrained.model.patch_size = patch_size
# We inject this function into the VisionTransformer instances so that
# we can use it with interpolated position embeddings without modifying the library source.
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
# We inject this function into the VisionTransformer instances so that
# we can use it with interpolated position embeddings without modifying the library source.
pretrained.model._resize_pos_embed = types.MethodType(
_resize_pos_embed, pretrained.model
)
return pretrained
def _make_pretrained_vitb_rn50_384(
pretrained, use_readout="ignore", hooks=None, use_vit_only=False
):
model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
hooks = [0, 1, 8, 11] if hooks is None else hooks
return _make_vit_b_rn50_backbone(
model,
features=[256, 512, 768, 768],
size=[384, 384],
hooks=hooks,
use_vit_only=use_vit_only,
use_readout=use_readout,
)

@ -2,7 +2,6 @@
import torch
from imaginairy import config
from imaginairy.utils.model_manager import get_cached_url_path
@ -14,9 +13,11 @@ class BaseModel(torch.nn.Module):
Args:
path (str): file path
"""
ckpt_path = get_cached_url_path(config.midas_url, category="weights")
ckpt_path = get_cached_url_path(path, category="weights")
parameters = torch.load(ckpt_path, map_location=torch.device("cpu"))
parameters = {
k: v for k, v in parameters.items() if "relative_position_index" not in k
}
if "optimizer" in parameters:
parameters = parameters["model"]

@ -1,12 +1,17 @@
"""Functions and classes for neural network construction"""
import torch
from torch import nn
import torch.nn as nn
from .vit import (
from .backbones.beit import (
_make_pretrained_beitb16_384,
_make_pretrained_beitl16_384,
_make_pretrained_beitl16_512,
forward_beit, # noqa
)
from .backbones.vit import (
_make_pretrained_vitb16_384,
_make_pretrained_vitb_rn50_384,
_make_pretrained_vitl16_384,
forward_vit, # noqa
)
@ -20,8 +25,62 @@ def _make_encoder(
hooks=None,
use_vit_only=False,
use_readout="ignore",
in_features=[96, 256, 512, 1024],
):
if backbone == "vitl16_384":
if backbone == "beitl16_512":
pretrained = _make_pretrained_beitl16_512(
use_pretrained, hooks=hooks, use_readout=use_readout
)
scratch = _make_scratch(
[256, 512, 1024, 1024], features, groups=groups, expand=expand
) # BEiT_512-L (backbone)
elif backbone == "beitl16_384":
pretrained = _make_pretrained_beitl16_384(
use_pretrained, hooks=hooks, use_readout=use_readout
)
scratch = _make_scratch(
[256, 512, 1024, 1024], features, groups=groups, expand=expand
) # BEiT_384-L (backbone)
elif backbone == "beitb16_384":
pretrained = _make_pretrained_beitb16_384(
use_pretrained, hooks=hooks, use_readout=use_readout
)
scratch = _make_scratch(
[96, 192, 384, 768], features, groups=groups, expand=expand
) # BEiT_384-B (backbone)
# elif backbone == "swin2l24_384":
# pretrained = _make_pretrained_swin2l24_384(use_pretrained, hooks=hooks)
# scratch = _make_scratch(
# [192, 384, 768, 1536], features, groups=groups, expand=expand
# ) # Swin2-L/12to24 (backbone)
# elif backbone == "swin2b24_384":
# pretrained = _make_pretrained_swin2b24_384(use_pretrained, hooks=hooks)
# scratch = _make_scratch(
# [128, 256, 512, 1024], features, groups=groups, expand=expand
# ) # Swin2-B/12to24 (backbone)
# elif backbone == "swin2t16_256":
# pretrained = _make_pretrained_swin2t16_256(use_pretrained, hooks=hooks)
# scratch = _make_scratch(
# [96, 192, 384, 768], features, groups=groups, expand=expand
# ) # Swin2-T/16 (backbone)
# elif backbone == "swinl12_384":
# pretrained = _make_pretrained_swinl12_384(use_pretrained, hooks=hooks)
# scratch = _make_scratch(
# [192, 384, 768, 1536], features, groups=groups, expand=expand
# ) # Swin-L/12 (backbone)
# elif backbone == "next_vit_large_6m":
# from .backbones.next_vit import _make_pretrained_next_vit_large_6m
#
# pretrained = _make_pretrained_next_vit_large_6m(hooks=hooks)
# scratch = _make_scratch(
# in_features, features, groups=groups, expand=expand
# ) # Next-ViT-L on ImageNet-1K-6M (backbone)
# elif backbone == "levit_384":
# pretrained = _make_pretrained_levit_384(use_pretrained, hooks=hooks)
# scratch = _make_scratch(
# [384, 512, 768], features, groups=groups, expand=expand
# ) # LeViT 384 (backbone)
elif backbone == "vitl16_384":
pretrained = _make_pretrained_vitl16_384(
use_pretrained, hooks=hooks, use_readout=use_readout
)
@ -70,12 +129,15 @@ def _make_scratch(in_shape, out_shape, groups=1, expand=False):
out_shape1 = out_shape
out_shape2 = out_shape
out_shape3 = out_shape
out_shape4 = out_shape
if expand is True:
if len(in_shape) >= 4:
out_shape4 = out_shape
if expand:
out_shape1 = out_shape
out_shape2 = out_shape * 2
out_shape3 = out_shape * 4
out_shape4 = out_shape * 8
if len(in_shape) >= 4:
out_shape4 = out_shape * 8
scratch.layer1_rn = nn.Conv2d(
in_shape[0],
@ -104,15 +166,16 @@ def _make_scratch(in_shape, out_shape, groups=1, expand=False):
bias=False,
groups=groups,
)
scratch.layer4_rn = nn.Conv2d(
in_shape[3],
out_shape4,
kernel_size=3,
stride=1,
padding=1,
bias=False,
groups=groups,
)
if len(in_shape) >= 4:
scratch.layer4_rn = nn.Conv2d(
in_shape[3],
out_shape4,
kernel_size=3,
stride=1,
padding=1,
bias=False,
groups=groups,
)
return scratch
@ -162,8 +225,7 @@ class Interpolate(nn.Module):
"""Interpolation module."""
def __init__(self, scale_factor, mode, align_corners=False):
"""
Init.
"""Init.
Args:
scale_factor (float): scaling
@ -177,8 +239,7 @@ class Interpolate(nn.Module):
self.align_corners = align_corners
def forward(self, x):
"""
Forward pass.
"""Forward pass.
Args:
x (tensor): input
@ -201,8 +262,7 @@ class ResidualConvUnit(nn.Module):
"""Residual convolution module."""
def __init__(self, features):
"""
Init.
"""Init.
Args:
features (int): number of features
@ -220,8 +280,7 @@ class ResidualConvUnit(nn.Module):
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
"""
Forward pass.
"""Forward pass.
Args:
x (tensor): input
@ -241,8 +300,7 @@ class FeatureFusionBlock(nn.Module):
"""Feature fusion block."""
def __init__(self, features):
"""
Init.
"""Init.
Args:
features (int): number of features
@ -253,8 +311,7 @@ class FeatureFusionBlock(nn.Module):
self.resConfUnit2 = ResidualConvUnit(features)
def forward(self, *xs):
"""
Forward pass.
"""Forward pass.
Returns:
tensor: output
@ -277,8 +334,7 @@ class ResidualConvUnit_custom(nn.Module):
"""Residual convolution module."""
def __init__(self, features, activation, bn):
"""
Init.
"""Init.
Args:
features (int): number of features
@ -318,8 +374,7 @@ class ResidualConvUnit_custom(nn.Module):
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, x):
"""
Forward pass.
"""Forward pass.
Args:
x (tensor): input
@ -357,9 +412,9 @@ class FeatureFusionBlock_custom(nn.Module):
bn=False,
expand=False,
align_corners=True,
size=None,
):
"""
Init.
"""Init.
Args:
features (int): number of features
@ -391,9 +446,10 @@ class FeatureFusionBlock_custom(nn.Module):
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, *xs):
"""
Forward pass.
self.size = size
def forward(self, *xs, size=None):
"""Forward pass.
Returns:
tensor: output
@ -407,8 +463,15 @@ class FeatureFusionBlock_custom(nn.Module):
output = self.resConfUnit2(output)
if (size is None) and (self.size is None):
modifier = {"scale_factor": 2}
elif size is None:
modifier = {"size": self.size}
else:
modifier = {"size": size}
output = nn.functional.interpolate(
output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
output, **modifier, mode="bilinear", align_corners=self.align_corners
)
output = self.out_conv(output)

@ -1,14 +1,17 @@
"""Classes for depth estimation from images"""
import torch
from torch import nn
import torch.nn as nn
from .base_model import BaseModel
from .blocks import FeatureFusionBlock_custom, Interpolate, _make_encoder
from .vit import forward_vit
from .blocks import (
FeatureFusionBlock_custom,
Interpolate,
_make_encoder,
forward_beit,
forward_vit,
)
def _make_fusion_block(features, use_bn):
def _make_fusion_block(features, use_bn, size=None):
return FeatureFusionBlock_custom(
features,
nn.ReLU(False),
@ -16,6 +19,7 @@ def _make_fusion_block(features, use_bn):
bn=use_bn,
expand=False,
align_corners=True,
size=size,
)
@ -28,16 +32,40 @@ class DPT(BaseModel):
readout="project",
channels_last=False,
use_bn=False,
**kwargs
):
super().__init__()
self.channels_last = channels_last
# For the Swin, Swin 2, LeViT and Next-ViT Transformers, the hierarchical architectures prevent setting the
# hooks freely. Instead, the hooks have to be chosen according to the ranges specified in the comments.
hooks = {
"beitl16_512": [5, 11, 17, 23],
"beitl16_384": [5, 11, 17, 23],
"beitb16_384": [2, 5, 8, 11],
"swin2l24_384": [
1,
1,
17,
1,
], # Allowed ranges: [0, 1], [0, 1], [ 0, 17], [ 0, 1]
"swin2b24_384": [1, 1, 17, 1], # [0, 1], [0, 1], [ 0, 17], [ 0, 1]
"swin2t16_256": [1, 1, 5, 1], # [0, 1], [0, 1], [ 0, 5], [ 0, 1]
"swinl12_384": [1, 1, 17, 1], # [0, 1], [0, 1], [ 0, 17], [ 0, 1]
"next_vit_large_6m": [2, 6, 36, 39], # [0, 2], [3, 6], [ 7, 36], [37, 39]
"levit_384": [3, 11, 21], # [0, 3], [6, 11], [14, 21]
"vitb_rn50_384": [0, 1, 8, 11],
"vitb16_384": [2, 5, 8, 11],
"vitl16_384": [5, 11, 17, 23],
}
}[backbone]
if "next_vit" in backbone:
in_features = {
"next_vit_large_6m": [96, 256, 512, 1024],
}[backbone]
else:
in_features = None
# Instantiate backbone and reassemble blocks
self.pretrained, self.scratch = _make_encoder(
@ -47,14 +75,37 @@ class DPT(BaseModel):
groups=1,
expand=False,
exportable=False,
hooks=hooks[backbone],
hooks=hooks,
use_readout=readout,
in_features=in_features,
)
self.number_layers = len(hooks) if hooks is not None else 4
size_refinenet3 = None
self.scratch.stem_transpose = None
if "beit" in backbone:
self.forward_transformer = forward_beit
# elif "swin" in backbone:
# self.forward_transformer = forward_swin
# elif "next_vit" in backbone:
# from .backbones.next_vit import forward_next_vit
#
# self.forward_transformer = forward_next_vit
# elif "levit" in backbone:
# self.forward_transformer = forward_levit
# size_refinenet3 = 7
# self.scratch.stem_transpose = stem_b4_transpose(
# 256, 128, get_act_layer("hard_swish")
# )
else:
self.forward_transformer = forward_vit
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
self.scratch.refinenet3 = _make_fusion_block(features, use_bn, size_refinenet3)
if self.number_layers >= 4:
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
self.scratch.output_conv = head
@ -62,18 +113,31 @@ class DPT(BaseModel):
if self.channels_last is True:
x.contiguous(memory_format=torch.channels_last)
layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
layers = self.forward_transformer(self.pretrained, x)
if self.number_layers == 3:
layer_1, layer_2, layer_3 = layers
else:
layer_1, layer_2, layer_3, layer_4 = layers
layer_1_rn = self.scratch.layer1_rn(layer_1)
layer_2_rn = self.scratch.layer2_rn(layer_2)
layer_3_rn = self.scratch.layer3_rn(layer_3)
layer_4_rn = self.scratch.layer4_rn(layer_4)
path_4 = self.scratch.refinenet4(layer_4_rn)
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
if self.number_layers >= 4:
layer_4_rn = self.scratch.layer4_rn(layer_4)
if self.number_layers == 3:
path_3 = self.scratch.refinenet3(layer_3_rn, size=layer_2_rn.shape[2:])
else:
path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
path_3 = self.scratch.refinenet3(
path_4, layer_3_rn, size=layer_2_rn.shape[2:]
)
path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
if self.scratch.stem_transpose is not None:
path_1 = self.scratch.stem_transpose(path_1)
out = self.scratch.output_conv(path_1)
return out
@ -82,13 +146,33 @@ class DPT(BaseModel):
class DPTDepthModel(DPT):
def __init__(self, path=None, non_negative=True, **kwargs):
features = kwargs["features"] if "features" in kwargs else 256
head_features_1 = (
kwargs["head_features_1"] if "head_features_1" in kwargs else features
)
head_features_2 = (
kwargs["head_features_2"] if "head_features_2" in kwargs else 32
)
kwargs.pop("head_features_1", None)
kwargs.pop("head_features_2", None)
head = nn.Sequential(
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
nn.Conv2d(
head_features_1,
head_features_1 // 2,
kernel_size=3,
stride=1,
padding=1,
),
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
nn.Conv2d(
head_features_1 // 2,
head_features_2,
kernel_size=3,
stride=1,
padding=1,
),
nn.ReLU(True),
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
nn.ReLU(True) if non_negative else nn.Identity(),
nn.Identity(),
)

@ -1,490 +0,0 @@
"""Classes and functions for Vision Transformer processing"""
import math
import types
import timm
import torch
import torch.nn.functional as F
from torch import nn
class Slice(nn.Module):
def __init__(self, start_index=1):
super().__init__()
self.start_index = start_index
def forward(self, x):
return x[:, self.start_index :]
class AddReadout(nn.Module):
def __init__(self, start_index=1):
super().__init__()
self.start_index = start_index
def forward(self, x):
readout = (x[:, 0] + x[:, 1]) / 2 if self.start_index == 2 else x[:, 0]
return x[:, self.start_index :] + readout.unsqueeze(1)
class ProjectReadout(nn.Module):
def __init__(self, in_features, start_index=1):
super().__init__()
self.start_index = start_index
self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
def forward(self, x):
readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
features = torch.cat((x[:, self.start_index :], readout), -1)
return self.project(features)
class Transpose(nn.Module):
def __init__(self, dim0, dim1):
super().__init__()
self.dim0 = dim0
self.dim1 = dim1
def forward(self, x):
x = x.transpose(self.dim0, self.dim1)
return x
def forward_vit(pretrained, x):
b, c, h, w = x.shape
pretrained.model.forward_flex(x)
layer_1 = pretrained.activations["1"]
layer_2 = pretrained.activations["2"]
layer_3 = pretrained.activations["3"]
layer_4 = pretrained.activations["4"]
layer_1 = pretrained.act_postprocess1[0:2](layer_1)
layer_2 = pretrained.act_postprocess2[0:2](layer_2)
layer_3 = pretrained.act_postprocess3[0:2](layer_3)
layer_4 = pretrained.act_postprocess4[0:2](layer_4)
unflatten = nn.Sequential(
nn.Unflatten(
2,
torch.Size(
[
h // pretrained.model.patch_size[1],
w // pretrained.model.patch_size[0],
]
),
)
)
if layer_1.ndim == 3:
layer_1 = unflatten(layer_1)
if layer_2.ndim == 3:
layer_2 = unflatten(layer_2)
if layer_3.ndim == 3:
layer_3 = unflatten(layer_3)
if layer_4.ndim == 3:
layer_4 = unflatten(layer_4)
layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
return layer_1, layer_2, layer_3, layer_4
def _resize_pos_embed(self, posemb, gs_h, gs_w):
posemb_tok, posemb_grid = (
posemb[:, : self.start_index],
posemb[0, self.start_index :],
)
gs_old = int(math.sqrt(len(posemb_grid)))
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
return posemb
def forward_flex(self, x):
b, c, h, w = x.shape
pos_embed = self._resize_pos_embed(
self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
)
B = x.shape[0]
if hasattr(self.patch_embed, "backbone"):
x = self.patch_embed.backbone(x)
if isinstance(x, (list, tuple)):
x = x[-1] # last feature if backbone outputs list/tuple of features
x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
if getattr(self, "dist_token", None) is not None:
cls_tokens = self.cls_token.expand(
B, -1, -1
) # stole cls_tokens impl from Phil Wang, thanks
dist_token = self.dist_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, dist_token, x), dim=1)
else:
cls_tokens = self.cls_token.expand(
B, -1, -1
) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
x = x + pos_embed
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x
activations = {}
def get_activation(name):
def hook(model, input, output): # noqa
activations[name] = output
return hook
def get_readout_oper(vit_features, features, use_readout, start_index=1):
if use_readout == "ignore":
readout_oper = [Slice(start_index)] * len(features)
elif use_readout == "add":
readout_oper = [AddReadout(start_index)] * len(features)
elif use_readout == "project":
readout_oper = [
ProjectReadout(vit_features, start_index) for out_feat in features
]
else:
msg = "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
raise ValueError(msg)
return readout_oper
def _make_vit_b16_backbone(
model,
features=(96, 192, 384, 768),
size=(384, 384),
hooks=(2, 5, 8, 11),
vit_features=768,
use_readout="ignore",
start_index=1,
):
pretrained = nn.Module()
pretrained.model = model
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
pretrained.activations = activations
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
# 32, 48, 136, 384
pretrained.act_postprocess1 = nn.Sequential(
readout_oper[0],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[0],
kernel_size=1,
stride=1,
padding=0,
),
nn.ConvTranspose2d(
in_channels=features[0],
out_channels=features[0],
kernel_size=4,
stride=4,
padding=0,
bias=True,
dilation=1,
groups=1,
),
)
pretrained.act_postprocess2 = nn.Sequential(
readout_oper[1],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[1],
kernel_size=1,
stride=1,
padding=0,
),
nn.ConvTranspose2d(
in_channels=features[1],
out_channels=features[1],
kernel_size=2,
stride=2,
padding=0,
bias=True,
dilation=1,
groups=1,
),
)
pretrained.act_postprocess3 = nn.Sequential(
readout_oper[2],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[2],
kernel_size=1,
stride=1,
padding=0,
),
)
pretrained.act_postprocess4 = nn.Sequential(
readout_oper[3],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[3],
kernel_size=1,
stride=1,
padding=0,
),
nn.Conv2d(
in_channels=features[3],
out_channels=features[3],
kernel_size=3,
stride=2,
padding=1,
),
)
pretrained.model.start_index = start_index
pretrained.model.patch_size = [16, 16]
# We inject this function into the VisionTransformer instances so that
# we can use it with interpolated position embeddings without modifying the library source.
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
pretrained.model._resize_pos_embed = types.MethodType(
_resize_pos_embed, pretrained.model
)
return pretrained
def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
hooks = [5, 11, 17, 23] if hooks is None else hooks
return _make_vit_b16_backbone(
model,
features=[256, 512, 1024, 1024],
hooks=hooks,
vit_features=1024,
use_readout=use_readout,
)
def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
hooks = [2, 5, 8, 11] if hooks is None else hooks
return _make_vit_b16_backbone(
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
)
def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
hooks = [2, 5, 8, 11] if hooks is None else hooks
return _make_vit_b16_backbone(
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
)
def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
model = timm.create_model(
"vit_deit_base_distilled_patch16_384", pretrained=pretrained
)
hooks = [2, 5, 8, 11] if hooks is None else hooks
return _make_vit_b16_backbone(
model,
features=[96, 192, 384, 768],
hooks=hooks,
use_readout=use_readout,
start_index=2,
)
def _make_vit_b_rn50_backbone(
model,
features=(256, 512, 768, 768),
size=(384, 384),
hooks=(0, 1, 8, 11),
vit_features=768,
use_vit_only=False,
use_readout="ignore",
start_index=1,
):
pretrained = nn.Module()
pretrained.model = model
if use_vit_only is True:
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
else:
pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
get_activation("1")
)
pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
get_activation("2")
)
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
pretrained.activations = activations
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
if use_vit_only is True:
pretrained.act_postprocess1 = nn.Sequential(
readout_oper[0],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[0],
kernel_size=1,
stride=1,
padding=0,
),
nn.ConvTranspose2d(
in_channels=features[0],
out_channels=features[0],
kernel_size=4,
stride=4,
padding=0,
bias=True,
dilation=1,
groups=1,
),
)
pretrained.act_postprocess2 = nn.Sequential(
readout_oper[1],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[1],
kernel_size=1,
stride=1,
padding=0,
),
nn.ConvTranspose2d(
in_channels=features[1],
out_channels=features[1],
kernel_size=2,
stride=2,
padding=0,
bias=True,
dilation=1,
groups=1,
),
)
else:
pretrained.act_postprocess1 = nn.Sequential(
nn.Identity(), nn.Identity(), nn.Identity()
)
pretrained.act_postprocess2 = nn.Sequential(
nn.Identity(), nn.Identity(), nn.Identity()
)
pretrained.act_postprocess3 = nn.Sequential(
readout_oper[2],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[2],
kernel_size=1,
stride=1,
padding=0,
),
)
pretrained.act_postprocess4 = nn.Sequential(
readout_oper[3],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[3],
kernel_size=1,
stride=1,
padding=0,
),
nn.Conv2d(
in_channels=features[3],
out_channels=features[3],
kernel_size=3,
stride=2,
padding=1,
),
)
pretrained.model.start_index = start_index
pretrained.model.patch_size = [16, 16]
# We inject this function into the VisionTransformer instances so that
# we can use it with interpolated position embeddings without modifying the library source.
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
# We inject this function into the VisionTransformer instances so that
# we can use it with interpolated position embeddings without modifying the library source.
pretrained.model._resize_pos_embed = types.MethodType(
_resize_pos_embed, pretrained.model
)
return pretrained
def _make_pretrained_vitb_rn50_384(
pretrained, use_readout="ignore", hooks=None, use_vit_only=False
):
model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
hooks = [0, 1, 8, 11] if hooks is None else hooks
return _make_vit_b_rn50_backbone(
model,
features=[256, 512, 768, 768],
size=[384, 384],
hooks=hooks,
use_vit_only=use_vit_only,
use_readout=use_readout,
)

@ -233,6 +233,8 @@ def get_diffusion_model_refiners(
dtype=None,
) -> LatentDiffusionModel:
"""Load a diffusion model."""
print(weights_config)
print(f"for inpainting: {for_inpainting}")
return _get_diffusion_model_refiners(
weights_location=weights_config.weights_location,
for_inpainting=for_inpainting,

@ -3,6 +3,7 @@ test_clip_masking
test_clip_text_comparison
test_cliptext_inpainting_pearl_doctor
test_colorize_cmd
test_compare_depth_maps
test_control_images[depth-create_depth_map]
test_control_images[hed-create_hed_edges]
test_control_images[normal-create_normal_map]
@ -18,6 +19,7 @@ test_controlnet[normal]
test_controlnet[openpose]
test_controlnet[qrcode]
test_controlnet[shuffle]
test_create_depth_map
test_describe_cmd
test_describe_picture
test_edit_cmd

1 test_cache_ordering
3 test_clip_text_comparison
4 test_cliptext_inpainting_pearl_doctor
5 test_colorize_cmd
6 test_compare_depth_maps
7 test_control_images[depth-create_depth_map]
8 test_control_images[hed-create_hed_edges]
9 test_control_images[normal-create_normal_map]
19 test_controlnet[openpose]
20 test_controlnet[qrcode]
21 test_controlnet[shuffle]
22 test_create_depth_map
23 test_describe_cmd
24 test_describe_picture
25 test_edit_cmd

@ -1,7 +1,10 @@
import itertools
import pytest
from lightning_fabric import seed_everything
from imaginairy.img_processors.control_modes import CONTROL_MODES
from imaginairy.img_processors.control_modes import CONTROL_MODES, create_depth_map
from imaginairy.modules.midas.api import ISL_PATHS
from imaginairy.schema import LazyLoadingImage
from imaginairy.utils.img_utils import (
pillow_img_to_torch_image,
@ -18,6 +21,31 @@ def control_img_to_pillow_img(img_t):
control_mode_params = list(CONTROL_MODES.items())
@pytest.mark.skip()
def test_compare_depth_maps(filename_base_for_outputs):
sizes = [384, 512, 768]
model_types = ISL_PATHS
img = LazyLoadingImage(
url="https://zhyever.github.io/patchfusion/images/interactive/case6.png"
)
for model_type, size in itertools.product(model_types.keys(), sizes):
if (
"dpt_swin" in model_type
or "next_vit" in model_type
or "levit" in model_type
):
continue
print(f"Testing {model_type} with size {size}")
img_t = pillow_img_to_torch_image(img)
depth_t = create_depth_map(img_t, model_type=model_type, max_size=size)
depth_img = control_img_to_pillow_img(depth_t)
img_path = f"{filename_base_for_outputs}_{model_type}_{size}.png"
depth_img.save(img_path)
@pytest.mark.parametrize(("control_name", "control_func"), control_mode_params)
def test_control_images(filename_base_for_outputs, control_func, control_name):
seed_everything(42)

Loading…
Cancel
Save