fix: use model appropriate composition size

This commit is contained in:
Bryce 2023-02-28 23:02:04 -08:00 committed by Bryce Drennan
parent 3b777b98d8
commit 52044c1073
2 changed files with 12 additions and 94 deletions

View File

@ -1,5 +1,4 @@
import logging
import math
import os
import re
@ -219,7 +218,10 @@ def _generate_single_image(
log_img,
log_latent,
)
from imaginairy.model_manager import get_diffusion_model
from imaginairy.model_manager import (
get_diffusion_model,
get_model_default_image_size,
)
from imaginairy.modules.midas.api import torch_image_to_depth_map
from imaginairy.outpaint import outpaint_arg_str_parse, prepare_image_for_outpaint
from imaginairy.safety import create_safety_score
@ -483,12 +485,14 @@ def _generate_single_image(
prompt=prompt,
target_height=init_image.height,
target_width=init_image.width,
cutoff=get_model_default_image_size(prompt.model),
)
else:
comp_image = _generate_composition_image(
prompt=prompt,
target_height=prompt.height,
target_width=prompt.width,
cutoff=get_model_default_image_size(prompt.model),
)
if comp_image is not None:
result_images["composition"] = comp_image
@ -626,12 +630,11 @@ def _scale_latent(
return latent
def _generate_composition_image(prompt, target_height, target_width):
def _generate_composition_image(prompt, target_height, target_width, cutoff=512):
from copy import copy
from PIL import Image
cutoff = 512
if prompt.width <= cutoff and prompt.height <= cutoff:
return None
@ -672,95 +675,6 @@ def _generate_composition_image(prompt, target_height, target_width):
return img
def _generate_composition_latentz(
sampler,
sampler_kwargs,
):
from copy import deepcopy
from torch.nn import functional as F
from imaginairy.enhancers.upscale_riverwing import upscale_latent
from imaginairy.log_utils import log_img, log_latent
b, c, h, w = sampler_kwargs["shape"]
max_compose_gen_size = 768
shrink_scale = calc_scale_to_fit_within(
height=h,
width=w,
max_size=int(math.ceil(max_compose_gen_size / 8)),
)
if shrink_scale >= 1:
return None
new_kwargs = deepcopy(sampler_kwargs)
# shrink everything
new_shape = b, c, int(round(h * shrink_scale)), int(round(w * shrink_scale))
noise = new_kwargs["noise"]
if noise is not None:
noise = F.interpolate(noise, size=new_shape[2:], mode="nearest-exact")
for cond in [
new_kwargs["positive_conditioning"],
new_kwargs["neutral_conditioning"],
]:
print(cond["c_concat"])
for c in cond["c_concat"]:
print(f"downscaling {c.shape} ")
cond["c_concat"] = [
_scale_latent(
latent=c, model=sampler.model, h=new_shape[2] * 8, w=new_shape[3] * 8
)
for c in cond["c_concat"]
]
print(cond["c_concat"])
mask_latent = new_kwargs["mask"]
if mask_latent is not None:
mask_latent = F.interpolate(mask_latent, size=new_shape[2:], mode="area")
orig_latent = new_kwargs["orig_latent"]
if orig_latent is not None:
orig_latent = _scale_latent(
latent=orig_latent,
model=sampler.model,
h=new_shape[2] * 8,
w=new_shape[3] * 8,
)
t_start = new_kwargs["t_start"]
if t_start is not None:
gen_strength = new_kwargs["t_start"] / new_kwargs["num_steps"]
t_start = int(round(15 * gen_strength))
new_kwargs.update(
{
"num_steps": 15,
"noise": noise,
"t_start": t_start,
"mask": mask_latent,
"orig_latent": orig_latent,
"shape": new_shape,
}
)
samples = sampler.sample(**new_kwargs)
# while samples.shape[2] < h:
logger.info("Upscaling latent...")
samples = upscale_latent(samples)
log_latent(samples, "upscaled")
img_t = sampler.model.decode_first_stage(samples)
img_t = F.interpolate(img_t, size=(h * 8, w * 8), mode="bicubic")
log_img(img_t, "upscaled interpolated")
samples = sampler.model.get_first_stage_encoding(
sampler.model.encode_first_stage(img_t)
)
log_latent(samples, "upscaled interpolated latent")
return samples
def prompt_normalized(prompt):
return re.sub(r"[^a-zA-Z0-9.,\[\]-]+", "_", prompt)[:130]

View File

@ -8,7 +8,11 @@ from imaginairy.cli.shared import (
)
@click.command(context_settings={"max_content_width": 140}, cls=ImagineColorsCommand, name="imagine")
@click.command(
context_settings={"max_content_width": 140},
cls=ImagineColorsCommand,
name="imagine",
)
@click.argument("prompt_texts", nargs=-1)
@add_options(common_options)
@click.option(