fix: move normal map code inline

Fixes conda package. Fixes #317
pull/321/head
Bryce 1 year ago committed by Bryce Drennan
parent 726ffe48c9
commit d5a276584b

@ -171,11 +171,12 @@ vendorize_controlnet_annotators:
vendorize_surface_normal_uncertainty:
make download_repo REPO=git@github.com:baegwangbin/surface_normal_uncertainty.git PKG=surface_normal_uncertainty COMMIT=fe2b9f1e8a4cac1c73475b023f6454bd23827a48
mkdir -p ./imaginairy/vendored/surface_normal_uncertainty
rm -rf ./imaginairy/vendored/surface_normal_uncertainty/*
cp -R ./downloads/surface_normal_uncertainty/* ./imaginairy/vendored/surface_normal_uncertainty/
vendorize_normal_map:
make download_repo REPO=git@github.com:brycedrennan/imaginairy-normal-map.git PKG=imaginairy_normal_map COMMIT=6b3b1692cbdc21d55c84a01e0b7875df030b6d79
mkdir -p ./imaginairy/vendored/imaginairy_normal_map
rm -rf ./imaginairy/vendored/imaginairy_normal_map/*
cp -R ./downloads/imaginairy_normal_map/imaginairy_normal_map/* ./imaginairy/vendored/imaginairy_normal_map/
make af

@ -85,7 +85,10 @@ def _create_depth_map_raw(img):
def create_normal_map(img):
import torch
from imaginairy_normal_map.model import create_normal_map_torch_img
from imaginairy.vendored.imaginairy_normal_map.model import (
create_normal_map_torch_img,
)
normal_img_t = create_normal_map_torch_img(img)
normal_img_t -= torch.min(normal_img_t)

@ -0,0 +1,149 @@
import torch
import torch.nn.functional as F
from torch import nn
from .submodules import UpSampleBN, norm_normalize
class Decoder(nn.Module):
def __init__(self, architecture="BN", sampling_ratio=0.4, importance_ratio=0.7):
super().__init__()
# hyper-parameter for sampling
self.sampling_ratio = sampling_ratio
self.importance_ratio = importance_ratio
# feature-map
self.conv2 = nn.Conv2d(2048, 2048, kernel_size=1, stride=1, padding=0)
if architecture == "BN":
self.up1 = UpSampleBN(skip_input=2048 + 176, output_features=1024)
self.up2 = UpSampleBN(skip_input=1024 + 64, output_features=512)
self.up3 = UpSampleBN(skip_input=512 + 40, output_features=256)
self.up4 = UpSampleBN(skip_input=256 + 24, output_features=128)
else:
raise RuntimeError("invalid architecture")
# produces 1/8 res output
self.out_conv_res8 = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
# produces 1/4 res output
self.out_conv_res4 = nn.Sequential(
nn.Conv1d(512 + 4, 128, kernel_size=1),
nn.ReLU(),
nn.Conv1d(128, 128, kernel_size=1),
nn.ReLU(),
nn.Conv1d(128, 128, kernel_size=1),
nn.ReLU(),
nn.Conv1d(128, 4, kernel_size=1),
)
# produces 1/2 res output
self.out_conv_res2 = nn.Sequential(
nn.Conv1d(256 + 4, 128, kernel_size=1),
nn.ReLU(),
nn.Conv1d(128, 128, kernel_size=1),
nn.ReLU(),
nn.Conv1d(128, 128, kernel_size=1),
nn.ReLU(),
nn.Conv1d(128, 4, kernel_size=1),
)
# produces 1/1 res output
self.out_conv_res1 = nn.Sequential(
nn.Conv1d(128 + 4, 128, kernel_size=1),
nn.ReLU(),
nn.Conv1d(128, 128, kernel_size=1),
nn.ReLU(),
nn.Conv1d(128, 128, kernel_size=1),
nn.ReLU(),
nn.Conv1d(128, 4, kernel_size=1),
)
def forward(self, features, gt_norm_mask=None, mode="test"):
x_block0, x_block1, x_block2, x_block3, x_block4 = (
features[4],
features[5],
features[6],
features[8],
features[11],
)
# generate feature-map
x_d0 = self.conv2(x_block4) # x_d0 : [2, 2048, 15, 20] 1/32 res
x_d1 = self.up1(x_d0, x_block3) # x_d1 : [2, 1024, 30, 40] 1/16 res
x_d2 = self.up2(x_d1, x_block2) # x_d2 : [2, 512, 60, 80] 1/8 res
x_d3 = self.up3(x_d2, x_block1) # x_d3: [2, 256, 120, 160] 1/4 res
x_d4 = self.up4(x_d3, x_block0) # x_d4: [2, 128, 240, 320] 1/2 res
# 1/8 res output
out_res8 = self.out_conv_res8(
x_d2
) # out_res8: [2, 4, 60, 80] 1/8 res output
out_res8 = norm_normalize(
out_res8
) # out_res8: [2, 4, 60, 80] 1/8 res output
################################################################################################################
# out_res4
################################################################################################################
# grid_sample feature-map
feat_map = F.interpolate(
x_d2, scale_factor=2, mode="bilinear", align_corners=True
)
init_pred = F.interpolate(
out_res8, scale_factor=2, mode="bilinear", align_corners=True
)
feat_map = torch.cat([feat_map, init_pred], dim=1) # (B, 512+4, H, W)
B, _, H, W = feat_map.shape
# try all pixels
out_res4 = self.out_conv_res4(feat_map.view(B, 512 + 4, -1)) # (B, 4, N)
out_res4 = norm_normalize(out_res4) # (B, 4, N) - normalized
out_res4 = out_res4.view(B, 4, H, W)
samples_pred_res4 = point_coords_res4 = None
################################################################################################################
# out_res2
################################################################################################################
# grid_sample feature-map
feat_map = F.interpolate(
x_d3, scale_factor=2, mode="bilinear", align_corners=True
)
init_pred = F.interpolate(
out_res4, scale_factor=2, mode="bilinear", align_corners=True
)
feat_map = torch.cat([feat_map, init_pred], dim=1) # (B, 512+4, H, W)
B, _, H, W = feat_map.shape
out_res2 = self.out_conv_res2(feat_map.view(B, 256 + 4, -1)) # (B, 4, N)
out_res2 = norm_normalize(out_res2) # (B, 4, N) - normalized
out_res2 = out_res2.view(B, 4, H, W)
samples_pred_res2 = point_coords_res2 = None
################################################################################################################
# out_res1
################################################################################################################
# grid_sample feature-map
feat_map = F.interpolate(
x_d4, scale_factor=2, mode="bilinear", align_corners=True
)
init_pred = F.interpolate(
out_res2, scale_factor=2, mode="bilinear", align_corners=True
)
feat_map = torch.cat([feat_map, init_pred], dim=1) # (B, 512+4, H, W)
B, _, H, W = feat_map.shape
out_res1 = self.out_conv_res1(feat_map.view(B, 128 + 4, -1)) # (B, 4, N)
out_res1 = norm_normalize(out_res1) # (B, 4, N) - normalized
out_res1 = out_res1.view(B, 4, H, W)
samples_pred_res1 = point_coords_res1 = None
return (
[out_res8, out_res4, out_res2, out_res1],
[out_res8, samples_pred_res4, samples_pred_res2, samples_pred_res1],
[None, point_coords_res4, point_coords_res2, point_coords_res1],
)

@ -0,0 +1,31 @@
import timm
from torch import nn
class Encoder(nn.Module):
def __init__(self):
super().__init__()
basemodel_name = "tf_efficientnet_b5_ap"
basemodel = timm.create_model(
basemodel_name, pretrained=True, num_classes=0, global_pool=""
)
basemodel.eval()
self.original_model = basemodel
def forward(self, x):
features = [x]
for k, v in self.original_model._modules.items(): # noqa
if k == "blocks":
for ki, vi in v._modules.items(): # noqa
features.append(vi(features[-1]))
else:
features.append(v(features[-1]))
# Decoder was made for handling output of NNET from rwightman/gen-efficientnet-pytorch
# that version outputs 16 features but since the decoder doesn't use the extra features
# we just placeholder None values
if len(features) == 14:
features.insert(2, None)
features.insert(12, None)
return features

@ -0,0 +1,102 @@
from functools import lru_cache
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from torch import nn
from torchvision import transforms
from .decoder import Decoder
from .encoder import Encoder
from .utils import get_device
class NNET(nn.Module):
def __init__(self, architecture="BN", sampling_ratio=0.4, importance_ratio=0.7):
super().__init__()
self.encoder = Encoder()
self.decoder = Decoder(
architecture=architecture,
sampling_ratio=sampling_ratio,
importance_ratio=importance_ratio,
)
def forward(self, img, **kwargs):
return self.decoder(self.encoder(img), **kwargs)
def create_normal_map_pil_img(img, device=get_device()):
img_t = pillow_img_to_torch_normal_map_input(img).to(device)
pred_norm = create_normal_map_torch_img(img_t, device=device)
return torch_normal_map_to_pillow_img(pred_norm)
def create_normal_map_torch_img(img_t, device=get_device()):
with torch.no_grad():
model = load_model(device=device)
img_t = img_t.to(device)
norm_out_list, _, _ = model(img_t) # noqa
norm_out = norm_out_list[-1]
pred_norm_t = norm_out[:, :3, :, :]
return pred_norm_t
normalize_img = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
def pillow_img_to_torch_normal_map_input(img):
img = np.array(img).astype(np.float32) / 255.0
img = torch.from_numpy(img).permute(2, 0, 1)
img = normalize_img(img)
img = img.unsqueeze(0)
# Resize image to nearest multiple of 8 using interpolate()
h, w = img.size(2), img.size(3)
h_new = int(round(h / 8) * 8)
w_new = int(round(w / 8) * 8)
img = torch.nn.functional.interpolate(
img, size=(h_new, w_new), mode="bilinear", align_corners=False
)
return img
def torch_normal_map_to_pillow_img(norm_map_t):
norm_map_np = norm_map_t.detach().cpu().permute(0, 2, 3, 1).numpy() # (B, H, W, 3)
pred_norm_rgb = ((norm_map_np + 1) * 0.5) * 255
pred_norm_rgb = np.clip(pred_norm_rgb, a_min=0, a_max=255)
pred_norm_rgb = pred_norm_rgb.astype(np.uint8) # (B, H, W, 3)
return Image.fromarray(pred_norm_rgb[0])
def load_checkpoint(fpath, model):
ckpt = torch.load(fpath, map_location="cpu")["model"]
load_dict = {}
for k, v in ckpt.items():
load_dict[k] = v
model.load_state_dict(load_dict)
return model
@lru_cache(maxsize=1)
def load_model(device=None, sampling_ratio=0.4, importance_ratio=0.7) -> NNET:
device = device or get_device()
weights_path = hf_hub_download(
repo_id="imaginairy/imaginairy-normal-uncertainty-map", filename="scannet.pt"
)
architecture = "BN"
model = NNET(
architecture=architecture,
sampling_ratio=sampling_ratio,
importance_ratio=importance_ratio,
).to(device)
model = load_checkpoint(weights_path, model)
model.eval()
return model

@ -0,0 +1,39 @@
import torch
import torch.nn.functional as F
from torch import nn
# Upsample + BatchNorm
class UpSampleBN(nn.Module):
def __init__(self, skip_input, output_features):
super().__init__()
self._net = nn.Sequential(
nn.Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(output_features),
nn.LeakyReLU(),
nn.Conv2d(
output_features, output_features, kernel_size=3, stride=1, padding=1
),
nn.BatchNorm2d(output_features),
nn.LeakyReLU(),
)
def forward(self, x, concat_with):
up_x = F.interpolate(
x,
size=[concat_with.size(2), concat_with.size(3)],
mode="bilinear",
align_corners=True,
)
f = torch.cat([up_x, concat_with], dim=1)
return self._net(f)
def norm_normalize(norm_out):
min_kappa = 0.01
norm_x, norm_y, norm_z, kappa = torch.split(norm_out, 1, dim=1)
norm = torch.sqrt(norm_x**2.0 + norm_y**2.0 + norm_z**2.0) + 1e-10
kappa = F.elu(kappa) + 1.0 + min_kappa
final_out = torch.cat([norm_x / norm, norm_y / norm, norm_z / norm, kappa], dim=1)
return final_out

@ -0,0 +1,15 @@
from functools import lru_cache
import torch
@lru_cache()
def get_device() -> str:
"""Return the best torch backend available."""
if torch.cuda.is_available():
return "cuda"
if torch.backends.mps.is_available():
return "mps:0"
return "cpu"

@ -77,7 +77,6 @@ ftfy==6.1.1
huggingface-hub==0.14.1
# via
# diffusers
# imaginairy-normal-map
# open-clip-torch
# timm
# transformers
@ -87,8 +86,6 @@ idna==3.4
# yarl
imageio==2.28.1
# via imaginAIry (setup.py)
imaginairy-normal-map==0.0.3
# via imaginAIry (setup.py)
importlib-metadata==6.6.0
# via diffusers
iniconfig==2.0.0
@ -196,7 +193,7 @@ pyflakes==3.0.1
# via pylama
pylama==8.4.1
# via -r requirements-dev.in
pylint==2.17.3
pylint==2.17.4
# via -r requirements-dev.in
pyparsing==3.0.9
# via matplotlib
@ -239,7 +236,7 @@ requests==2.30.0
# transformers
responses==0.23.1
# via -r requirements-dev.in
ruff==0.0.264
ruff==0.0.265
# via -r requirements-dev.in
safetensors==0.3.1
# via imaginAIry (setup.py)
@ -259,7 +256,6 @@ termcolor==2.3.0
timm==0.6.13
# via
# imaginAIry (setup.py)
# imaginairy-normal-map
# open-clip-torch
tokenizers==0.13.3
# via transformers
@ -277,7 +273,6 @@ torch==1.13.1
# facexlib
# fairscale
# imaginAIry (setup.py)
# imaginairy-normal-map
# kornia
# open-clip-torch
# pytorch-lightning
@ -295,7 +290,6 @@ torchvision==0.14.1
# via
# facexlib
# imaginAIry (setup.py)
# imaginairy-normal-map
# open-clip-torch
# timm
tqdm==4.65.0

@ -75,7 +75,6 @@ setup(
"tqdm",
"diffusers",
"imageio>=2.9.0",
"imaginairy-normal-map",
"Pillow>=8.0.0",
"psutil",
# 2.0.0 need to fix `ImportError: cannot import name 'rank_zero_only' from 'pytorch_lightning.utilities.distributed' `

Loading…
Cancel
Save