feature: better upscaling

- use face enhancement in a smarter way that doesn't blur high-res images
- use a different upscale model for composition images

**Upscaling**
RealESRGAN is great but it blurs parts of images it doesn't understand

4xUltrasharp is a finetune of RealESRGan that isn't as good but doesn't have this blurry patch problem.  This makes it more suitable to use as part of the composition/upscale process.  We still use realesrgan for any last-step upscales since it does look better.

had to write a state dict translator to use the ultrasharp model

**Face Enhancement**

We no longer enhance faces that are larger than 512 pixels. They should already have enough details and the face enhancer doesn't produce faces at high enough resolution to look good at that size.
pull/427/head
Bryce 5 months ago committed by Bryce Drennan
parent 6ebd12abb1
commit 32b5175e0e

@ -79,7 +79,18 @@ def enhance_faces(img, fidelity=0):
face_helper.align_warp_face()
# face restoration for each cropped face
for cropped_face in face_helper.cropped_faces:
for face_box, cropped_face in zip(face_helper.det_faces, face_helper.cropped_faces):
x1, y1, x2, y2, scaling = face_box
face_width = x2 - x1
face_height = y2 - y1
logger.debug(f"Face detected. size: {face_width:1f}x{face_height:.1f}")
if face_width > 512 or face_height > 512:
logger.debug(
f"Face too large: ({face_width:.1f}x{face_height:.1f}). skipping enhancement"
)
face_helper.add_restored_face(cropped_face)
continue
# prepare data
cropped_face_t = img2tensor(cropped_face / 255.0, bgr2rgb=True, float32=True)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)

@ -12,16 +12,25 @@ from imaginairy.vendored.realesrgan import RealESRGANer
@memory_managed_model("realesrgan_upsampler", memory_usage_mb=70)
def realesrgan_upsampler():
def realesrgan_upsampler(tile=1024, tile_pad=50, ultrasharp=False):
model = RRDBNet(
num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4
)
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
if ultrasharp:
url = "https://huggingface.co/lokCX/4x-Ultrasharp/resolve/1856559b50de25116a7c07261177dd128f1f5664/4x-UltraSharp.pth"
else:
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
model_path = get_cached_url_path(url)
device = get_device()
upsampler = RealESRGANer(
scale=4, model_path=model_path, model=model, tile=512, device=device
scale=4,
model_path=model_path,
model=model,
tile=tile,
device=device,
tile_pad=tile_pad,
)
upsampler.device = torch.device(device)
@ -30,9 +39,11 @@ def realesrgan_upsampler():
return upsampler
def upscale_image(img):
def upscale_image(img, ultrasharp=False):
img = img.convert("RGB")
np_img = np.array(img, dtype=np.uint8)
upsampler_output, img_mode = realesrgan_upsampler().enhance(np_img[:, :, ::-1])
upsampler_output, img_mode = realesrgan_upsampler(ultrasharp=ultrasharp).enhance(
np_img[:, :, ::-1]
)
return Image.fromarray(upsampler_output[:, :, ::-1], mode=img_mode)

@ -290,7 +290,7 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
size: tuple[int, int] = Field(validate_default=True)
upscale: bool = False
fix_faces: bool = False
fix_faces_fidelity: float | None = Field(0.2, ge=0, le=1, validate_default=True)
fix_faces_fidelity: float | None = Field(0.5, ge=0, le=1, validate_default=True)
conditioning: str | None = None
tile_mode: str = ""
allow_compose_phase: bool = True
@ -531,7 +531,7 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
@field_validator("fix_faces_fidelity", mode="before")
def validate_fix_faces_fidelity(cls, v):
if v is None:
return 0.2
return 0.5
return v

@ -77,9 +77,13 @@ class RealESRGANer:
# prefer to use params_ema
if "params_ema" in loadnet:
keyname = "params_ema"
else:
loadnet = loadnet[keyname]
elif "params" in loadnet:
keyname = "params"
model.load_state_dict(loadnet[keyname], strict=True)
loadnet = loadnet[keyname]
else:
loadnet = convert_realesrgan_state_dict(loadnet)
model.load_state_dict(loadnet, strict=True)
model.eval()
self.model = model.to(self.device)
@ -347,3 +351,32 @@ class IOConsumer(threading.Thread):
save_path = msg["save_path"]
cv2.imwrite(save_path, output)
print(f"IO worker {self.qid} is done.")
def convert_realesrgan_state_dict(state_dict):
new_state_dict = {}
new_state_dict["conv_first.weight"] = state_dict.pop("model.0.weight")
new_state_dict["conv_first.bias"] = state_dict.pop("model.0.bias")
# "model.1.sub.21.RDB3.conv5.0.weight => body.21.rdb1.conv3.weight"
for k, v in list(state_dict.items()):
parts = k.split(".")
if len(parts) == 8 and parts[0] == "model":
new_parts = ["body", parts[3], parts[4].lower(), parts[5], parts[7]]
new_k = ".".join(new_parts)
new_state_dict[new_k] = state_dict.pop(k)
new_state_dict["conv_body.weight"] = state_dict.pop("model.1.sub.23.weight")
new_state_dict["conv_body.bias"] = state_dict.pop("model.1.sub.23.bias")
new_state_dict["conv_up1.weight"] = state_dict.pop("model.3.weight")
new_state_dict["conv_up1.bias"] = state_dict.pop("model.3.bias")
new_state_dict["conv_up2.weight"] = state_dict.pop("model.6.weight")
new_state_dict["conv_up2.bias"] = state_dict.pop("model.6.bias")
new_state_dict["conv_hr.weight"] = state_dict.pop("model.8.weight")
new_state_dict["conv_hr.bias"] = state_dict.pop("model.8.bias")
new_state_dict["conv_last.weight"] = state_dict.pop("model.10.weight")
new_state_dict["conv_last.bias"] = state_dict.pop("model.10.bias")
return new_state_dict

Binary file not shown.

After

Width:  |  Height:  |  Size: 61 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 826 KiB

@ -0,0 +1,21 @@
import pytest
from lightning_fabric import seed_everything
from PIL import Image
from imaginairy.enhancers.face_restoration_codeformer import enhance_faces
from imaginairy.utils import get_device
from tests import TESTS_FOLDER
from tests.utils import assert_image_similar_to_expectation
@pytest.mark.skipif(
get_device() == "cpu", reason="TypeError: Got unsupported ScalarType BFloat16"
)
def test_fix_faces(filename_base_for_orig_outputs, filename_base_for_outputs):
distorted_img = Image.open(f"{TESTS_FOLDER}/data/distorted_face.png")
seed_everything(1)
img = enhance_faces(distorted_img)
distorted_img.save(f"{filename_base_for_orig_outputs}__orig.jpg")
img_path = f"{filename_base_for_outputs}.png"
assert_image_similar_to_expectation(img, img_path=img_path, threshold=2800)

@ -0,0 +1,13 @@
from PIL import Image
from imaginairy.enhancers.upscale_realesrgan import upscale_image
from tests import TESTS_FOLDER
from tests.utils import assert_image_similar_to_expectation
def test_upscale_textured_image(filename_base_for_outputs):
img = Image.open(f"{TESTS_FOLDER}/data/sand_upscale_difficult.jpg")
upscaled_image = upscale_image(img, ultrasharp=True)
assert_image_similar_to_expectation(
upscaled_image, f"{filename_base_for_outputs}.jpg", threshold=25000
)
Loading…
Cancel
Save