imaginAIry/imaginairy/vendored/realesrgan.py
2024-01-14 16:50:17 -08:00

385 lines
14 KiB
Python

import logging
import math
import os
import queue
import threading
import cv2
import numpy as np
import torch
from torch.nn import functional as F
from imaginairy.utils.downloads import get_cached_url_path
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
logger = logging.getLogger(__name__)
class RealESRGANer:
"""A helper class for upsampling images with RealESRGAN.
Args:
scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
model (nn.Module): The defined network. Default: None.
tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
input images into tiles, and then process each of them. Finally, they will be merged into one image.
0 denotes for do not use tile. Default: 0.
tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
half (float): Whether to use half precision during inference. Default: False.
"""
def __init__(
self,
scale,
model_path,
dni_weight=None,
model=None,
tile=0,
tile_pad=10,
pre_pad=10,
half=False,
device=None,
gpu_id=None,
):
self.scale = scale
self.tile_size = tile
self.tile_pad = tile_pad
self.pre_pad = pre_pad
self.mod_scale = None
self.half = half
# initialize model
if gpu_id:
self.device = (
torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
if device is None
else device
)
else:
self.device = (
torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device is None
else device
)
if isinstance(model_path, list):
# dni
assert len(model_path) == len(
dni_weight
), "model_path and dni_weight should have the save length."
loadnet = self.dni(model_path[0], model_path[1], dni_weight)
else:
# if the model_path starts with https, it will first download models to the folder: weights
if model_path.startswith("https://"):
model_path = get_cached_url_path(model_path)
loadnet = torch.load(model_path, map_location=torch.device("cpu"))
# prefer to use params_ema
if "params_ema" in loadnet:
keyname = "params_ema"
loadnet = loadnet[keyname]
elif "params" in loadnet:
keyname = "params"
loadnet = loadnet[keyname]
else:
loadnet = convert_realesrgan_state_dict(loadnet)
model.load_state_dict(loadnet, strict=True)
model.eval()
self.model = model.to(self.device)
if self.half:
self.model = self.model.half()
def dni(self, net_a, net_b, dni_weight, key="params", loc="cpu"):
"""Deep network interpolation.
``Paper: Deep Network Interpolation for Continuous Imagery Effect Transition``
"""
net_a = torch.load(net_a, map_location=torch.device(loc))
net_b = torch.load(net_b, map_location=torch.device(loc))
for k, v_a in net_a[key].items():
net_a[key][k] = dni_weight[0] * v_a + dni_weight[1] * net_b[key][k]
return net_a
def pre_process(self, img):
"""Pre-process, such as pre-pad and mod pad, so that the images can be divisible."""
img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
self.img = img.unsqueeze(0).to(self.device)
if self.half:
self.img = self.img.half()
# pre_pad
if self.pre_pad != 0:
self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), "reflect")
# mod pad for divisible borders
if self.scale == 2:
self.mod_scale = 2
elif self.scale == 1:
self.mod_scale = 4
if self.mod_scale is not None:
self.mod_pad_h, self.mod_pad_w = 0, 0
_, _, h, w = self.img.size()
if h % self.mod_scale != 0:
self.mod_pad_h = self.mod_scale - h % self.mod_scale
if w % self.mod_scale != 0:
self.mod_pad_w = self.mod_scale - w % self.mod_scale
self.img = F.pad(
self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), "reflect"
)
def process(self):
# model inference
self.output = self.model(self.img)
def tile_process(self):
"""It will first crop input images to tiles, and then process each tile.
Finally, all the processed tiles are merged into one images.
Modified from: https://github.com/ata4/esrgan-launcher
"""
batch, channel, height, width = self.img.shape
output_height = height * self.scale
output_width = width * self.scale
output_shape = (batch, channel, output_height, output_width)
# start with black image
self.output = self.img.new_zeros(output_shape)
tiles_x = math.ceil(width / self.tile_size)
tiles_y = math.ceil(height / self.tile_size)
logger.debug(f"Tiling with {tiles_x}x{tiles_y} ({tiles_x*tiles_y}) tiles")
# loop over all tiles
for y in range(tiles_y):
for x in range(tiles_x):
# extract tile from input image
ofs_x = x * self.tile_size
ofs_y = y * self.tile_size
# input tile area on total image
input_start_x = ofs_x
input_end_x = min(ofs_x + self.tile_size, width)
input_start_y = ofs_y
input_end_y = min(ofs_y + self.tile_size, height)
# input tile area on total image with padding
input_start_x_pad = max(input_start_x - self.tile_pad, 0)
input_end_x_pad = min(input_end_x + self.tile_pad, width)
input_start_y_pad = max(input_start_y - self.tile_pad, 0)
input_end_y_pad = min(input_end_y + self.tile_pad, height)
# input tile dimensions
input_tile_width = input_end_x - input_start_x
input_tile_height = input_end_y - input_start_y
input_tile = self.img[
:,
:,
input_start_y_pad:input_end_y_pad,
input_start_x_pad:input_end_x_pad,
]
# upscale tile
with torch.no_grad():
output_tile = self.model(input_tile)
# output tile area on total image
output_start_x = input_start_x * self.scale
output_end_x = input_end_x * self.scale
output_start_y = input_start_y * self.scale
output_end_y = input_end_y * self.scale
# output tile area without padding
output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
# put tile into output image
self.output[
:, :, output_start_y:output_end_y, output_start_x:output_end_x
] = output_tile[
:,
:,
output_start_y_tile:output_end_y_tile,
output_start_x_tile:output_end_x_tile,
]
def post_process(self):
# remove extra pad
if self.mod_scale is not None:
_, _, h, w = self.output.size()
self.output = self.output[
:,
:,
0 : h - self.mod_pad_h * self.scale,
0 : w - self.mod_pad_w * self.scale,
]
# remove prepad
if self.pre_pad != 0:
_, _, h, w = self.output.size()
self.output = self.output[
:,
:,
0 : h - self.pre_pad * self.scale,
0 : w - self.pre_pad * self.scale,
]
return self.output
@torch.no_grad()
def enhance(self, img, outscale=None, alpha_upsampler="realesrgan"):
h_input, w_input = img.shape[0:2]
# img: numpy
img = img.astype(np.float32)
if np.max(img) > 256: # 16-bit image
max_range = 65535
print("\tInput is a 16-bit image")
else:
max_range = 255
img = img / max_range
if len(img.shape) == 2: # gray image
img_mode = "L"
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
elif img.shape[2] == 4: # RGBA image with alpha channel
img_mode = "RGBA"
alpha = img[:, :, 3]
img = img[:, :, 0:3]
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
if alpha_upsampler == "realesrgan":
alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
else:
img_mode = "RGB"
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# ------------------- process image (without the alpha channel) ------------------- #
self.pre_process(img)
if self.tile_size > 0:
self.tile_process()
else:
self.process()
output_img = self.post_process()
output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
if img_mode == "L":
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
# ------------------- process the alpha channel if necessary ------------------- #
if img_mode == "RGBA":
if alpha_upsampler == "realesrgan":
self.pre_process(alpha)
if self.tile_size > 0:
self.tile_process()
else:
self.process()
output_alpha = self.post_process()
output_alpha = (
output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
)
output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
else: # use the cv2 resize for alpha channel
h, w = alpha.shape[0:2]
output_alpha = cv2.resize(
alpha,
(w * self.scale, h * self.scale),
interpolation=cv2.INTER_LINEAR,
)
# merge the alpha channel
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
output_img[:, :, 3] = output_alpha
# ------------------------------ return ------------------------------ #
if max_range == 65535: # 16-bit image
output = (output_img * 65535.0).round().astype(np.uint16)
else:
output = (output_img * 255.0).round().astype(np.uint8)
if outscale is not None and outscale != float(self.scale):
output = cv2.resize(
output,
(
int(w_input * outscale),
int(h_input * outscale),
),
interpolation=cv2.INTER_LANCZOS4,
)
return output, img_mode
class PrefetchReader(threading.Thread):
"""Prefetch images.
Args:
img_list (list[str]): A image list of image paths to be read.
num_prefetch_queue (int): Number of prefetch queue.
"""
def __init__(self, img_list, num_prefetch_queue):
super().__init__()
self.que = queue.Queue(num_prefetch_queue)
self.img_list = img_list
def run(self):
for img_path in self.img_list:
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
self.que.put(img)
self.que.put(None)
def __next__(self):
next_item = self.que.get()
if next_item is None:
raise StopIteration
return next_item
def __iter__(self):
return self
class IOConsumer(threading.Thread):
def __init__(self, opt, que, qid):
super().__init__()
self._queue = que
self.qid = qid
self.opt = opt
def run(self):
while True:
msg = self._queue.get()
if isinstance(msg, str) and msg == "quit":
break
output = msg["output"]
save_path = msg["save_path"]
cv2.imwrite(save_path, output)
print(f"IO worker {self.qid} is done.")
def convert_realesrgan_state_dict(state_dict):
new_state_dict = {}
new_state_dict["conv_first.weight"] = state_dict.pop("model.0.weight")
new_state_dict["conv_first.bias"] = state_dict.pop("model.0.bias")
# "model.1.sub.21.RDB3.conv5.0.weight => body.21.rdb1.conv3.weight"
for k, v in list(state_dict.items()):
parts = k.split(".")
if len(parts) == 8 and parts[0] == "model":
new_parts = ["body", parts[3], parts[4].lower(), parts[5], parts[7]]
new_k = ".".join(new_parts)
new_state_dict[new_k] = state_dict.pop(k)
new_state_dict["conv_body.weight"] = state_dict.pop("model.1.sub.23.weight")
new_state_dict["conv_body.bias"] = state_dict.pop("model.1.sub.23.bias")
new_state_dict["conv_up1.weight"] = state_dict.pop("model.3.weight")
new_state_dict["conv_up1.bias"] = state_dict.pop("model.3.bias")
new_state_dict["conv_up2.weight"] = state_dict.pop("model.6.weight")
new_state_dict["conv_up2.bias"] = state_dict.pop("model.6.bias")
new_state_dict["conv_hr.weight"] = state_dict.pop("model.8.weight")
new_state_dict["conv_hr.bias"] = state_dict.pop("model.8.bias")
new_state_dict["conv_last.weight"] = state_dict.pop("model.10.weight")
new_state_dict["conv_last.bias"] = state_dict.pop("model.10.bias")
return new_state_dict