mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-10-31 03:20:40 +00:00
parent
726ffe48c9
commit
d5a276584b
11
Makefile
11
Makefile
@ -171,11 +171,12 @@ vendorize_controlnet_annotators:
|
||||
|
||||
|
||||
|
||||
vendorize_surface_normal_uncertainty:
|
||||
make download_repo REPO=git@github.com:baegwangbin/surface_normal_uncertainty.git PKG=surface_normal_uncertainty COMMIT=fe2b9f1e8a4cac1c73475b023f6454bd23827a48
|
||||
mkdir -p ./imaginairy/vendored/surface_normal_uncertainty
|
||||
rm -rf ./imaginairy/vendored/surface_normal_uncertainty/*
|
||||
cp -R ./downloads/surface_normal_uncertainty/* ./imaginairy/vendored/surface_normal_uncertainty/
|
||||
vendorize_normal_map:
|
||||
make download_repo REPO=git@github.com:brycedrennan/imaginairy-normal-map.git PKG=imaginairy_normal_map COMMIT=6b3b1692cbdc21d55c84a01e0b7875df030b6d79
|
||||
mkdir -p ./imaginairy/vendored/imaginairy_normal_map
|
||||
rm -rf ./imaginairy/vendored/imaginairy_normal_map/*
|
||||
cp -R ./downloads/imaginairy_normal_map/imaginairy_normal_map/* ./imaginairy/vendored/imaginairy_normal_map/
|
||||
make af
|
||||
|
||||
|
||||
|
||||
|
@ -85,7 +85,10 @@ def _create_depth_map_raw(img):
|
||||
|
||||
def create_normal_map(img):
|
||||
import torch
|
||||
from imaginairy_normal_map.model import create_normal_map_torch_img
|
||||
|
||||
from imaginairy.vendored.imaginairy_normal_map.model import (
|
||||
create_normal_map_torch_img,
|
||||
)
|
||||
|
||||
normal_img_t = create_normal_map_torch_img(img)
|
||||
normal_img_t -= torch.min(normal_img_t)
|
||||
|
149
imaginairy/vendored/imaginairy_normal_map/decoder.py
Normal file
149
imaginairy/vendored/imaginairy_normal_map/decoder.py
Normal file
@ -0,0 +1,149 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from .submodules import UpSampleBN, norm_normalize
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, architecture="BN", sampling_ratio=0.4, importance_ratio=0.7):
|
||||
super().__init__()
|
||||
|
||||
# hyper-parameter for sampling
|
||||
self.sampling_ratio = sampling_ratio
|
||||
self.importance_ratio = importance_ratio
|
||||
|
||||
# feature-map
|
||||
self.conv2 = nn.Conv2d(2048, 2048, kernel_size=1, stride=1, padding=0)
|
||||
if architecture == "BN":
|
||||
self.up1 = UpSampleBN(skip_input=2048 + 176, output_features=1024)
|
||||
self.up2 = UpSampleBN(skip_input=1024 + 64, output_features=512)
|
||||
self.up3 = UpSampleBN(skip_input=512 + 40, output_features=256)
|
||||
self.up4 = UpSampleBN(skip_input=256 + 24, output_features=128)
|
||||
else:
|
||||
raise RuntimeError("invalid architecture")
|
||||
|
||||
# produces 1/8 res output
|
||||
self.out_conv_res8 = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
# produces 1/4 res output
|
||||
self.out_conv_res4 = nn.Sequential(
|
||||
nn.Conv1d(512 + 4, 128, kernel_size=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(128, 128, kernel_size=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(128, 128, kernel_size=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(128, 4, kernel_size=1),
|
||||
)
|
||||
|
||||
# produces 1/2 res output
|
||||
self.out_conv_res2 = nn.Sequential(
|
||||
nn.Conv1d(256 + 4, 128, kernel_size=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(128, 128, kernel_size=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(128, 128, kernel_size=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(128, 4, kernel_size=1),
|
||||
)
|
||||
|
||||
# produces 1/1 res output
|
||||
self.out_conv_res1 = nn.Sequential(
|
||||
nn.Conv1d(128 + 4, 128, kernel_size=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(128, 128, kernel_size=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(128, 128, kernel_size=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(128, 4, kernel_size=1),
|
||||
)
|
||||
|
||||
def forward(self, features, gt_norm_mask=None, mode="test"):
|
||||
x_block0, x_block1, x_block2, x_block3, x_block4 = (
|
||||
features[4],
|
||||
features[5],
|
||||
features[6],
|
||||
features[8],
|
||||
features[11],
|
||||
)
|
||||
|
||||
# generate feature-map
|
||||
|
||||
x_d0 = self.conv2(x_block4) # x_d0 : [2, 2048, 15, 20] 1/32 res
|
||||
x_d1 = self.up1(x_d0, x_block3) # x_d1 : [2, 1024, 30, 40] 1/16 res
|
||||
x_d2 = self.up2(x_d1, x_block2) # x_d2 : [2, 512, 60, 80] 1/8 res
|
||||
x_d3 = self.up3(x_d2, x_block1) # x_d3: [2, 256, 120, 160] 1/4 res
|
||||
x_d4 = self.up4(x_d3, x_block0) # x_d4: [2, 128, 240, 320] 1/2 res
|
||||
|
||||
# 1/8 res output
|
||||
out_res8 = self.out_conv_res8(
|
||||
x_d2
|
||||
) # out_res8: [2, 4, 60, 80] 1/8 res output
|
||||
out_res8 = norm_normalize(
|
||||
out_res8
|
||||
) # out_res8: [2, 4, 60, 80] 1/8 res output
|
||||
|
||||
################################################################################################################
|
||||
# out_res4
|
||||
################################################################################################################
|
||||
|
||||
# grid_sample feature-map
|
||||
feat_map = F.interpolate(
|
||||
x_d2, scale_factor=2, mode="bilinear", align_corners=True
|
||||
)
|
||||
init_pred = F.interpolate(
|
||||
out_res8, scale_factor=2, mode="bilinear", align_corners=True
|
||||
)
|
||||
feat_map = torch.cat([feat_map, init_pred], dim=1) # (B, 512+4, H, W)
|
||||
B, _, H, W = feat_map.shape
|
||||
|
||||
# try all pixels
|
||||
out_res4 = self.out_conv_res4(feat_map.view(B, 512 + 4, -1)) # (B, 4, N)
|
||||
out_res4 = norm_normalize(out_res4) # (B, 4, N) - normalized
|
||||
out_res4 = out_res4.view(B, 4, H, W)
|
||||
samples_pred_res4 = point_coords_res4 = None
|
||||
|
||||
################################################################################################################
|
||||
# out_res2
|
||||
################################################################################################################
|
||||
|
||||
# grid_sample feature-map
|
||||
feat_map = F.interpolate(
|
||||
x_d3, scale_factor=2, mode="bilinear", align_corners=True
|
||||
)
|
||||
init_pred = F.interpolate(
|
||||
out_res4, scale_factor=2, mode="bilinear", align_corners=True
|
||||
)
|
||||
feat_map = torch.cat([feat_map, init_pred], dim=1) # (B, 512+4, H, W)
|
||||
B, _, H, W = feat_map.shape
|
||||
|
||||
out_res2 = self.out_conv_res2(feat_map.view(B, 256 + 4, -1)) # (B, 4, N)
|
||||
out_res2 = norm_normalize(out_res2) # (B, 4, N) - normalized
|
||||
out_res2 = out_res2.view(B, 4, H, W)
|
||||
samples_pred_res2 = point_coords_res2 = None
|
||||
|
||||
################################################################################################################
|
||||
# out_res1
|
||||
################################################################################################################
|
||||
|
||||
# grid_sample feature-map
|
||||
feat_map = F.interpolate(
|
||||
x_d4, scale_factor=2, mode="bilinear", align_corners=True
|
||||
)
|
||||
init_pred = F.interpolate(
|
||||
out_res2, scale_factor=2, mode="bilinear", align_corners=True
|
||||
)
|
||||
feat_map = torch.cat([feat_map, init_pred], dim=1) # (B, 512+4, H, W)
|
||||
B, _, H, W = feat_map.shape
|
||||
|
||||
out_res1 = self.out_conv_res1(feat_map.view(B, 128 + 4, -1)) # (B, 4, N)
|
||||
out_res1 = norm_normalize(out_res1) # (B, 4, N) - normalized
|
||||
out_res1 = out_res1.view(B, 4, H, W)
|
||||
samples_pred_res1 = point_coords_res1 = None
|
||||
|
||||
return (
|
||||
[out_res8, out_res4, out_res2, out_res1],
|
||||
[out_res8, samples_pred_res4, samples_pred_res2, samples_pred_res1],
|
||||
[None, point_coords_res4, point_coords_res2, point_coords_res1],
|
||||
)
|
31
imaginairy/vendored/imaginairy_normal_map/encoder.py
Normal file
31
imaginairy/vendored/imaginairy_normal_map/encoder.py
Normal file
@ -0,0 +1,31 @@
|
||||
import timm
|
||||
from torch import nn
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
basemodel_name = "tf_efficientnet_b5_ap"
|
||||
basemodel = timm.create_model(
|
||||
basemodel_name, pretrained=True, num_classes=0, global_pool=""
|
||||
)
|
||||
basemodel.eval()
|
||||
|
||||
self.original_model = basemodel
|
||||
|
||||
def forward(self, x):
|
||||
features = [x]
|
||||
for k, v in self.original_model._modules.items(): # noqa
|
||||
if k == "blocks":
|
||||
for ki, vi in v._modules.items(): # noqa
|
||||
features.append(vi(features[-1]))
|
||||
else:
|
||||
features.append(v(features[-1]))
|
||||
# Decoder was made for handling output of NNET from rwightman/gen-efficientnet-pytorch
|
||||
# that version outputs 16 features but since the decoder doesn't use the extra features
|
||||
# we just placeholder None values
|
||||
if len(features) == 14:
|
||||
features.insert(2, None)
|
||||
features.insert(12, None)
|
||||
|
||||
return features
|
102
imaginairy/vendored/imaginairy_normal_map/model.py
Normal file
102
imaginairy/vendored/imaginairy_normal_map/model.py
Normal file
@ -0,0 +1,102 @@
|
||||
from functools import lru_cache
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
from torch import nn
|
||||
from torchvision import transforms
|
||||
|
||||
from .decoder import Decoder
|
||||
from .encoder import Encoder
|
||||
from .utils import get_device
|
||||
|
||||
|
||||
class NNET(nn.Module):
|
||||
def __init__(self, architecture="BN", sampling_ratio=0.4, importance_ratio=0.7):
|
||||
super().__init__()
|
||||
self.encoder = Encoder()
|
||||
self.decoder = Decoder(
|
||||
architecture=architecture,
|
||||
sampling_ratio=sampling_ratio,
|
||||
importance_ratio=importance_ratio,
|
||||
)
|
||||
|
||||
def forward(self, img, **kwargs):
|
||||
return self.decoder(self.encoder(img), **kwargs)
|
||||
|
||||
|
||||
def create_normal_map_pil_img(img, device=get_device()):
|
||||
img_t = pillow_img_to_torch_normal_map_input(img).to(device)
|
||||
pred_norm = create_normal_map_torch_img(img_t, device=device)
|
||||
return torch_normal_map_to_pillow_img(pred_norm)
|
||||
|
||||
|
||||
def create_normal_map_torch_img(img_t, device=get_device()):
|
||||
with torch.no_grad():
|
||||
model = load_model(device=device)
|
||||
img_t = img_t.to(device)
|
||||
norm_out_list, _, _ = model(img_t) # noqa
|
||||
norm_out = norm_out_list[-1]
|
||||
|
||||
pred_norm_t = norm_out[:, :3, :, :]
|
||||
return pred_norm_t
|
||||
|
||||
|
||||
normalize_img = transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||||
)
|
||||
|
||||
|
||||
def pillow_img_to_torch_normal_map_input(img):
|
||||
img = np.array(img).astype(np.float32) / 255.0
|
||||
img = torch.from_numpy(img).permute(2, 0, 1)
|
||||
img = normalize_img(img)
|
||||
img = img.unsqueeze(0)
|
||||
|
||||
# Resize image to nearest multiple of 8 using interpolate()
|
||||
h, w = img.size(2), img.size(3)
|
||||
h_new = int(round(h / 8) * 8)
|
||||
w_new = int(round(w / 8) * 8)
|
||||
img = torch.nn.functional.interpolate(
|
||||
img, size=(h_new, w_new), mode="bilinear", align_corners=False
|
||||
)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def torch_normal_map_to_pillow_img(norm_map_t):
|
||||
norm_map_np = norm_map_t.detach().cpu().permute(0, 2, 3, 1).numpy() # (B, H, W, 3)
|
||||
pred_norm_rgb = ((norm_map_np + 1) * 0.5) * 255
|
||||
pred_norm_rgb = np.clip(pred_norm_rgb, a_min=0, a_max=255)
|
||||
pred_norm_rgb = pred_norm_rgb.astype(np.uint8) # (B, H, W, 3)
|
||||
return Image.fromarray(pred_norm_rgb[0])
|
||||
|
||||
|
||||
def load_checkpoint(fpath, model):
|
||||
ckpt = torch.load(fpath, map_location="cpu")["model"]
|
||||
|
||||
load_dict = {}
|
||||
for k, v in ckpt.items():
|
||||
load_dict[k] = v
|
||||
|
||||
model.load_state_dict(load_dict)
|
||||
return model
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def load_model(device=None, sampling_ratio=0.4, importance_ratio=0.7) -> NNET:
|
||||
device = device or get_device()
|
||||
weights_path = hf_hub_download(
|
||||
repo_id="imaginairy/imaginairy-normal-uncertainty-map", filename="scannet.pt"
|
||||
)
|
||||
architecture = "BN"
|
||||
|
||||
model = NNET(
|
||||
architecture=architecture,
|
||||
sampling_ratio=sampling_ratio,
|
||||
importance_ratio=importance_ratio,
|
||||
).to(device)
|
||||
model = load_checkpoint(weights_path, model)
|
||||
model.eval()
|
||||
return model
|
39
imaginairy/vendored/imaginairy_normal_map/submodules.py
Normal file
39
imaginairy/vendored/imaginairy_normal_map/submodules.py
Normal file
@ -0,0 +1,39 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
|
||||
# Upsample + BatchNorm
|
||||
class UpSampleBN(nn.Module):
|
||||
def __init__(self, skip_input, output_features):
|
||||
super().__init__()
|
||||
|
||||
self._net = nn.Sequential(
|
||||
nn.Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1),
|
||||
nn.BatchNorm2d(output_features),
|
||||
nn.LeakyReLU(),
|
||||
nn.Conv2d(
|
||||
output_features, output_features, kernel_size=3, stride=1, padding=1
|
||||
),
|
||||
nn.BatchNorm2d(output_features),
|
||||
nn.LeakyReLU(),
|
||||
)
|
||||
|
||||
def forward(self, x, concat_with):
|
||||
up_x = F.interpolate(
|
||||
x,
|
||||
size=[concat_with.size(2), concat_with.size(3)],
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
f = torch.cat([up_x, concat_with], dim=1)
|
||||
return self._net(f)
|
||||
|
||||
|
||||
def norm_normalize(norm_out):
|
||||
min_kappa = 0.01
|
||||
norm_x, norm_y, norm_z, kappa = torch.split(norm_out, 1, dim=1)
|
||||
norm = torch.sqrt(norm_x**2.0 + norm_y**2.0 + norm_z**2.0) + 1e-10
|
||||
kappa = F.elu(kappa) + 1.0 + min_kappa
|
||||
final_out = torch.cat([norm_x / norm, norm_y / norm, norm_z / norm, kappa], dim=1)
|
||||
return final_out
|
15
imaginairy/vendored/imaginairy_normal_map/utils.py
Normal file
15
imaginairy/vendored/imaginairy_normal_map/utils.py
Normal file
@ -0,0 +1,15 @@
|
||||
from functools import lru_cache
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_device() -> str:
|
||||
"""Return the best torch backend available."""
|
||||
if torch.cuda.is_available():
|
||||
return "cuda"
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
return "mps:0"
|
||||
|
||||
return "cpu"
|
@ -77,7 +77,6 @@ ftfy==6.1.1
|
||||
huggingface-hub==0.14.1
|
||||
# via
|
||||
# diffusers
|
||||
# imaginairy-normal-map
|
||||
# open-clip-torch
|
||||
# timm
|
||||
# transformers
|
||||
@ -87,8 +86,6 @@ idna==3.4
|
||||
# yarl
|
||||
imageio==2.28.1
|
||||
# via imaginAIry (setup.py)
|
||||
imaginairy-normal-map==0.0.3
|
||||
# via imaginAIry (setup.py)
|
||||
importlib-metadata==6.6.0
|
||||
# via diffusers
|
||||
iniconfig==2.0.0
|
||||
@ -196,7 +193,7 @@ pyflakes==3.0.1
|
||||
# via pylama
|
||||
pylama==8.4.1
|
||||
# via -r requirements-dev.in
|
||||
pylint==2.17.3
|
||||
pylint==2.17.4
|
||||
# via -r requirements-dev.in
|
||||
pyparsing==3.0.9
|
||||
# via matplotlib
|
||||
@ -239,7 +236,7 @@ requests==2.30.0
|
||||
# transformers
|
||||
responses==0.23.1
|
||||
# via -r requirements-dev.in
|
||||
ruff==0.0.264
|
||||
ruff==0.0.265
|
||||
# via -r requirements-dev.in
|
||||
safetensors==0.3.1
|
||||
# via imaginAIry (setup.py)
|
||||
@ -259,7 +256,6 @@ termcolor==2.3.0
|
||||
timm==0.6.13
|
||||
# via
|
||||
# imaginAIry (setup.py)
|
||||
# imaginairy-normal-map
|
||||
# open-clip-torch
|
||||
tokenizers==0.13.3
|
||||
# via transformers
|
||||
@ -277,7 +273,6 @@ torch==1.13.1
|
||||
# facexlib
|
||||
# fairscale
|
||||
# imaginAIry (setup.py)
|
||||
# imaginairy-normal-map
|
||||
# kornia
|
||||
# open-clip-torch
|
||||
# pytorch-lightning
|
||||
@ -295,7 +290,6 @@ torchvision==0.14.1
|
||||
# via
|
||||
# facexlib
|
||||
# imaginAIry (setup.py)
|
||||
# imaginairy-normal-map
|
||||
# open-clip-torch
|
||||
# timm
|
||||
tqdm==4.65.0
|
||||
|
Loading…
Reference in New Issue
Block a user