2022-09-13 07:27:53 +00:00
|
|
|
from functools import lru_cache
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
|
|
|
from PIL import Image
|
|
|
|
from realesrgan import RealESRGANer
|
|
|
|
|
2022-10-23 21:46:45 +00:00
|
|
|
from imaginairy.model_manager import get_cached_url_path
|
|
|
|
from imaginairy.utils import get_device
|
2022-09-13 07:27:53 +00:00
|
|
|
|
|
|
|
|
2023-01-02 04:14:22 +00:00
|
|
|
@lru_cache
|
2022-09-13 07:27:53 +00:00
|
|
|
def realesrgan_upsampler():
|
|
|
|
model = RRDBNet(
|
|
|
|
num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4
|
|
|
|
)
|
|
|
|
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
|
|
|
|
model_path = get_cached_url_path(url)
|
|
|
|
upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0)
|
|
|
|
|
2022-09-17 19:24:27 +00:00
|
|
|
device = get_device()
|
2022-09-18 13:07:07 +00:00
|
|
|
if "mps" in device:
|
|
|
|
device = "cpu"
|
2022-09-13 07:27:53 +00:00
|
|
|
|
|
|
|
upsampler.device = torch.device(device)
|
|
|
|
upsampler.model.to(device)
|
|
|
|
|
|
|
|
return upsampler
|
|
|
|
|
|
|
|
|
|
|
|
def upscale_image(img):
|
|
|
|
img = img.convert("RGB")
|
2022-09-18 13:07:07 +00:00
|
|
|
|
2022-09-13 07:27:53 +00:00
|
|
|
np_img = np.array(img, dtype=np.uint8)
|
|
|
|
upsampler_output, img_mode = realesrgan_upsampler().enhance(np_img[:, :, ::-1])
|
|
|
|
return Image.fromarray(upsampler_output[:, :, ::-1], mode=img_mode)
|