refactor: cleanup image generation code

This commit is contained in:
Bryce 2023-02-15 08:02:36 -08:00 committed by Bryce Drennan
parent 8a97213622
commit ea1d4baafe
6 changed files with 228 additions and 225 deletions

View File

@ -195,18 +195,20 @@ def _generate_single_image(
half_mode=None,
add_caption=False,
):
import numpy as np
import torch.nn
from einops import rearrange, repeat
from PIL import Image, ImageOps
from pytorch_lightning import seed_everything
from torch.cuda import OutOfMemoryError
from imaginairy.enhancers.clip_masking import get_img_mask
from imaginairy.enhancers.describe_image_blip import generate_caption
from imaginairy.enhancers.face_restoration_codeformer import enhance_faces
from imaginairy.enhancers.upscale_realesrgan import upscale_image
from imaginairy.img_utils import pillow_fit_image_within, pillow_img_to_torch_image
from imaginairy.img_utils import (
pillow_fit_image_within,
pillow_img_to_torch_image,
pillow_mask_to_latent_mask,
torch_img_to_pillow_img,
)
from imaginairy.log_utils import (
ImageLoggingContext,
log_conditioning,
@ -214,7 +216,7 @@ def _generate_single_image(
log_latent,
)
from imaginairy.model_manager import get_diffusion_model
from imaginairy.modules.midas.utils import AddMiDaS
from imaginairy.modules.midas.api import torch_image_to_depth_map
from imaginairy.outpaint import outpaint_arg_str_parse, prepare_image_for_outpaint
from imaginairy.safety import create_safety_score
from imaginairy.samplers import SAMPLER_LOOKUP
@ -241,7 +243,6 @@ def _generate_single_image(
half_mode=half_mode,
for_inpainting=prompt.mask_image or prompt.mask_prompt or prompt.outpaint,
)
has_depth_channel = hasattr(model, "depth_stage_key")
progress_latents = []
def latent_logger(latents):
@ -279,9 +280,14 @@ def _generate_single_image(
]
SamplerCls = SAMPLER_LOOKUP[prompt.sampler_type.lower()]
sampler = SamplerCls(model)
mask = mask_image = mask_image_orig = mask_grayscale = None
mask_latent = mask_image = mask_image_orig = mask_grayscale = None
t_enc = init_latent = init_latent_noised = None
starting_image = None
denoiser_cls = None
c_cat = []
c_cat_neutral = None
result_images = {}
if prompt.init_image:
starting_image = prompt.init_image
generation_strength = 1 - prompt.init_image_strength
@ -321,25 +327,12 @@ def _generate_single_image(
if prompt.mask_mode == ImaginePrompt.MaskMode.REPLACE:
mask_image = ImageOps.invert(mask_image)
log_img(
Image.composite(init_image, mask_image, mask_image),
"mask overlay",
)
mask_image_orig = mask_image
mask_image = mask_image.resize(
(
mask_image.width // downsampling_factor,
mask_image.height // downsampling_factor,
),
resample=Image.Resampling.LANCZOS,
)
log_img(mask_image, "latent_mask")
mask_latent = pillow_mask_to_latent_mask(
mask_image, downsampling_factor=downsampling_factor
).to(get_device())
mask = np.array(mask_image)
mask = mask.astype(np.float32) / 255.0
mask = mask[None, None]
mask = torch.from_numpy(mask)
mask = mask.to(get_device())
init_image_t = pillow_img_to_torch_image(init_image)
init_image_t = init_image_t.to(get_device())
init_latent = model.get_first_stage_encoding(
@ -348,7 +341,7 @@ def _generate_single_image(
shape = init_latent.shape
log_latent(init_latent, "init_latent")
# encode (scaled latent)
seed_everything(prompt.seed)
noise = randn_seeded(seed=prompt.seed, size=init_latent.size())
noise = noise.to(get_device())
@ -371,90 +364,45 @@ def _generate_single_image(
schedule=schedule,
noise=noise,
)
batch_size = 1
log_latent(init_latent_noised, "init_latent_noised")
batch = {
"txt": batch_size * [prompt.prompt_text],
}
c_cat = []
c_cat_neutral = None
depth_image_display = None
if has_depth_channel and starting_image:
midas_model = AddMiDaS()
_init_image_d = np.array(starting_image.convert("RGB"))
_init_image_d = (
torch.from_numpy(_init_image_d).to(dtype=torch.float32) / 127.5 - 1.0
)
depth_image = midas_model(_init_image_d)
depth_image = torch.from_numpy(depth_image[None, ...])
batch[model.depth_stage_key] = depth_image.to(device=get_device())
_init_image_d = rearrange(_init_image_d, "h w c -> 1 c h w")
batch["jpg"] = _init_image_d
for ck in model.concat_keys:
cc = batch[ck]
cc = model.depth_model(cc)
depth_min, depth_max = torch.amin(
cc, dim=[1, 2, 3], keepdim=True
), torch.amax(cc, dim=[1, 2, 3], keepdim=True)
display_depth = (cc - depth_min) / (depth_max - depth_min)
depth_image_display = Image.fromarray(
(display_depth[0, 0, ...].cpu().numpy() * 255.0).astype(np.uint8)
)
cc = torch.nn.functional.interpolate(
cc,
if hasattr(model, "depth_stage_key"):
# depth model
depth_t = torch_image_to_depth_map(init_image_t)
depth_latent = torch.nn.functional.interpolate(
depth_t,
size=shape[2:],
mode="bicubic",
align_corners=False,
)
depth_min, depth_max = torch.amin(
cc, dim=[1, 2, 3], keepdim=True
), torch.amax(cc, dim=[1, 2, 3], keepdim=True)
cc = 2.0 * (cc - depth_min) / (depth_max - depth_min) - 1.0
c_cat.append(cc)
c_cat = [torch.cat(c_cat, dim=1)]
result_images["depth_image"] = depth_t
c_cat.append(depth_latent)
if mask_image_orig and not has_depth_channel:
elif hasattr(model, "masked_image_key"):
# inpainting model
mask_t = pillow_img_to_torch_image(ImageOps.invert(mask_image_orig)).to(
get_device()
)
inverted_mask = 1 - mask
inverted_mask = 1 - mask_latent
masked_image_t = init_image_t * (mask_t < 0.5)
batch.update(
{
"image": repeat(
init_image_t.to(device=get_device()),
"1 ... -> n ...",
n=batch_size,
),
"txt": batch_size * [prompt.prompt_text],
"mask": repeat(
inverted_mask.to(device=get_device()),
"1 ... -> n ...",
n=batch_size,
),
"masked_image": repeat(
masked_image_t.to(device=get_device()),
"1 ... -> n ...",
n=batch_size,
),
}
)
log_img(masked_image_t, "masked_image")
for concat_key in getattr(model, "concat_keys", []):
cc = batch[concat_key].float()
if concat_key != model.masked_image_key:
bchw = [batch_size, 4, shape[2], shape[3]]
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
else:
cc = model.get_first_stage_encoding(model.encode_first_stage(cc))
c_cat.append(cc)
if c_cat:
c_cat = [torch.cat(c_cat, dim=1)]
denoiser_cls = None
if model.cond_stage_key == "edit":
inverted_mask_latent = torch.nn.functional.interpolate(
inverted_mask, size=shape[-2:]
)
c_cat.append(inverted_mask_latent)
masked_image_latent = model.get_first_stage_encoding(
model.encode_first_stage(masked_image_t)
)
c_cat.append(masked_image_latent)
elif model.cond_stage_key == "edit":
# pix2pix model
c_cat = [model.encode_first_stage(init_image_t).mode()]
c_cat_neutral = [torch.zeros_like(init_latent)]
denoiser_cls = CFGEditingDenoiser
if c_cat:
c_cat = [torch.cat(c_cat, dim=1)]
if c_cat_neutral is None:
c_cat_neutral = c_cat
@ -467,6 +415,8 @@ def _generate_single_image(
"c_concat": c_cat_neutral,
"c_crossattn": [neutral_conditioning],
}
log_latent(init_latent_noised, "init_latent_noised")
with lc.timing("sampling"):
samples = sampler.sample(
num_steps=prompt.steps,
@ -475,51 +425,34 @@ def _generate_single_image(
neutral_conditioning=neutral_conditioning,
guidance_scale=prompt.prompt_strength,
t_start=t_enc,
mask=mask,
mask=mask_latent,
orig_latent=init_latent,
shape=shape,
batch_size=1,
denoiser_cls=denoiser_cls,
)
# from torch.nn.functional import interpolate
# samples = interpolate(samples, scale_factor=2, mode='nearest')
with lc.timing("decoding"):
try:
x_samples = model.decode_first_stage(samples)
except OutOfMemoryError:
model.cond_stage_model.to("cpu")
model.model.to("cpu")
x_samples = model.decode_first_stage(samples)
model.cond_stage_model.to(get_device())
model.model.to(get_device())
gen_imgs_t = model.decode_first_stage(samples)
gen_img = torch_img_to_pillow_img(gen_imgs_t)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
for x_sample in x_samples:
x_sample = x_sample.to(torch.float32)
x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c")
x_sample_8_orig = x_sample.astype(np.uint8)
img = Image.fromarray(x_sample_8_orig)
if mask_image_orig and init_image:
# mask_final = mask_image_orig.filter(
# ImageFilter.GaussianBlur(radius=3)
# )
mask_final = mask_image_orig.copy()
log_img(mask_final, "reconstituting mask")
mask_final = ImageOps.invert(mask_final)
img = Image.composite(img, init_image, mask_final)
log_img(img, "reconstituted image")
gen_img = Image.composite(gen_img, init_image, mask_final)
log_img(gen_img, "reconstituted image")
upscaled_img = None
rebuilt_orig_img = None
if add_caption:
caption = generate_caption(img)
caption = generate_caption(gen_img)
logger.info(f"Generated caption: {caption}")
with lc.timing("safety-filter"):
safety_score = create_safety_score(
img,
gen_img,
safety_mode=IMAGINAIRY_SAFETY_MODE,
)
if safety_score.is_filtered:
@ -528,38 +461,23 @@ def _generate_single_image(
if prompt.fix_faces:
logger.info("Fixing 😊 's in 🖼 using CodeFormer...")
with lc.timing("face enhancement"):
img = enhance_faces(img, fidelity=prompt.fix_faces_fidelity)
gen_img = enhance_faces(gen_img, fidelity=prompt.fix_faces_fidelity)
if prompt.upscale:
logger.info("Upscaling 🖼 using real-ESRGAN...")
with lc.timing("upscaling"):
upscaled_img = upscale_image(img)
upscaled_img = upscale_image(gen_img)
# put the newly generated patch back into the original, full size image
# put the newly generated patch back into the original, full-size image
if prompt.mask_modify_original and mask_image_orig and starting_image:
img_to_add_back_to_original = upscaled_img if upscaled_img else img
img_to_add_back_to_original = img_to_add_back_to_original.resize(
starting_image.size,
resample=Image.Resampling.LANCZOS,
img_to_add_back_to_original = upscaled_img if upscaled_img else gen_img
rebuilt_orig_img = combine_image(
original_img=starting_image,
generated_img=img_to_add_back_to_original,
mask_img=mask_image_orig,
)
mask_for_orig_size = mask_image_orig.resize(
starting_image.size,
resample=Image.Resampling.LANCZOS,
)
# mask_for_orig_size = mask_for_orig_size.filter(
# ImageFilter.GaussianBlur(radius=5)
# )
log_img(mask_for_orig_size, "mask for original image size")
rebuilt_orig_img = Image.composite(
starting_image,
img_to_add_back_to_original,
mask_for_orig_size,
)
log_img(rebuilt_orig_img, "reconstituted original")
result = ImagineResult(
img=img,
img=gen_img,
prompt=prompt,
upscaled_img=upscaled_img,
is_nsfw=safety_score.is_nsfw,
@ -567,10 +485,11 @@ def _generate_single_image(
modified_original=rebuilt_orig_img,
mask_binary=mask_image_orig,
mask_grayscale=mask_grayscale,
depth_image=depth_image_display,
result_images=result_images,
timings=lc.get_timings(),
progress_latents=progress_latents.copy(),
)
_most_recent_result = result
logger.info(f"Image Generated. Timings: {result.timings_str()}")
return result
@ -587,3 +506,29 @@ def _prompts_to_embeddings(prompts, model):
def prompt_normalized(prompt):
return re.sub(r"[^a-zA-Z0-9.,\[\]-]+", "_", prompt)[:130]
def combine_image(original_img, generated_img, mask_img):
"""Combine the generated image with the original image using the mask image."""
from PIL import Image
from imaginairy.log_utils import log_img
generated_img = generated_img.resize(
original_img.size,
resample=Image.Resampling.LANCZOS,
)
mask_for_orig_size = mask_img.resize(
original_img.size,
resample=Image.Resampling.LANCZOS,
)
log_img(mask_for_orig_size, "mask for original image size")
rebuilt_orig_img = Image.composite(
original_img,
generated_img,
mask_for_orig_size,
)
log_img(rebuilt_orig_img, "reconstituted original")
return rebuilt_orig_img

View File

@ -34,20 +34,28 @@ def pillow_fit_image_within(
return image
def pillow_img_to_torch_image(img: PIL.Image.Image):
img = img.convert("RGB")
def pillow_img_to_torch_image(img: PIL.Image.Image, convert="RGB"):
if convert:
img = img.convert(convert)
img = np.array(img).astype(np.float32) / 255.0
img = img[None].transpose(0, 3, 1, 2)
img = torch.from_numpy(img)
return 2.0 * img - 1.0
def torch_img_to_pillow_img(img: torch.Tensor):
img = rearrange(img, "b c h w -> b h w c")
img = torch.clamp((img + 1.0) / 2.0, min=0.0, max=1.0)
img = (255.0 * img).cpu().numpy().astype(np.uint8)
img = Image.fromarray(img[0])
return img
def pillow_mask_to_latent_mask(mask_img: PIL.Image.Image, downsampling_factor):
mask_img = mask_img.resize(
(
mask_img.width // downsampling_factor,
mask_img.height // downsampling_factor,
),
resample=Image.Resampling.LANCZOS,
)
mask = np.array(mask_img).astype(np.float32) / 255.0
mask = mask[None, None]
mask = torch.from_numpy(mask)
return mask
def pillow_img_to_opencv_img(img: PIL.Image.Image):
@ -57,22 +65,42 @@ def pillow_img_to_opencv_img(img: PIL.Image.Image):
return open_cv_image
def model_latents_to_pillow_imgs(latents: torch.Tensor) -> Sequence[PIL.Image.Image]:
def torch_img_to_pillow_img(img_t: torch.Tensor):
if len(img_t.shape) == 3:
img_t = img_t.unsqueeze(0)
if img_t.shape[0] != 1:
raise ValueError("Only batch size 1 supported")
if img_t.shape[1] == 1:
colorspace = "L"
elif img_t.shape[1] == 3:
colorspace = "RGB"
else:
raise ValueError("Unsupported colorspace")
img_t = rearrange(img_t, "b c h w -> b h w c")
img_t = torch.clamp((img_t + 1.0) / 2.0, min=0.0, max=1.0)
img_np = (255.0 * img_t).cpu().numpy().astype(np.uint8)[0]
if colorspace == "L":
img_np = img_np[:, :, 0]
return Image.fromarray(img_np, colorspace)
def model_latent_to_pillow_img(latent: torch.Tensor) -> PIL.Image.Image:
from imaginairy.model_manager import get_current_diffusion_model # noqa
if len(latent.shape) == 3:
latent = latent.unsqueeze(0)
if latent.shape[0] != 1:
raise ValueError("Only batch size 1 supported")
model = get_current_diffusion_model()
latents = model.decode_first_stage(latents)
latents = torch.clamp((latents + 1.0) / 2.0, min=0.0, max=1.0)
imgs = []
for latent in latents:
latent = 255.0 * rearrange(latent.cpu().numpy(), "c h w -> h w c")
img = Image.fromarray(latent.astype(np.uint8))
imgs.append(img)
return imgs
img_t = model.decode_first_stage(latent)
return torch_img_to_pillow_img(img_t)
def model_latents_to_pillow_imgs(latents: torch.Tensor) -> Sequence[PIL.Image.Image]:
return [model_latent_to_pillow_img(latent) for latent in latents]
def pillow_img_to_model_latent(model, img, batch_size=1, half=True):
# init_image = pil_img_to_torch(img, half=half).to(device)
init_image = pillow_img_to_torch_image(img).to(get_device())
init_image = repeat(init_image, "1 ... -> b ...", b=batch_size)
if half:

View File

@ -1,7 +1,9 @@
# based on https://github.com/isl-org/MiDaS
from functools import lru_cache
import cv2
import torch
from einops import rearrange
from torch import nn
from torchvision.transforms import Compose
@ -13,6 +15,7 @@ from imaginairy.modules.midas.midas.transforms import (
PrepareForNet,
Resize,
)
from imaginairy.utils import get_device
ISL_PATHS = {
"dpt_large": "midas_models/dpt_large-midas-2f21e586.pt",
@ -151,6 +154,30 @@ def load_model(model_type):
return model.eval(), transform
@lru_cache()
def load_midas(model_type="dpt_hybrid"):
model = MiDaSInference(model_type)
model.to(get_device())
return model
def torch_image_to_depth_map(image_t: torch.Tensor, model_type="dpt_hybrid"):
model = load_midas(model_type)
transform = load_midas_transform(model_type)
image_t = rearrange(image_t, "b c h w -> b h w c")[0]
image_np = ((image_t + 1.0) * 0.5).detach().cpu().numpy()
image_np = transform({"image": image_np})["image"]
image_t = torch.from_numpy(image_np[None, ...])
image_t = image_t.to(device=get_device())
depth_t = model(image_t)
depth_min = torch.amin(depth_t, dim=[1, 2, 3], keepdim=True)
depth_max = torch.amax(depth_t, dim=[1, 2, 3], keepdim=True)
depth_t = (depth_t - depth_min) / (depth_max - depth_min)
return depth_t
class MiDaSInference(nn.Module):
MODEL_TYPES_TORCH_HUB = ["DPT_Large", "DPT_Hybrid", "MiDaS_small"]
MODEL_TYPES_ISL = [

View File

@ -278,10 +278,13 @@ class ImagineResult:
modified_original=None,
mask_binary=None,
mask_grayscale=None,
depth_image=None,
result_images=None,
timings=None,
progress_latents=None,
):
import torch
from imaginairy.img_utils import torch_img_to_pillow_img
from imaginairy.utils import get_device, get_hardware_description
self.prompt = prompt
@ -300,8 +303,10 @@ class ImagineResult:
if mask_grayscale:
self.images["mask_grayscale"] = mask_grayscale
if depth_image is not None:
self.images["depth_image"] = depth_image
for img_type, r_img in result_images.items():
if isinstance(r_img, torch.Tensor):
r_img = torch_img_to_pillow_img(r_img)
self.images[img_type] = r_img
self.timings = timings
self.progress_latents = progress_latents

View File

@ -158,7 +158,7 @@ def test_img_to_img_fruit_2_gold(
"k_euler_a": 18000,
"k_dpm_adaptive": 13000,
}
threshold = threshold_lookup.get(sampler_type, 11000)
threshold = threshold_lookup.get(sampler_type, 14000)
pillow_fit_image_within(img).save(f"{filename_base_for_orig_outputs}__orig.jpg")
img_path = f"{filename_base_for_outputs}.png"

View File

@ -58,9 +58,9 @@ def test_feather_tile_simple(img_ratio, tile_size, overlap_pct):
tile_size, overlap_pct, (img.size(2), img.size(3))
)
print(
f"tile_coords={tile_coords}, tile_size={tile_size}, overlap={overlap}, img.shape={img.shape}"
)
# print(
# f"tile_coords={tile_coords}, tile_size={tile_size}, overlap={overlap}, img.shape={img.shape}"
# )
rebuilt = rebuild_image(
tiles, base_img=img, tile_size=tile_size, overlap_percent=overlap_pct
@ -86,9 +86,9 @@ def test_feather_tile_brute():
tile_coords, tile_size, overlap = tile_setup(
tile_size, overlap_percent, (img.size(2), img.size(3))
)
print(
f"tile_coords={tile_coords}, tile_size={tile_size}, overlap={overlap}, img.shape={img.shape}"
)
# print(
# f"tile_coords={tile_coords}, tile_size={tile_size}, overlap={overlap}, img.shape={img.shape}"
# )
rebuilt = rebuild_image(
tiles, base_img=img, tile_size=tile_size, overlap_percent=overlap_percent
@ -100,13 +100,11 @@ def test_feather_tile_brute():
torch_img_to_pillow_img(rebuilt).show()
torch_img_to_pillow_img((rebuilt - img) * 20).show()
status = "🚫 FAILED"
else:
status = "✅ PASSED"
print(
f"{status}: img:{img.shape} tile_size={tile_size} overlap_percent={overlap_percent} diff={diff}"
)
pass
# print(
# f"{status}: img:{img.shape} tile_size={tile_size} overlap_percent={overlap_percent} diff={diff}"
# )
assert diff < 1
for tile_size_pct, overlap_percent, img_ratio, flip_ratio in itertools.product(
@ -123,9 +121,9 @@ def test_feather_tile_brute():
if overlap_percent >= 0.5:
continue
print(
f"img_ratio={img_ratio}, tile_size_pct={tile_size_pct}, overlap_percent={overlap_percent}, tile_size={tile_size} img.shape={img.shape}"
)
# print(
# f"img_ratio={img_ratio}, tile_size_pct={tile_size_pct}, overlap_percent={overlap_percent}, tile_size={tile_size} img.shape={img.shape}"
# )
tile_untile(img, tile_size=tile_size, overlap_percent=overlap_percent)
del img