imaginAIry/imaginairy/enhancers/upscale_realesrgan.py
Bryce 1c986d8644 fix: use py3.7 compat lru_cache
- disable lint fixer that updates to newer syntax
2023-01-22 18:24:57 -08:00

38 lines
1.1 KiB
Python

from functools import lru_cache
import numpy as np
import torch
from PIL import Image
from imaginairy.model_manager import get_cached_url_path
from imaginairy.utils import get_device
from imaginairy.vendored.basicsr.rrdbnet_arch import RRDBNet
from imaginairy.vendored.realesrgan import RealESRGANer
@lru_cache()
def realesrgan_upsampler():
model = RRDBNet(
num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4
)
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
model_path = get_cached_url_path(url)
upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=512)
device = get_device()
if "mps" in device:
device = "cpu"
upsampler.device = torch.device(device)
upsampler.model.to(device)
return upsampler
def upscale_image(img):
img = img.convert("RGB")
np_img = np.array(img, dtype=np.uint8)
upsampler_output, img_mode = realesrgan_upsampler().enhance(np_img[:, :, ::-1])
return Image.fromarray(upsampler_output[:, :, ::-1], mode=img_mode)