build: vendorize realesrgan

Removes lots of dependencies including problematic `grpcio` and `tb-nightly`
This commit is contained in:
Bryce 2023-01-08 22:03:16 -08:00 committed by Bryce Drennan
parent 35ac8d64d7
commit 4bc78b9be5
6 changed files with 389 additions and 115 deletions

View File

@ -17,6 +17,8 @@ init: require_pyenv ## Setup a dev environment for local development.
pip install --upgrade pip pip-tools
pip-sync requirements-dev.txt
pip install -e . --no-deps
# the compiled requirements don't included OS specific subdependencies so we trigger those this way
pip install `pip freeze | grep "^torch=="`
@echo -e "\nEnvironment setup! ✨ 🍰 ✨ 🐍 \n\nCopy this path to tell PyCharm where your virtualenv is. You may have to click the refresh button in the pycharm file explorer.\n"
@echo -e "\033[0;32m"
@pyenv which python
@ -86,7 +88,6 @@ vendor_openai_clip:
revendorize: vendorize_kdiffusion
make vendorize REPO=git@github.com:openai/CLIP.git PKG=clip COMMIT=d50d76daa670286dd6cacf3bcd80b5e4823fc8e1
make af
vendorize_clipseg:

View File

@ -3,11 +3,11 @@ from functools import lru_cache
import numpy as np
import torch
from PIL import Image
from realesrgan import RealESRGANer
from imaginairy.model_manager import get_cached_url_path
from imaginairy.utils import get_device
from imaginairy.vendored.basicsr.rrdbnet_arch import RRDBNet
from imaginairy.vendored.realesrgan import RealESRGANer
@lru_cache

View File

