fix: allow use of upscaler on mps device

pull/333/head
Bryce 1 year ago committed by Bryce Drennan
parent 3b066f8e29
commit fb19e34acc

@ -9,18 +9,16 @@ from imaginairy.vendored.basicsr.rrdbnet_arch import RRDBNet
from imaginairy.vendored.realesrgan import RealESRGANer
@memory_managed_model("realesrgan_upsampler")
@memory_managed_model("realesrgan_upsampler", memory_usage_mb=70)
def realesrgan_upsampler():
model = RRDBNet(
num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4
)
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
model_path = get_cached_url_path(url)
upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=512)
device = get_device()
if "mps" in device:
device = "cpu"
upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=512, device=device)
upsampler.device = torch.device(device)
upsampler.model.to(device)

Loading…
Cancel
Save