diff --git a/Makefile b/Makefile index ea76f0b..3413a78 100644 --- a/Makefile +++ b/Makefile @@ -171,11 +171,12 @@ vendorize_controlnet_annotators: -vendorize_surface_normal_uncertainty: - make download_repo REPO=git@github.com:baegwangbin/surface_normal_uncertainty.git PKG=surface_normal_uncertainty COMMIT=fe2b9f1e8a4cac1c73475b023f6454bd23827a48 - mkdir -p ./imaginairy/vendored/surface_normal_uncertainty - rm -rf ./imaginairy/vendored/surface_normal_uncertainty/* - cp -R ./downloads/surface_normal_uncertainty/* ./imaginairy/vendored/surface_normal_uncertainty/ +vendorize_normal_map: + make download_repo REPO=git@github.com:brycedrennan/imaginairy-normal-map.git PKG=imaginairy_normal_map COMMIT=6b3b1692cbdc21d55c84a01e0b7875df030b6d79 + mkdir -p ./imaginairy/vendored/imaginairy_normal_map + rm -rf ./imaginairy/vendored/imaginairy_normal_map/* + cp -R ./downloads/imaginairy_normal_map/imaginairy_normal_map/* ./imaginairy/vendored/imaginairy_normal_map/ + make af diff --git a/imaginairy/img_processors/control_modes.py b/imaginairy/img_processors/control_modes.py index 1bd4251..18b647c 100644 --- a/imaginairy/img_processors/control_modes.py +++ b/imaginairy/img_processors/control_modes.py @@ -85,7 +85,10 @@ def _create_depth_map_raw(img): def create_normal_map(img): import torch - from imaginairy_normal_map.model import create_normal_map_torch_img + + from imaginairy.vendored.imaginairy_normal_map.model import ( + create_normal_map_torch_img, + ) normal_img_t = create_normal_map_torch_img(img) normal_img_t -= torch.min(normal_img_t) diff --git a/imaginairy/vendored/imaginairy_normal_map/__init__.py b/imaginairy/vendored/imaginairy_normal_map/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/imaginairy/vendored/imaginairy_normal_map/decoder.py b/imaginairy/vendored/imaginairy_normal_map/decoder.py new file mode 100644 index 0000000..b0a7501 --- /dev/null +++ b/imaginairy/vendored/imaginairy_normal_map/decoder.py @@ -0,0 +1,149 @@ +import torch +import torch.nn.functional as F +from torch import nn + +from .submodules import UpSampleBN, norm_normalize + + +class Decoder(nn.Module): + def __init__(self, architecture="BN", sampling_ratio=0.4, importance_ratio=0.7): + super().__init__() + + # hyper-parameter for sampling + self.sampling_ratio = sampling_ratio + self.importance_ratio = importance_ratio + + # feature-map + self.conv2 = nn.Conv2d(2048, 2048, kernel_size=1, stride=1, padding=0) + if architecture == "BN": + self.up1 = UpSampleBN(skip_input=2048 + 176, output_features=1024) + self.up2 = UpSampleBN(skip_input=1024 + 64, output_features=512) + self.up3 = UpSampleBN(skip_input=512 + 40, output_features=256) + self.up4 = UpSampleBN(skip_input=256 + 24, output_features=128) + else: + raise RuntimeError("invalid architecture") + + # produces 1/8 res output + self.out_conv_res8 = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) + + # produces 1/4 res output + self.out_conv_res4 = nn.Sequential( + nn.Conv1d(512 + 4, 128, kernel_size=1), + nn.ReLU(), + nn.Conv1d(128, 128, kernel_size=1), + nn.ReLU(), + nn.Conv1d(128, 128, kernel_size=1), + nn.ReLU(), + nn.Conv1d(128, 4, kernel_size=1), + ) + + # produces 1/2 res output + self.out_conv_res2 = nn.Sequential( + nn.Conv1d(256 + 4, 128, kernel_size=1), + nn.ReLU(), + nn.Conv1d(128, 128, kernel_size=1), + nn.ReLU(), + nn.Conv1d(128, 128, kernel_size=1), + nn.ReLU(), + nn.Conv1d(128, 4, kernel_size=1), + ) + + # produces 1/1 res output + self.out_conv_res1 = nn.Sequential( + nn.Conv1d(128 + 4, 128, kernel_size=1), + nn.ReLU(), + nn.Conv1d(128, 128, kernel_size=1), + nn.ReLU(), + nn.Conv1d(128, 128, kernel_size=1), + nn.ReLU(), + nn.Conv1d(128, 4, kernel_size=1), + ) + + def forward(self, features, gt_norm_mask=None, mode="test"): + x_block0, x_block1, x_block2, x_block3, x_block4 = ( + features[4], + features[5], + features[6], + features[8], + features[11], + ) + + # generate feature-map + + x_d0 = self.conv2(x_block4) # x_d0 : [2, 2048, 15, 20] 1/32 res + x_d1 = self.up1(x_d0, x_block3) # x_d1 : [2, 1024, 30, 40] 1/16 res + x_d2 = self.up2(x_d1, x_block2) # x_d2 : [2, 512, 60, 80] 1/8 res + x_d3 = self.up3(x_d2, x_block1) # x_d3: [2, 256, 120, 160] 1/4 res + x_d4 = self.up4(x_d3, x_block0) # x_d4: [2, 128, 240, 320] 1/2 res + + # 1/8 res output + out_res8 = self.out_conv_res8( + x_d2 + ) # out_res8: [2, 4, 60, 80] 1/8 res output + out_res8 = norm_normalize( + out_res8 + ) # out_res8: [2, 4, 60, 80] 1/8 res output + + ################################################################################################################ + # out_res4 + ################################################################################################################ + + # grid_sample feature-map + feat_map = F.interpolate( + x_d2, scale_factor=2, mode="bilinear", align_corners=True + ) + init_pred = F.interpolate( + out_res8, scale_factor=2, mode="bilinear", align_corners=True + ) + feat_map = torch.cat([feat_map, init_pred], dim=1) # (B, 512+4, H, W) + B, _, H, W = feat_map.shape + + # try all pixels + out_res4 = self.out_conv_res4(feat_map.view(B, 512 + 4, -1)) # (B, 4, N) + out_res4 = norm_normalize(out_res4) # (B, 4, N) - normalized + out_res4 = out_res4.view(B, 4, H, W) + samples_pred_res4 = point_coords_res4 = None + + ################################################################################################################ + # out_res2 + ################################################################################################################ + + # grid_sample feature-map + feat_map = F.interpolate( + x_d3, scale_factor=2, mode="bilinear", align_corners=True + ) + init_pred = F.interpolate( + out_res4, scale_factor=2, mode="bilinear", align_corners=True + ) + feat_map = torch.cat([feat_map, init_pred], dim=1) # (B, 512+4, H, W) + B, _, H, W = feat_map.shape + + out_res2 = self.out_conv_res2(feat_map.view(B, 256 + 4, -1)) # (B, 4, N) + out_res2 = norm_normalize(out_res2) # (B, 4, N) - normalized + out_res2 = out_res2.view(B, 4, H, W) + samples_pred_res2 = point_coords_res2 = None + + ################################################################################################################ + # out_res1 + ################################################################################################################ + + # grid_sample feature-map + feat_map = F.interpolate( + x_d4, scale_factor=2, mode="bilinear", align_corners=True + ) + init_pred = F.interpolate( + out_res2, scale_factor=2, mode="bilinear", align_corners=True + ) + feat_map = torch.cat([feat_map, init_pred], dim=1) # (B, 512+4, H, W) + B, _, H, W = feat_map.shape + + out_res1 = self.out_conv_res1(feat_map.view(B, 128 + 4, -1)) # (B, 4, N) + out_res1 = norm_normalize(out_res1) # (B, 4, N) - normalized + out_res1 = out_res1.view(B, 4, H, W) + samples_pred_res1 = point_coords_res1 = None + + return ( + [out_res8, out_res4, out_res2, out_res1], + [out_res8, samples_pred_res4, samples_pred_res2, samples_pred_res1], + [None, point_coords_res4, point_coords_res2, point_coords_res1], + ) diff --git a/imaginairy/vendored/imaginairy_normal_map/encoder.py b/imaginairy/vendored/imaginairy_normal_map/encoder.py new file mode 100644 index 0000000..3b1c9ad --- /dev/null +++ b/imaginairy/vendored/imaginairy_normal_map/encoder.py @@ -0,0 +1,31 @@ +import timm +from torch import nn + + +class Encoder(nn.Module): + def __init__(self): + super().__init__() + basemodel_name = "tf_efficientnet_b5_ap" + basemodel = timm.create_model( + basemodel_name, pretrained=True, num_classes=0, global_pool="" + ) + basemodel.eval() + + self.original_model = basemodel + + def forward(self, x): + features = [x] + for k, v in self.original_model._modules.items(): # noqa + if k == "blocks": + for ki, vi in v._modules.items(): # noqa + features.append(vi(features[-1])) + else: + features.append(v(features[-1])) + # Decoder was made for handling output of NNET from rwightman/gen-efficientnet-pytorch + # that version outputs 16 features but since the decoder doesn't use the extra features + # we just placeholder None values + if len(features) == 14: + features.insert(2, None) + features.insert(12, None) + + return features diff --git a/imaginairy/vendored/imaginairy_normal_map/model.py b/imaginairy/vendored/imaginairy_normal_map/model.py new file mode 100644 index 0000000..50a7496 --- /dev/null +++ b/imaginairy/vendored/imaginairy_normal_map/model.py @@ -0,0 +1,102 @@ +from functools import lru_cache + +import numpy as np +import torch +from huggingface_hub import hf_hub_download +from PIL import Image +from torch import nn +from torchvision import transforms + +from .decoder import Decoder +from .encoder import Encoder +from .utils import get_device + + +class NNET(nn.Module): + def __init__(self, architecture="BN", sampling_ratio=0.4, importance_ratio=0.7): + super().__init__() + self.encoder = Encoder() + self.decoder = Decoder( + architecture=architecture, + sampling_ratio=sampling_ratio, + importance_ratio=importance_ratio, + ) + + def forward(self, img, **kwargs): + return self.decoder(self.encoder(img), **kwargs) + + +def create_normal_map_pil_img(img, device=get_device()): + img_t = pillow_img_to_torch_normal_map_input(img).to(device) + pred_norm = create_normal_map_torch_img(img_t, device=device) + return torch_normal_map_to_pillow_img(pred_norm) + + +def create_normal_map_torch_img(img_t, device=get_device()): + with torch.no_grad(): + model = load_model(device=device) + img_t = img_t.to(device) + norm_out_list, _, _ = model(img_t) # noqa + norm_out = norm_out_list[-1] + + pred_norm_t = norm_out[:, :3, :, :] + return pred_norm_t + + +normalize_img = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] +) + + +def pillow_img_to_torch_normal_map_input(img): + img = np.array(img).astype(np.float32) / 255.0 + img = torch.from_numpy(img).permute(2, 0, 1) + img = normalize_img(img) + img = img.unsqueeze(0) + + # Resize image to nearest multiple of 8 using interpolate() + h, w = img.size(2), img.size(3) + h_new = int(round(h / 8) * 8) + w_new = int(round(w / 8) * 8) + img = torch.nn.functional.interpolate( + img, size=(h_new, w_new), mode="bilinear", align_corners=False + ) + + return img + + +def torch_normal_map_to_pillow_img(norm_map_t): + norm_map_np = norm_map_t.detach().cpu().permute(0, 2, 3, 1).numpy() # (B, H, W, 3) + pred_norm_rgb = ((norm_map_np + 1) * 0.5) * 255 + pred_norm_rgb = np.clip(pred_norm_rgb, a_min=0, a_max=255) + pred_norm_rgb = pred_norm_rgb.astype(np.uint8) # (B, H, W, 3) + return Image.fromarray(pred_norm_rgb[0]) + + +def load_checkpoint(fpath, model): + ckpt = torch.load(fpath, map_location="cpu")["model"] + + load_dict = {} + for k, v in ckpt.items(): + load_dict[k] = v + + model.load_state_dict(load_dict) + return model + + +@lru_cache(maxsize=1) +def load_model(device=None, sampling_ratio=0.4, importance_ratio=0.7) -> NNET: + device = device or get_device() + weights_path = hf_hub_download( + repo_id="imaginairy/imaginairy-normal-uncertainty-map", filename="scannet.pt" + ) + architecture = "BN" + + model = NNET( + architecture=architecture, + sampling_ratio=sampling_ratio, + importance_ratio=importance_ratio, + ).to(device) + model = load_checkpoint(weights_path, model) + model.eval() + return model diff --git a/imaginairy/vendored/imaginairy_normal_map/submodules.py b/imaginairy/vendored/imaginairy_normal_map/submodules.py new file mode 100644 index 0000000..521c005 --- /dev/null +++ b/imaginairy/vendored/imaginairy_normal_map/submodules.py @@ -0,0 +1,39 @@ +import torch +import torch.nn.functional as F +from torch import nn + + +# Upsample + BatchNorm +class UpSampleBN(nn.Module): + def __init__(self, skip_input, output_features): + super().__init__() + + self._net = nn.Sequential( + nn.Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(output_features), + nn.LeakyReLU(), + nn.Conv2d( + output_features, output_features, kernel_size=3, stride=1, padding=1 + ), + nn.BatchNorm2d(output_features), + nn.LeakyReLU(), + ) + + def forward(self, x, concat_with): + up_x = F.interpolate( + x, + size=[concat_with.size(2), concat_with.size(3)], + mode="bilinear", + align_corners=True, + ) + f = torch.cat([up_x, concat_with], dim=1) + return self._net(f) + + +def norm_normalize(norm_out): + min_kappa = 0.01 + norm_x, norm_y, norm_z, kappa = torch.split(norm_out, 1, dim=1) + norm = torch.sqrt(norm_x**2.0 + norm_y**2.0 + norm_z**2.0) + 1e-10 + kappa = F.elu(kappa) + 1.0 + min_kappa + final_out = torch.cat([norm_x / norm, norm_y / norm, norm_z / norm, kappa], dim=1) + return final_out diff --git a/imaginairy/vendored/imaginairy_normal_map/utils.py b/imaginairy/vendored/imaginairy_normal_map/utils.py new file mode 100644 index 0000000..6471fe0 --- /dev/null +++ b/imaginairy/vendored/imaginairy_normal_map/utils.py @@ -0,0 +1,15 @@ +from functools import lru_cache + +import torch + + +@lru_cache() +def get_device() -> str: + """Return the best torch backend available.""" + if torch.cuda.is_available(): + return "cuda" + + if torch.backends.mps.is_available(): + return "mps:0" + + return "cpu" diff --git a/requirements-dev.txt b/requirements-dev.txt index b4e0dbf..2ddf55d 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -77,7 +77,6 @@ ftfy==6.1.1 huggingface-hub==0.14.1 # via # diffusers - # imaginairy-normal-map # open-clip-torch # timm # transformers @@ -87,8 +86,6 @@ idna==3.4 # yarl imageio==2.28.1 # via imaginAIry (setup.py) -imaginairy-normal-map==0.0.3 - # via imaginAIry (setup.py) importlib-metadata==6.6.0 # via diffusers iniconfig==2.0.0 @@ -196,7 +193,7 @@ pyflakes==3.0.1 # via pylama pylama==8.4.1 # via -r requirements-dev.in -pylint==2.17.3 +pylint==2.17.4 # via -r requirements-dev.in pyparsing==3.0.9 # via matplotlib @@ -239,7 +236,7 @@ requests==2.30.0 # transformers responses==0.23.1 # via -r requirements-dev.in -ruff==0.0.264 +ruff==0.0.265 # via -r requirements-dev.in safetensors==0.3.1 # via imaginAIry (setup.py) @@ -259,7 +256,6 @@ termcolor==2.3.0 timm==0.6.13 # via # imaginAIry (setup.py) - # imaginairy-normal-map # open-clip-torch tokenizers==0.13.3 # via transformers @@ -277,7 +273,6 @@ torch==1.13.1 # facexlib # fairscale # imaginAIry (setup.py) - # imaginairy-normal-map # kornia # open-clip-torch # pytorch-lightning @@ -295,7 +290,6 @@ torchvision==0.14.1 # via # facexlib # imaginAIry (setup.py) - # imaginairy-normal-map # open-clip-torch # timm tqdm==4.65.0 diff --git a/setup.py b/setup.py index e388fac..c4fb8ad 100644 --- a/setup.py +++ b/setup.py @@ -75,7 +75,6 @@ setup( "tqdm", "diffusers", "imageio>=2.9.0", - "imaginairy-normal-map", "Pillow>=8.0.0", "psutil", # 2.0.0 need to fix `ImportError: cannot import name 'rank_zero_only' from 'pytorch_lightning.utilities.distributed' `