@ -0,0 +1,354 @@
import math
import os
import queue
import threading
import cv2
import numpy as np
import torch
from torch.nn import functional as F
from imaginairy.model_manager import get_cached_url_path
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
class RealESRGANer:
"""A helper class for upsampling images with RealESRGAN.
Args:
scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
model (nn.Module): The defined network. Default: None.
tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
input images into tiles, and then process each of them. Finally, they will be merged into one image.
0 denotes for do not use tile. Default: 0.
tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
half (float): Whether to use half precision during inference. Default: False.
"""
def __init__(
self,
scale,
model_path,
dni_weight=None,
model=None,
tile=0,
tile_pad=10,
pre_pad=10,
half=False,
device=None,
gpu_id=None,
):
self.scale = scale
self.tile_size = tile
self.tile_pad = tile_pad
self.pre_pad = pre_pad
self.mod_scale = None
self.half = half
# initialize model
if gpu_id:
self.device = (
torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
if device is None
else device
)
else:
self.device = (
torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device is None
else device
)
if isinstance(model_path, list):
# dni
assert len(model_path) == len(
dni_weight
), "model_path and dni_weight should have the save length."
loadnet = self.dni(model_path[0], model_path[1], dni_weight)
else:
# if the model_path starts with https, it will first download models to the folder: weights
if model_path.startswith("https://"):
model_path = get_cached_url_path(model_path)
loadnet = torch.load(model_path, map_location=torch.device("cpu"))
# prefer to use params_ema
if "params_ema" in loadnet:
keyname = "params_ema"
else:
keyname = "params"
model.load_state_dict(loadnet[keyname], strict=True)
model.eval()
self.model = model.to(self.device)
if self.half:
self.model = self.model.half()
def dni(self, net_a, net_b, dni_weight, key="params", loc="cpu"):
"""Deep network interpolation.
``Paper: Deep Network Interpolation for Continuous Imagery Effect Transition``
"""
net_a = torch.load(net_a, map_location=torch.device(loc))
net_b = torch.load(net_b, map_location=torch.device(loc))
for k, v_a in net_a[key].items():
net_a[key][k] = dni_weight[0] * v_a + dni_weight[1] * net_b[key][k]
return net_a
def pre_process(self, img):
"""Pre-process, such as pre-pad and mod pad, so that the images can be divisible."""
img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
self.img = img.unsqueeze(0).to(self.device)
if self.half:
self.img = self.img.half()
# pre_pad
if self.pre_pad != 0:
self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), "reflect")
# mod pad for divisible borders
if self.scale == 2:
self.mod_scale = 2
elif self.scale == 1:
self.mod_scale = 4
if self.mod_scale is not None:
self.mod_pad_h, self.mod_pad_w = 0, 0
_, _, h, w = self.img.size()
if h % self.mod_scale != 0:
self.mod_pad_h = self.mod_scale - h % self.mod_scale
if w % self.mod_scale != 0:
self.mod_pad_w = self.mod_scale - w % self.mod_scale
self.img = F.pad(
self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), "reflect"
)
def process(self):
# model inference
self.output = self.model(self.img)
def tile_process(self):
"""It will first crop input images to tiles, and then process each tile.
Finally, all the processed tiles are merged into one images.
Modified from: https://github.com/ata4/esrgan-launcher
"""
batch, channel, height, width = self.img.shape
output_height = height * self.scale
output_width = width * self.scale
output_shape = (batch, channel, output_height, output_width)
# start with black image
self.output = self.img.new_zeros(output_shape)
tiles_x = math.ceil(width / self.tile_size)
tiles_y = math.ceil(height / self.tile_size)
# loop over all tiles
for y in range(tiles_y):
for x in range(tiles_x):
# extract tile from input image
ofs_x = x * self.tile_size
ofs_y = y * self.tile_size
# input tile area on total image
input_start_x = ofs_x
input_end_x = min(ofs_x + self.tile_size, width)
input_start_y = ofs_y
input_end_y = min(ofs_y + self.tile_size, height)
# input tile area on total image with padding
input_start_x_pad = max(input_start_x - self.tile_pad, 0)
input_end_x_pad = min(input_end_x + self.tile_pad, width)
input_start_y_pad = max(input_start_y - self.tile_pad, 0)
input_end_y_pad = min(input_end_y + self.tile_pad, height)
# input tile dimensions
input_tile_width = input_end_x - input_start_x
input_tile_height = input_end_y - input_start_y
tile_idx = y * tiles_x + x + 1
input_tile = self.img[
:,
:,
input_start_y_pad:input_end_y_pad,
input_start_x_pad:input_end_x_pad,
]
# upscale tile
try:
with torch.no_grad():
output_tile = self.model(input_tile)
except RuntimeError as error:
print("Error", error)
print(f"\tTile {tile_idx}/{tiles_x * tiles_y}")
# output tile area on total image
output_start_x = input_start_x * self.scale
output_end_x = input_end_x * self.scale
output_start_y = input_start_y * self.scale
output_end_y = input_end_y * self.scale
# output tile area without padding
output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
# put tile into output image
self.output[
:, :, output_start_y:output_end_y, output_start_x:output_end_x
] = output_tile[
:,
:,
output_start_y_tile:output_end_y_tile,
output_start_x_tile:output_end_x_tile,
]
def post_process(self):
# remove extra pad
if self.mod_scale is not None:
_, _, h, w = self.output.size()
self.output = self.output[
:,
:,
0 : h - self.mod_pad_h * self.scale,
0 : w - self.mod_pad_w * self.scale,
]
# remove prepad
if self.pre_pad != 0:
_, _, h, w = self.output.size()
self.output = self.output[
:,
:,
0 : h - self.pre_pad * self.scale,
0 : w - self.pre_pad * self.scale,
]
return self.output
@torch.no_grad()
def enhance(self, img, outscale=None, alpha_upsampler="realesrgan"):
h_input, w_input = img.shape[0:2]
# img: numpy
img = img.astype(np.float32)
if np.max(img) > 256: # 16-bit image
max_range = 65535
print("\tInput is a 16-bit image")
else:
max_range = 255
img = img / max_range
if len(img.shape) == 2: # gray image
img_mode = "L"
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
elif img.shape[2] == 4: # RGBA image with alpha channel
img_mode = "RGBA"
alpha = img[:, :, 3]
img = img[:, :, 0:3]
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
if alpha_upsampler == "realesrgan":
alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
else:
img_mode = "RGB"
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# ------------------- process image (without the alpha channel) ------------------- #
self.pre_process(img)
if self.tile_size > 0:
self.tile_process()
else:
self.process()
output_img = self.post_process()
output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
if img_mode == "L":
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
# ------------------- process the alpha channel if necessary ------------------- #
if img_mode == "RGBA":
if alpha_upsampler == "realesrgan":
self.pre_process(alpha)
if self.tile_size > 0:
self.tile_process()
else:
self.process()
output_alpha = self.post_process()
output_alpha = (
output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
)
output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
else: # use the cv2 resize for alpha channel
h, w = alpha.shape[0:2]
output_alpha = cv2.resize(
alpha,
(w * self.scale, h * self.scale),
interpolation=cv2.INTER_LINEAR,
)
# merge the alpha channel
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
output_img[:, :, 3] = output_alpha
# ------------------------------ return ------------------------------ #
if max_range == 65535: # 16-bit image
output = (output_img * 65535.0).round().astype(np.uint16)
else:
output = (output_img * 255.0).round().astype(np.uint8)
if outscale is not None and outscale != float(self.scale):
output = cv2.resize(
output,
(
int(w_input * outscale),
int(h_input * outscale),
),
interpolation=cv2.INTER_LANCZOS4,
)
return output, img_mode
class PrefetchReader(threading.Thread):
"""Prefetch images.
Args:
img_list (list[str]): A image list of image paths to be read.
num_prefetch_queue (int): Number of prefetch queue.
"""
def __init__(self, img_list, num_prefetch_queue):
super().__init__()
self.que = queue.Queue(num_prefetch_queue)
self.img_list = img_list
def run(self):
for img_path in self.img_list:
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
self.que.put(img)
self.que.put(None)
def __next__(self):
next_item = self.que.get()
if next_item is None:
raise StopIteration
return next_item
def __iter__(self):
return self
class IOConsumer(threading.Thread):
def __init__(self, opt, que, qid):
super().__init__()
self._queue = que
self.qid = qid
self.opt = opt
def run(self):
while True:
msg = self._queue.get()
if isinstance(msg, str) and msg == "quit":
break
output = msg["output"]
save_path = msg["save_path"]
cv2.imwrite(save_path, output)
print(f"IO worker {self.qid} is done.")

