fix: composition images were too blurry

This commit is contained in:
Bryce 2023-02-25 00:14:12 -08:00 committed by Bryce Drennan
parent be7f8f3c2a
commit d7e494241c
4 changed files with 21 additions and 7 deletions

View File

@ -3,7 +3,6 @@ import math
import os
import re
from imaginairy.img_utils import add_caption_to_image
from imaginairy.schema import SafetyMode
logger = logging.getLogger(__name__)
@ -208,6 +207,7 @@ def _generate_single_image(
from imaginairy.enhancers.face_restoration_codeformer import enhance_faces
from imaginairy.enhancers.upscale_realesrgan import upscale_image
from imaginairy.img_utils import (
add_caption_to_image,
pillow_fit_image_within,
pillow_img_to_torch_image,
pillow_mask_to_latent_mask,
@ -626,6 +626,7 @@ def _generate_composition_latent(
from torch.nn import functional as F
from imaginairy.enhancers.upscale_riverwing import upscale_latent
from imaginairy.log_utils import log_img, log_latent
b, c, h, w = orig_shape = sampler_kwargs["shape"]
max_compose_gen_size = 768
@ -677,7 +678,16 @@ def _generate_composition_latent(
samples = sampler.sample(**new_kwargs)
samples = upscale_latent(samples)
samples = F.interpolate(samples, size=orig_shape[2:], mode="bilinear")
log_latent(samples, "upscaled")
img_t = sampler.model.decode_first_stage(samples)
img_t = F.interpolate(img_t, size=(h * 8, w * 8), mode="bicubic")
log_img(img_t, "upscaled interpolated")
samples = sampler.model.get_first_stage_encoding(
sampler.model.encode_first_stage(img_t)
)
log_latent(samples, "upscaled interpolated latent")
return samples

View File

@ -30,7 +30,7 @@ def colorize_img(img):
negative_prompt="black and white",
width=min(img.width, 1024),
height=min(img.height, 1024),
steps=30
steps=30,
)
result = list(imagine(prompt))[0]
colorized_img = replace_color(img, result.images["generated"])

View File

@ -3,9 +3,9 @@ from functools import lru_cache
import numpy as np
import torch
import torch.nn.functional as F
from pytorch_lightning import seed_everything
from torch import nn
from imaginairy.log_utils import log_latent
from imaginairy.model_manager import hf_hub_download
from imaginairy.utils import get_device, platform_appropriate_autocast
from imaginairy.vendored import k_diffusion as K
@ -211,19 +211,19 @@ def upscale_latent(
batch_size=1,
num_samples=1,
# Amount of noise to add per step (0.0=deterministic). Used in all samplers except `k_euler`.
eta=1.0,
eta=0.1,
device=get_device(),
):
# Add noise to the latent vectors before upscaling. This theoretically can make the model work better on out-of-distribution inputs, but mostly just seems to make it match the input less, so it's turned off by default.
noise_aug_level = 0 # @param {type: 'slider', min: 0.0, max: 0.6, step:0.025}
noise_aug_type = "gaussian" # @param ["gaussian", "fake"]
noise_aug_type = "fake" # @param ["gaussian", "fake"]
# @markdown Sampler settings. `k_dpm_adaptive` uses an adaptive solver with error tolerance `tol_scale`, all other use a fixed number of steps.
sampler = "k_dpm_2_ancestral" # @param ["k_euler", "k_euler_ancestral", "k_dpm_2_ancestral", "k_dpm_fast", "k_dpm_adaptive"]
tol_scale = 0.25 # @param {type: 'number'}
seed_everything(seed)
# seed_everything(seed)
# uc = condition_up(batch_size * ["blurry, low resolution, 720p, grainy"])
uc = condition_up(batch_size * [""])
@ -302,4 +302,5 @@ def upscale_latent(
extra_args = {"low_res": latent_noised, "low_res_sigma": low_res_sigma, "c": c}
noise = torch.randn(x_shape, device=device)
up_latents = do_sample(noise, extra_args)
log_latent(low_res_latent, "low_res_latent")
return up_latents

View File

@ -6,6 +6,8 @@ from torch import nn
from torchdiffeq import odeint
from tqdm.auto import tqdm, trange
from imaginairy.log_utils import log_latent
def append_zero(x):
return torch.cat([x, x.new_zeros([1])])
@ -945,4 +947,5 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
old_denoised = denoised
log_latent(x, "K_dpmpp_2m_x")
return x