mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-10-31 03:20:40 +00:00
213 lines
7.2 KiB
Python
213 lines
7.2 KiB
Python
"""Classes and functions for edge detection in images"""
|
|
|
|
from functools import lru_cache
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import torch
|
|
|
|
from imaginairy.utils import get_device
|
|
from imaginairy.utils.model_manager import get_cached_url_path
|
|
|
|
|
|
class Network(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.netVggOne = torch.nn.Sequential(
|
|
torch.nn.Conv2d(
|
|
in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1
|
|
),
|
|
torch.nn.ReLU(inplace=False),
|
|
torch.nn.Conv2d(
|
|
in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1
|
|
),
|
|
torch.nn.ReLU(inplace=False),
|
|
)
|
|
|
|
self.netVggTwo = torch.nn.Sequential(
|
|
torch.nn.MaxPool2d(kernel_size=2, stride=2),
|
|
torch.nn.Conv2d(
|
|
in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1
|
|
),
|
|
torch.nn.ReLU(inplace=False),
|
|
torch.nn.Conv2d(
|
|
in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1
|
|
),
|
|
torch.nn.ReLU(inplace=False),
|
|
)
|
|
|
|
self.netVggThr = torch.nn.Sequential(
|
|
torch.nn.MaxPool2d(kernel_size=2, stride=2),
|
|
torch.nn.Conv2d(
|
|
in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1
|
|
),
|
|
torch.nn.ReLU(inplace=False),
|
|
torch.nn.Conv2d(
|
|
in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1
|
|
),
|
|
torch.nn.ReLU(inplace=False),
|
|
torch.nn.Conv2d(
|
|
in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1
|
|
),
|
|
torch.nn.ReLU(inplace=False),
|
|
)
|
|
|
|
self.netVggFou = torch.nn.Sequential(
|
|
torch.nn.MaxPool2d(kernel_size=2, stride=2),
|
|
torch.nn.Conv2d(
|
|
in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1
|
|
),
|
|
torch.nn.ReLU(inplace=False),
|
|
torch.nn.Conv2d(
|
|
in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1
|
|
),
|
|
torch.nn.ReLU(inplace=False),
|
|
torch.nn.Conv2d(
|
|
in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1
|
|
),
|
|
torch.nn.ReLU(inplace=False),
|
|
)
|
|
|
|
self.netVggFiv = torch.nn.Sequential(
|
|
torch.nn.MaxPool2d(kernel_size=2, stride=2),
|
|
torch.nn.Conv2d(
|
|
in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1
|
|
),
|
|
torch.nn.ReLU(inplace=False),
|
|
torch.nn.Conv2d(
|
|
in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1
|
|
),
|
|
torch.nn.ReLU(inplace=False),
|
|
torch.nn.Conv2d(
|
|
in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1
|
|
),
|
|
torch.nn.ReLU(inplace=False),
|
|
)
|
|
|
|
self.netScoreOne = torch.nn.Conv2d(
|
|
in_channels=64, out_channels=1, kernel_size=1, stride=1, padding=0
|
|
)
|
|
self.netScoreTwo = torch.nn.Conv2d(
|
|
in_channels=128, out_channels=1, kernel_size=1, stride=1, padding=0
|
|
)
|
|
self.netScoreThr = torch.nn.Conv2d(
|
|
in_channels=256, out_channels=1, kernel_size=1, stride=1, padding=0
|
|
)
|
|
self.netScoreFou = torch.nn.Conv2d(
|
|
in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0
|
|
)
|
|
self.netScoreFiv = torch.nn.Conv2d(
|
|
in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0
|
|
)
|
|
|
|
self.netCombine = torch.nn.Sequential(
|
|
torch.nn.Conv2d(
|
|
in_channels=5, out_channels=1, kernel_size=1, stride=1, padding=0
|
|
),
|
|
torch.nn.Sigmoid(),
|
|
)
|
|
|
|
def forward(self, img_t):
|
|
img_t = (img_t + 1) * 127.5
|
|
img_t = img_t - torch.tensor(
|
|
data=[104.00698793, 116.66876762, 122.67891434],
|
|
dtype=img_t.dtype,
|
|
device=img_t.device,
|
|
).view(1, 3, 1, 1)
|
|
|
|
ten_vgg_one = self.netVggOne(img_t)
|
|
ten_vgg_two = self.netVggTwo(ten_vgg_one)
|
|
ten_vgg_thr = self.netVggThr(ten_vgg_two)
|
|
ten_vgg_fou = self.netVggFou(ten_vgg_thr)
|
|
ten_vgg_fiv = self.netVggFiv(ten_vgg_fou)
|
|
|
|
ten_score_one = self.netScoreOne(ten_vgg_one)
|
|
ten_score_two = self.netScoreTwo(ten_vgg_two)
|
|
ten_score_thr = self.netScoreThr(ten_vgg_thr)
|
|
ten_score_fou = self.netScoreFou(ten_vgg_fou)
|
|
ten_score_fiv = self.netScoreFiv(ten_vgg_fiv)
|
|
|
|
ten_score_one = torch.nn.functional.interpolate(
|
|
input=ten_score_one,
|
|
size=(img_t.shape[2], img_t.shape[3]),
|
|
mode="bilinear",
|
|
align_corners=False,
|
|
)
|
|
ten_score_two = torch.nn.functional.interpolate(
|
|
input=ten_score_two,
|
|
size=(img_t.shape[2], img_t.shape[3]),
|
|
mode="bilinear",
|
|
align_corners=False,
|
|
)
|
|
ten_score_thr = torch.nn.functional.interpolate(
|
|
input=ten_score_thr,
|
|
size=(img_t.shape[2], img_t.shape[3]),
|
|
mode="bilinear",
|
|
align_corners=False,
|
|
)
|
|
ten_score_fou = torch.nn.functional.interpolate(
|
|
input=ten_score_fou,
|
|
size=(img_t.shape[2], img_t.shape[3]),
|
|
mode="bilinear",
|
|
align_corners=False,
|
|
)
|
|
ten_score_fiv = torch.nn.functional.interpolate(
|
|
input=ten_score_fiv,
|
|
size=(img_t.shape[2], img_t.shape[3]),
|
|
mode="bilinear",
|
|
align_corners=False,
|
|
)
|
|
|
|
return self.netCombine(
|
|
torch.cat(
|
|
[
|
|
ten_score_one,
|
|
ten_score_two,
|
|
ten_score_thr,
|
|
ten_score_fou,
|
|
ten_score_fiv,
|
|
],
|
|
1,
|
|
)
|
|
)
|
|
|
|
|
|
@lru_cache(maxsize=1)
|
|
def hed_model():
|
|
model = Network().to(get_device()).eval()
|
|
model_path = get_cached_url_path(
|
|
"https://huggingface.co/lllyasviel/ControlNet/resolve/38a62cbf79862c1bac73405ec8dc46133aee3e36/annotator/ckpts/network-bsds500.pth"
|
|
)
|
|
state_dict = torch.load(model_path, map_location="cpu")
|
|
state_dict = {k.replace("module", "net"): v for k, v in state_dict.items()}
|
|
model.load_state_dict(state_dict)
|
|
return model
|
|
|
|
|
|
def create_hed_map(img_t):
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
model = hed_model().to(device)
|
|
img_t = img_t.to(device)
|
|
with torch.no_grad():
|
|
edge = model(img_t)[0]
|
|
return edge[0]
|
|
|
|
|
|
def nms(x, t, s):
|
|
"""make scribbles."""
|
|
x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
|
|
|
|
f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
|
|
f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
|
|
f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
|
|
f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
|
|
|
|
y = np.zeros_like(x)
|
|
|
|
for f in [f1, f2, f3, f4]:
|
|
np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
|
|
|
|
z = np.zeros_like(y, dtype=np.uint8)
|
|
z[y > t] = 255
|
|
return z
|