View File

@ -2,9 +2,11 @@ black
coverage
isort
ruff
pycln
pylama
pylint
pytest
pytest-randomly
pytest-sugar
responses
wheel

View File

@ -4,10 +4,6 @@
#
# pip-compile --output-file=requirements-dev.txt --resolver=backtracking requirements-dev.in setup.py
#
absl-py==1.3.0
# via tb-nightly
addict==2.4.0
# via basicsr
aiohttp==3.8.3
# via fsspec
aiosignal==1.3.1
@ -22,14 +18,8 @@ attrs==22.2.0
# via
# aiohttp
# pytest
basicsr==1.4.2
# via
# gfpgan
# realesrgan
black==22.12.0
# via -r requirements-dev.in
cachetools==5.2.1
# via google-auth
certifi==2022.12.7
# via requests
charset-normalizer==2.1.1
@ -41,6 +31,7 @@ click==8.1.3
# black
# click-shell
# imaginAIry (setup.py)
# typer
click-shell==2.1
# via imaginAIry (setup.py)
contourpy==1.0.6
@ -58,9 +49,7 @@ einops==0.6.0
exceptiongroup==1.1.0
# via pytest
facexlib==0.2.5
# via
# gfpgan
# realesrgan
# via imaginAIry (setup.py)
fairscale==0.4.13
# via imaginAIry (setup.py)
filelock==3.9.0
@ -82,18 +71,6 @@ ftfy==6.1.1
# via
# imaginAIry (setup.py)
# open-clip-torch
future==0.18.2
# via basicsr
gfpgan==1.3.8
# via realesrgan
google-auth==2.15.0
# via
# google-auth-oauthlib
# tb-nightly
google-auth-oauthlib==0.4.6
# via tb-nightly
grpcio==1.51.1
# via tb-nightly
huggingface-hub==0.11.1
# via
# diffusers
@ -105,9 +82,7 @@ idna==3.4
# requests
# yarl
imageio==2.24.0
# via
# imaginAIry (setup.py)
# scikit-image
# via imaginAIry (setup.py)
importlib-metadata==6.0.0
# via diffusers
iniconfig==2.0.0
@ -122,18 +97,12 @@ kornia==0.6.9
# via imaginAIry (setup.py)
lazy-object-proxy==1.9.0
# via astroid
libcst==0.4.9
# via pycln
lightning-utilities==0.5.0
# via pytorch-lightning
llvmlite==0.39.1
# via numba
lmdb==1.4.0
# via
# basicsr
# gfpgan
markdown==3.4.1
# via tb-nightly
markupsafe==2.1.1
# via werkzeug
matplotlib==3.6.2
# via filterpy
mccabe==0.7.0
@ -145,48 +114,37 @@ multidict==6.0.4
# aiohttp
# yarl
mypy-extensions==0.4.3
# via black
networkx==3.0
# via scikit-image
# via
# black
# typing-inspect
numba==0.56.4
# via facexlib
numpy==1.23.5
# via
# basicsr
# contourpy
# diffusers
# facexlib
# fairscale
# filterpy
# gfpgan
# imageio
# imaginAIry (setup.py)
# matplotlib
# numba
# opencv-python
# pytorch-lightning
# pywavelets
# realesrgan
# scikit-image
# scipy
# tb-nightly
# tensorboardx
# tifffile
# torchmetrics
# torchvision
# transformers
oauthlib==3.2.2
# via requests-oauthlib
omegaconf==2.3.0
# via imaginAIry (setup.py)
open-clip-torch==2.9.2
# via imaginAIry (setup.py)
opencv-python==4.7.0.68
# via
# basicsr
# facexlib
# gfpgan
# realesrgan
# imaginAIry (setup.py)
packaging==23.0
# via
# huggingface-hub
@ -196,21 +154,19 @@ packaging==23.0
# pytest
# pytest-sugar
# pytorch-lightning
# scikit-image
# torchmetrics
# transformers
pathspec==0.10.3
# via black
pathspec==0.9.0
# via
# black
# pycln
pillow==9.4.0
# via
# basicsr
# diffusers
# facexlib
# imageio
# imaginAIry (setup.py)
# matplotlib
# realesrgan
# scikit-image
# torchvision
platformdirs==2.6.2
# via
@ -222,16 +178,11 @@ protobuf==3.20.1
# via
# imaginAIry (setup.py)
# open-clip-torch
# tb-nightly
# tensorboardx
psutil==5.9.4
# via imaginAIry (setup.py)
pyasn1==0.4.8
# via
# pyasn1-modules
# rsa
pyasn1-modules==0.2.8
# via google-auth
pycln==2.1.2
# via -r requirements-dev.in
pycodestyle==2.10.0
# via pylama
pydocstyle==6.2.3
@ -257,19 +208,15 @@ python-dateutil==2.8.2
# via matplotlib
pytorch-lightning==1.8.6
# via imaginAIry (setup.py)
pywavelets==1.4.1
# via scikit-image
pyyaml==6.0
# via
# basicsr
# gfpgan
# huggingface-hub
# libcst
# omegaconf
# pycln
# pytorch-lightning
# timm
# transformers
realesrgan==0.3.0
# via imaginAIry (setup.py)
regex==2022.10.31
# via
# diffusers
@ -277,58 +224,34 @@ regex==2022.10.31
# transformers
requests==2.28.1
# via
# basicsr
# diffusers
# fsspec
# huggingface-hub
# imaginAIry (setup.py)
# requests-oauthlib
# responses
# tb-nightly
# torchvision
# transformers
requests-oauthlib==1.3.1
# via google-auth-oauthlib
responses==0.22.0
# via -r requirements-dev.in
rsa==4.9
# via google-auth
ruff==0.0.215
# via -r requirements-dev.in
safetensors==0.2.7
# via imaginAIry (setup.py)
scikit-image==0.19.3
# via basicsr
scipy==1.10.0
# via
# basicsr
# facexlib
# filterpy
# gfpgan
# scikit-image
# torchdiffeq
sentencepiece==0.1.97
# via open-clip-torch
six==1.16.0
# via
# google-auth
# python-dateutil
# via python-dateutil
snowballstemmer==2.2.0
# via pydocstyle
tb-nightly==2.12.0a20230107
# via
# basicsr
# gfpgan
tensorboard-data-server==0.6.1
# via tb-nightly
tensorboard-plugin-wit==1.8.1
# via tb-nightly
tensorboardx==2.5.1
# via pytorch-lightning
termcolor==2.2.0
# via pytest-sugar
tifffile==2022.10.10
# via scikit-image
timm==0.6.12
# via imaginAIry (setup.py)
tokenizers==0.13.2
@ -341,18 +264,17 @@ tomli==2.0.1
# pylint
# pytest
tomlkit==0.11.6
# via pylint
# via
# pycln
# pylint
torch==1.13.1
# via
# basicsr
# facexlib
# fairscale
# gfpgan
# imaginAIry (setup.py)
# kornia
# open-clip-torch
# pytorch-lightning
# realesrgan
# timm
# torchdiffeq
# torchmetrics
@ -365,52 +287,46 @@ torchmetrics==0.11.0
# pytorch-lightning
torchvision==0.14.1
# via
# basicsr
# facexlib
# gfpgan
# imaginAIry (setup.py)
# open-clip-torch
# realesrgan
# timm
tqdm==4.64.1
# via
# basicsr
# facexlib
# gfpgan
# huggingface-hub
# imaginAIry (setup.py)
# open-clip-torch
# pytorch-lightning
# realesrgan
# transformers
transformers==4.25.1
# via imaginAIry (setup.py)
typer==0.7.0
# via pycln
types-toml==0.10.8.1
# via responses
typing-extensions==4.4.0
# via
# astroid
# huggingface-hub
# libcst
# lightning-utilities
# pytorch-lightning
# torch
# torchvision
# typing-inspect
typing-inspect==0.8.0
# via libcst
urllib3==1.26.13
# via
# requests
# responses
wcwidth==0.2.5
# via ftfy
werkzeug==2.2.2
# via tb-nightly
wheel==0.38.4
# via tb-nightly
# via -r requirements-dev.in
wrapt==1.14.1
# via astroid
yapf==0.32.0
# via
# basicsr
# gfpgan
yarl==1.8.2
# via aiohttp
zipp==3.11.0

View File

@ -37,6 +37,7 @@ setup(
"click",
"click-shell",
"protobuf != 3.20.2, != 3.19.5",
"facexlib",
"fairscale>=0.4.4", # for vendored blip
"ftfy", # for vendored clip
"torch>=1.2.0",
@ -49,6 +50,7 @@ setup(
"pytorch-lightning>=1.4.2",
"omegaconf>=2.1.1",
"open-clip-torch",
"opencv-python",
"requests",
"einops>=0.3.0",
"safetensors",
@ -58,6 +60,5 @@ setup(
"torchmetrics>=0.6.0",
"torchvision>=0.13.1",
"kornia>=0.6",
"realesrgan",
],
)