diff --git a/imaginairy/api/generate.py b/imaginairy/api/generate.py index eab401b..2f260a6 100755 --- a/imaginairy/api/generate.py +++ b/imaginairy/api/generate.py @@ -2,13 +2,9 @@ import logging import os -import re -from typing import TYPE_CHECKING, Any, Callable +from typing import Callable -from imaginairy.utils.named_resolutions import normalize_image_size - -if TYPE_CHECKING: - from imaginairy.schema import ImaginePrompt, LazyLoadingImage +from imaginairy.utils import prompt_normalized logger = logging.getLogger(__name__) @@ -160,7 +156,7 @@ def imagine( ): import torch.nn - from imaginairy.api.generate_refiners import _generate_single_image + from imaginairy.api.generate_refiners import generate_single_image from imaginairy.schema import ImaginePrompt from imaginairy.utils import ( check_torch_version, @@ -199,7 +195,7 @@ def imagine( for attempt in range(unsafe_retry_count + 1): if attempt > 0 and isinstance(prompt.seed, int): prompt.seed += 100_000_000 + attempt - result = _generate_single_image( + result = generate_single_image( prompt, debug_img_callback=debug_img_callback, progress_img_callback=progress_img_callback, @@ -215,596 +211,3 @@ def imagine( logger.info(" Image was unsafe, retrying with new seed...") yield result - - -def _generate_single_image_compvis( - prompt: "ImaginePrompt", - debug_img_callback=None, - progress_img_callback=None, - progress_img_interval_steps=3, - progress_img_interval_min_s=0.1, - half_mode=None, - add_caption=False, - # controlnet, finetune, naive, auto - inpaint_method="finetune", - return_latent=False, -): - import torch.nn - from PIL import Image, ImageOps - from pytorch_lightning import seed_everything - - from imaginairy.enhancers.clip_masking import get_img_mask - from imaginairy.enhancers.describe_image_blip import generate_caption - from imaginairy.enhancers.face_restoration_codeformer import enhance_faces - from imaginairy.enhancers.upscale_realesrgan import upscale_image - from imaginairy.modules.midas.api import torch_image_to_depth_map - from imaginairy.samplers import SOLVER_LOOKUP - from imaginairy.samplers.editing import CFGEditingDenoiser - from imaginairy.schema import ControlInput, ImagineResult, MaskMode - from imaginairy.utils import get_device, randn_seeded - from imaginairy.utils.img_utils import ( - add_caption_to_image, - pillow_fit_image_within, - pillow_img_to_torch_image, - pillow_mask_to_latent_mask, - torch_img_to_pillow_img, - ) - from imaginairy.utils.log_utils import ( - ImageLoggingContext, - log_conditioning, - log_img, - log_latent, - ) - from imaginairy.utils.model_manager import ( - get_diffusion_model, - get_model_default_image_size, - ) - from imaginairy.utils.outpaint import ( - outpaint_arg_str_parse, - prepare_image_for_outpaint, - ) - from imaginairy.utils.safety import create_safety_score - - latent_channels = 4 - downsampling_factor = 8 - batch_size = 1 - global _most_recent_result - # handle prompt pulling in previous values - # if isinstance(prompt.init_image, str) and prompt.init_image.startswith("*prev"): - # _, img_type = prompt.init_image.strip("*").split(".") - # prompt.init_image = _most_recent_result.images[img_type] - # if isinstance(prompt.mask_image, str) and prompt.mask_image.startswith("*prev"): - # _, img_type = prompt.mask_image.strip("*").split(".") - # prompt.mask_image = _most_recent_result.images[img_type] - prompt = prompt.make_concrete_copy() - - control_modes = [] - control_inputs = prompt.control_inputs or [] - control_inputs = control_inputs.copy() - for_inpainting = bool(prompt.mask_image or prompt.mask_prompt or prompt.outpaint) - - if control_inputs: - control_modes = [c.mode for c in prompt.control_inputs] - if inpaint_method == "auto": - if prompt.model_weights in {"SD-1.5"}: - inpaint_method = "finetune" - else: - inpaint_method = "controlnet" - - if for_inpainting and inpaint_method == "controlnet": - control_modes.append("inpaint") - model = get_diffusion_model( - weights_location=prompt.model_weights, - config_path=prompt.model_architecture, - control_weights_locations=control_modes, - half_mode=half_mode, - for_inpainting=for_inpainting and inpaint_method == "finetune", - ) - is_controlnet_model = hasattr(model, "control_key") - - progress_latents = [] - - def latent_logger(latents): - progress_latents.append(latents) - - with ImageLoggingContext( - prompt=prompt, - model=model, - debug_img_callback=debug_img_callback, - progress_img_callback=progress_img_callback, - progress_img_interval_steps=progress_img_interval_steps, - progress_img_interval_min_s=progress_img_interval_min_s, - progress_latent_callback=latent_logger - if prompt.collect_progress_latents - else None, - ) as lc: - seed_everything(prompt.seed) - - model.tile_mode(prompt.tile_mode) - with lc.timing("conditioning"): - # need to expand if doing batches - neutral_conditioning = _prompts_to_embeddings(prompt.negative_prompt, model) - _prompts_to_embeddings("", model) - log_conditioning(neutral_conditioning, "neutral conditioning") - if prompt.conditioning is not None: - positive_conditioning = prompt.conditioning - else: - positive_conditioning = _prompts_to_embeddings(prompt.prompts, model) - log_conditioning(positive_conditioning, "positive conditioning") - - shape = [ - batch_size, - latent_channels, - prompt.height // downsampling_factor, - prompt.width // downsampling_factor, - ] - SolverCls = SOLVER_LOOKUP[prompt.solver_type.lower()] - solver = SolverCls(model) - mask_image: Image.Image | LazyLoadingImage | None = None - mask_latent = mask_image_orig = mask_grayscale = None - init_latent: torch.Tensor | None = None - t_enc = None - starting_image = None - denoiser_cls = None - - c_cat = [] - c_cat_neutral = None - result_images: dict[str, torch.Tensor | Image.Image | None] = {} - assert prompt.seed is not None - seed_everything(prompt.seed) - noise = randn_seeded(seed=prompt.seed, size=shape).to(get_device()) - control_strengths = [] - - if prompt.init_image: - starting_image = prompt.init_image - assert prompt.init_image_strength is not None - generation_strength = 1 - prompt.init_image_strength - - if model.cond_stage_key == "edit" or generation_strength >= 1: - t_enc = None - else: - t_enc = int(prompt.steps * generation_strength) - - if prompt.mask_prompt: - mask_image, mask_grayscale = get_img_mask( - starting_image, prompt.mask_prompt, threshold=0.1 - ) - elif prompt.mask_image: - mask_image = prompt.mask_image.convert("L") - if prompt.outpaint: - outpaint_kwargs = outpaint_arg_str_parse(prompt.outpaint) - starting_image, mask_image = prepare_image_for_outpaint( - starting_image, mask_image, **outpaint_kwargs - ) - assert starting_image is not None - init_image = pillow_fit_image_within( - starting_image, - max_height=prompt.height, - max_width=prompt.width, - ) - init_image_t = pillow_img_to_torch_image(init_image).to(get_device()) - init_latent = model.get_first_stage_encoding( - model.encode_first_stage(init_image_t) - ) - assert init_latent is not None - shape = list(init_latent.shape) - - log_latent(init_latent, "init_latent") - - if mask_image is not None: - mask_image = pillow_fit_image_within( - mask_image, - max_height=prompt.height, - max_width=prompt.width, - convert="L", - ) - - log_img(mask_image, "init mask") - - if prompt.mask_mode == MaskMode.REPLACE: - mask_image = ImageOps.invert(mask_image) - - mask_image_orig = mask_image - log_img(mask_image, "latent_mask") - mask_latent = pillow_mask_to_latent_mask( - mask_image, downsampling_factor=downsampling_factor - ).to(get_device()) - if inpaint_method == "controlnet": - result_images["control-inpaint"] = mask_image - control_inputs.append( - ControlInput(mode="inpaint", image=mask_image) - ) - assert prompt.seed is not None - seed_everything(prompt.seed) - noise = randn_seeded(seed=prompt.seed, size=list(init_latent.shape)).to( - get_device() - ) - # noise = noise[:, :, : init_latent.shape[2], : init_latent.shape[3]] - - # schedule = NoiseSchedule( - # model_num_timesteps=model.num_timesteps, - # ddim_num_steps=prompt.steps, - # model_alphas_cumprod=model.alphas_cumprod, - # ddim_discretize="uniform", - # ) - # if generation_strength >= 1: - # # prompt strength gets converted to time encodings, - # # which means you can't get to true 0 without this hack - # # (or setting steps=1000) - # init_latent_noised = noise - # else: - # init_latent_noised = noise_an_image( - # init_latent, - # torch.tensor([t_enc - 1]).to(get_device()), - # schedule=schedule, - # noise=noise, - # ) - - if hasattr(model, "depth_stage_key"): - # depth model - depth_t = torch_image_to_depth_map(init_image_t) - depth_latent = torch.nn.functional.interpolate( - depth_t, - size=shape[2:], - mode="bicubic", - align_corners=False, - ) - result_images["depth_image"] = depth_t - c_cat.append(depth_latent) - - elif is_controlnet_model: - from imaginairy.img_processors.control_modes import CONTROL_MODES - - for control_input in control_inputs: - if control_input.image_raw is not None: - control_image = control_input.image_raw - elif control_input.image is not None: - control_image = control_input.image - else: - raise RuntimeError("Control image must be provided") - assert control_image is not None - control_image = control_image.convert("RGB") - log_img(control_image, "control_image_input") - assert control_image is not None - - control_image_input = pillow_fit_image_within( - control_image, - max_height=prompt.height, - max_width=prompt.width, - ) - control_image_input_t = pillow_img_to_torch_image(control_image_input) - control_image_input_t = control_image_input_t.to(get_device()) - - if control_input.image_raw is None: - control_prep_function = CONTROL_MODES[control_input.mode] - if control_input.mode == "inpaint": - control_image_t = control_prep_function( # type: ignore - control_image_input_t, init_image_t - ) - else: - control_image_t = control_prep_function(control_image_input_t) # type: ignore - else: - control_image_t = (control_image_input_t + 1) / 2 - - control_image_disp = control_image_t * 2 - 1 - result_images[f"control-{control_input.mode}"] = control_image_disp - log_img(control_image_disp, "control_image") - - if len(control_image_t.shape) == 3: - raise RuntimeError("Control image must be 4D") - - if control_image_t.shape[1] != 3: - raise RuntimeError("Control image must have 3 channels") - - if ( - control_input.mode != "inpaint" - and control_image_t.min() < 0 - or control_image_t.max() > 1 - ): - msg = f"Control image must be in [0, 1] but we received {control_image_t.min()} and {control_image_t.max()}" - raise RuntimeError(msg) - - if control_image_t.max() == control_image_t.min(): - msg = f"No control signal found in control image {control_input.mode}." - raise RuntimeError(msg) - - c_cat.append(control_image_t) - control_strengths.append(control_input.strength) - - elif hasattr(model, "masked_image_key"): - # inpainting model - assert mask_image_orig is not None - assert mask_latent is not None - mask_t = pillow_img_to_torch_image(ImageOps.invert(mask_image_orig)).to( - get_device() - ) - inverted_mask = 1 - mask_latent - masked_image_t = init_image_t * (mask_t < 0.5) - log_img(masked_image_t, "masked_image") - - inverted_mask_latent = torch.nn.functional.interpolate( - inverted_mask, size=shape[-2:] - ) - c_cat.append(inverted_mask_latent) - - masked_image_latent = model.get_first_stage_encoding( - model.encode_first_stage(masked_image_t) - ) - c_cat.append(masked_image_latent) - - elif model.cond_stage_key == "edit": - # pix2pix model - c_cat = [model.encode_first_stage(init_image_t)] - assert init_latent is not None - c_cat_neutral = [torch.zeros_like(init_latent)] - denoiser_cls = CFGEditingDenoiser - if c_cat: - c_cat = [torch.cat([c], dim=1) for c in c_cat] - - if c_cat_neutral is None: - c_cat_neutral = c_cat - - positive_conditioning_d: dict[str, Any] = { - "c_concat": c_cat, - "c_crossattn": [positive_conditioning], - } - neutral_conditioning_d: dict[str, Any] = { - "c_concat": c_cat_neutral, - "c_crossattn": [neutral_conditioning], - } - del neutral_conditioning - del positive_conditioning - - if control_strengths and is_controlnet_model: - positive_conditioning_d["control_strengths"] = torch.Tensor( - control_strengths - ) - neutral_conditioning_d["control_strengths"] = torch.Tensor( - control_strengths - ) - - if ( - prompt.allow_compose_phase - and not is_controlnet_model - and model.cond_stage_key != "edit" - ): - default_size = get_model_default_image_size( - prompt.model_weights.architecture - ) - if prompt.init_image: - comp_image = _generate_composition_image( - prompt=prompt, - target_height=init_image.height, - target_width=init_image.width, - cutoff=default_size, - ) - else: - comp_image = _generate_composition_image( - prompt=prompt, - target_height=prompt.height, - target_width=prompt.width, - cutoff=default_size, - ) - if comp_image is not None: - result_images["composition"] = comp_image - # noise = noise[:, :, : comp_image.height, : comp_image.shape[3]] - t_enc = int(prompt.steps * 0.65) - log_img(comp_image, "comp_image") - comp_image_t = pillow_img_to_torch_image(comp_image) - comp_image_t = comp_image_t.to(get_device()) - init_latent = model.get_first_stage_encoding( - model.encode_first_stage(comp_image_t) - ) - with lc.timing("sampling"): - samples = solver.sample( - num_steps=prompt.steps, - positive_conditioning=positive_conditioning_d, - neutral_conditioning=neutral_conditioning_d, - guidance_scale=prompt.prompt_strength, - t_start=t_enc, - mask=mask_latent, - orig_latent=init_latent, - shape=shape, - batch_size=1, - denoiser_cls=denoiser_cls, - noise=noise, - ) - if return_latent: - return samples - - with lc.timing("decoding"): - gen_imgs_t = model.decode_first_stage(samples) - gen_img = torch_img_to_pillow_img(gen_imgs_t) - - if mask_image_orig and init_image: - mask_final = mask_image_orig.copy() - log_img(mask_final, "reconstituting mask") - mask_final = ImageOps.invert(mask_final) - gen_img = Image.composite(gen_img, init_image, mask_final) - gen_img = combine_image( - original_img=init_image, - generated_img=gen_img, - mask_img=mask_image_orig, - ) - log_img(gen_img, "reconstituted image") - - upscaled_img = None - rebuilt_orig_img = None - - if add_caption: - caption = generate_caption(gen_img) - logger.info(f"Generated caption: {caption}") - - with lc.timing("safety-filter"): - safety_score = create_safety_score( - gen_img, - safety_mode=IMAGINAIRY_SAFETY_MODE, - ) - if safety_score.is_filtered: - progress_latents.clear() - if not safety_score.is_filtered: - if prompt.fix_faces: - logger.info("Fixing 😊 's in 🖼 using CodeFormer...") - with lc.timing("face enhancement"): - gen_img = enhance_faces(gen_img, fidelity=prompt.fix_faces_fidelity) - if prompt.upscale: - logger.info("Upscaling 🖼 using real-ESRGAN...") - with lc.timing("upscaling"): - upscaled_img = upscale_image(gen_img) - - # put the newly generated patch back into the original, full-size image - if prompt.mask_modify_original and mask_image_orig and starting_image: - img_to_add_back_to_original = upscaled_img if upscaled_img else gen_img - rebuilt_orig_img = combine_image( - original_img=starting_image, - generated_img=img_to_add_back_to_original, - mask_img=mask_image_orig, - ) - - if prompt.caption_text: - caption_text = prompt.caption_text.format(prompt=prompt.prompt_text) - add_caption_to_image(gen_img, caption_text) - - result_images["upscaled"] = upscaled_img - result_images["modified_original"] = rebuilt_orig_img - result_images["mask_binary"] = mask_image_orig - result_images["mask_grayscale"] = mask_grayscale - - result = ImagineResult( - img=gen_img, - prompt=prompt, - is_nsfw=safety_score.is_nsfw, - safety_score=safety_score, - result_images=result_images, - timings=lc.get_timings(), - progress_latents=progress_latents.copy(), - ) - - _most_recent_result = result - logger.info(f"Image Generated. Timings: {result.timings_str()}") - return result - - -def _prompts_to_embeddings(prompts, model): - total_weight = sum(wp.weight for wp in prompts) - conditioning = sum( - model.get_learned_conditioning(wp.text) * (wp.weight / total_weight) - for wp in prompts - ) - return conditioning - - -def calc_scale_to_fit_within(height: int, width: int, max_size) -> float: - max_width, max_height = normalize_image_size(max_size) - if width <= max_width and height <= max_height: - return 1 - - width_ratio = max_width / width - height_ratio = max_height / height - - return min(width_ratio, height_ratio) - - -def _scale_latent( - latent, - model, - h, - w, -): - from torch.nn import functional as F - - # convert to non-latent-space first - img = model.decode_first_stage(latent) - img = F.interpolate(img, size=(h, w), mode="bicubic", align_corners=False) - latent = model.get_first_stage_encoding(model.encode_first_stage(img)) - return latent - - -def _generate_composition_image( - prompt, - target_height, - target_width, - cutoff: tuple[int, int] = (512, 512), - dtype=None, -): - from PIL import Image - - from imaginairy.api.generate_refiners import _generate_single_image - from imaginairy.utils import default, get_default_dtype - - cutoff = normalize_image_size(cutoff) - if prompt.width <= cutoff[0] and prompt.height <= cutoff[1]: - return None, None - - dtype = default(dtype, get_default_dtype) - - shrink_scale = calc_scale_to_fit_within( - height=prompt.height, - width=prompt.width, - max_size=cutoff, - ) - - composition_prompt = prompt.full_copy( - deep=True, - update={ - "size": ( - int(prompt.width * shrink_scale), - int(prompt.height * shrink_scale), - ), - "steps": None, - "upscale": False, - "fix_faces": False, - "mask_modify_original": False, - "allow_compose_phase": False, - }, - ) - - result = _generate_single_image(composition_prompt, dtype=dtype) - img = result.images["generated"] - while img.width < target_width: - from imaginairy.enhancers.upscale_realesrgan import upscale_image - - img = upscale_image(img) - - # samples = _generate_single_image(composition_prompt, return_latent=True) - # while samples.shape[-1] * 8 < target_width: - # samples = upscale_latent(samples) - # - # img = model_latent_to_pillow_img(samples) - - img = img.resize( - (target_width, target_height), - resample=Image.Resampling.LANCZOS, - ) - - return img, result.images["generated"] - - -def prompt_normalized(prompt, length=130): - return re.sub(r"[^a-zA-Z0-9.,\[\]-]+", "_", prompt)[:length] - - -def combine_image(original_img, generated_img, mask_img): - """Combine the generated image with the original image using the mask image.""" - from PIL import Image - - from imaginairy.utils.log_utils import log_img - - generated_img = generated_img.resize( - original_img.size, - resample=Image.Resampling.LANCZOS, - ) - - mask_for_orig_size = mask_img.resize( - original_img.size, - resample=Image.Resampling.LANCZOS, - ) - log_img(mask_for_orig_size, "mask for original image size") - - rebuilt_orig_img = Image.composite( - original_img, - generated_img, - mask_for_orig_size, - ) - log_img(rebuilt_orig_img, "reconstituted original") - return rebuilt_orig_img diff --git a/imaginairy/api/generate_compvis.py b/imaginairy/api/generate_compvis.py new file mode 100644 index 0000000..70a03d4 --- /dev/null +++ b/imaginairy/api/generate_compvis.py @@ -0,0 +1,547 @@ +from typing import Any + +from imaginairy.api.generate import ( + IMAGINAIRY_SAFETY_MODE, + logger, +) +from imaginairy.api.generate_refiners import _generate_composition_image +from imaginairy.schema import ImaginePrompt, LazyLoadingImage +from imaginairy.utils.img_utils import calc_scale_to_fit_within, combine_image +from imaginairy.utils.named_resolutions import normalize_image_size + + +def _generate_single_image_compvis( + prompt: "ImaginePrompt", + debug_img_callback=None, + progress_img_callback=None, + progress_img_interval_steps=3, + progress_img_interval_min_s=0.1, + half_mode=None, + add_caption=False, + # controlnet, finetune, naive, auto + inpaint_method="finetune", + return_latent=False, +): + import torch.nn + from PIL import Image, ImageOps + from pytorch_lightning import seed_everything + + from imaginairy.enhancers.clip_masking import get_img_mask + from imaginairy.enhancers.describe_image_blip import generate_caption + from imaginairy.enhancers.face_restoration_codeformer import enhance_faces + from imaginairy.enhancers.upscale_realesrgan import upscale_image + from imaginairy.modules.midas.api import torch_image_to_depth_map + from imaginairy.samplers import SOLVER_LOOKUP + from imaginairy.samplers.editing import CFGEditingDenoiser + from imaginairy.schema import ControlInput, ImagineResult, MaskMode + from imaginairy.utils import get_device, randn_seeded + from imaginairy.utils.img_utils import ( + add_caption_to_image, + pillow_fit_image_within, + pillow_img_to_torch_image, + pillow_mask_to_latent_mask, + torch_img_to_pillow_img, + ) + from imaginairy.utils.log_utils import ( + ImageLoggingContext, + log_conditioning, + log_img, + log_latent, + ) + from imaginairy.utils.model_manager import ( + get_diffusion_model, + get_model_default_image_size, + ) + from imaginairy.utils.outpaint import ( + outpaint_arg_str_parse, + prepare_image_for_outpaint, + ) + from imaginairy.utils.safety import create_safety_score + + latent_channels = 4 + downsampling_factor = 8 + batch_size = 1 + global _most_recent_result + # handle prompt pulling in previous values + # if isinstance(prompt.init_image, str) and prompt.init_image.startswith("*prev"): + # _, img_type = prompt.init_image.strip("*").split(".") + # prompt.init_image = _most_recent_result.images[img_type] + # if isinstance(prompt.mask_image, str) and prompt.mask_image.startswith("*prev"): + # _, img_type = prompt.mask_image.strip("*").split(".") + # prompt.mask_image = _most_recent_result.images[img_type] + prompt = prompt.make_concrete_copy() + + control_modes = [] + control_inputs = prompt.control_inputs or [] + control_inputs = control_inputs.copy() + for_inpainting = bool(prompt.mask_image or prompt.mask_prompt or prompt.outpaint) + + if control_inputs: + control_modes = [c.mode for c in prompt.control_inputs] + if inpaint_method == "auto": + if prompt.model_weights in {"SD-1.5"}: + inpaint_method = "finetune" + else: + inpaint_method = "controlnet" + + if for_inpainting and inpaint_method == "controlnet": + control_modes.append("inpaint") + model = get_diffusion_model( + weights_location=prompt.model_weights, + config_path=prompt.model_architecture, + control_weights_locations=control_modes, + half_mode=half_mode, + for_inpainting=for_inpainting and inpaint_method == "finetune", + ) + is_controlnet_model = hasattr(model, "control_key") + + progress_latents = [] + + def latent_logger(latents): + progress_latents.append(latents) + + with ImageLoggingContext( + prompt=prompt, + model=model, + debug_img_callback=debug_img_callback, + progress_img_callback=progress_img_callback, + progress_img_interval_steps=progress_img_interval_steps, + progress_img_interval_min_s=progress_img_interval_min_s, + progress_latent_callback=latent_logger + if prompt.collect_progress_latents + else None, + ) as lc: + seed_everything(prompt.seed) + + model.tile_mode(prompt.tile_mode) + with lc.timing("conditioning"): + # need to expand if doing batches + neutral_conditioning = _prompts_to_embeddings(prompt.negative_prompt, model) + _prompts_to_embeddings("", model) + log_conditioning(neutral_conditioning, "neutral conditioning") + if prompt.conditioning is not None: + positive_conditioning = prompt.conditioning + else: + positive_conditioning = _prompts_to_embeddings(prompt.prompts, model) + log_conditioning(positive_conditioning, "positive conditioning") + + shape = [ + batch_size, + latent_channels, + prompt.height // downsampling_factor, + prompt.width // downsampling_factor, + ] + SolverCls = SOLVER_LOOKUP[prompt.solver_type.lower()] + solver = SolverCls(model) + mask_image: Image.Image | LazyLoadingImage | None = None + mask_latent = mask_image_orig = mask_grayscale = None + init_latent: torch.Tensor | None = None + t_enc = None + starting_image = None + denoiser_cls = None + + c_cat = [] + c_cat_neutral = None + result_images: dict[str, torch.Tensor | Image.Image | None] = {} + assert prompt.seed is not None + seed_everything(prompt.seed) + noise = randn_seeded(seed=prompt.seed, size=shape).to(get_device()) + control_strengths = [] + + if prompt.init_image: + starting_image = prompt.init_image + assert prompt.init_image_strength is not None + generation_strength = 1 - prompt.init_image_strength + + if model.cond_stage_key == "edit" or generation_strength >= 1: + t_enc = None + else: + t_enc = int(prompt.steps * generation_strength) + + if prompt.mask_prompt: + mask_image, mask_grayscale = get_img_mask( + starting_image, prompt.mask_prompt, threshold=0.1 + ) + elif prompt.mask_image: + mask_image = prompt.mask_image.convert("L") + if prompt.outpaint: + outpaint_kwargs = outpaint_arg_str_parse(prompt.outpaint) + starting_image, mask_image = prepare_image_for_outpaint( + starting_image, mask_image, **outpaint_kwargs + ) + assert starting_image is not None + init_image = pillow_fit_image_within( + starting_image, + max_height=prompt.height, + max_width=prompt.width, + ) + init_image_t = pillow_img_to_torch_image(init_image).to(get_device()) + init_latent = model.get_first_stage_encoding( + model.encode_first_stage(init_image_t) + ) + assert init_latent is not None + shape = list(init_latent.shape) + + log_latent(init_latent, "init_latent") + + if mask_image is not None: + mask_image = pillow_fit_image_within( + mask_image, + max_height=prompt.height, + max_width=prompt.width, + convert="L", + ) + + log_img(mask_image, "init mask") + + if prompt.mask_mode == MaskMode.REPLACE: + mask_image = ImageOps.invert(mask_image) + + mask_image_orig = mask_image + log_img(mask_image, "latent_mask") + mask_latent = pillow_mask_to_latent_mask( + mask_image, downsampling_factor=downsampling_factor + ).to(get_device()) + if inpaint_method == "controlnet": + result_images["control-inpaint"] = mask_image + control_inputs.append( + ControlInput(mode="inpaint", image=mask_image) + ) + assert prompt.seed is not None + seed_everything(prompt.seed) + noise = randn_seeded(seed=prompt.seed, size=list(init_latent.shape)).to( + get_device() + ) + # noise = noise[:, :, : init_latent.shape[2], : init_latent.shape[3]] + + # schedule = NoiseSchedule( + # model_num_timesteps=model.num_timesteps, + # ddim_num_steps=prompt.steps, + # model_alphas_cumprod=model.alphas_cumprod, + # ddim_discretize="uniform", + # ) + # if generation_strength >= 1: + # # prompt strength gets converted to time encodings, + # # which means you can't get to true 0 without this hack + # # (or setting steps=1000) + # init_latent_noised = noise + # else: + # init_latent_noised = noise_an_image( + # init_latent, + # torch.tensor([t_enc - 1]).to(get_device()), + # schedule=schedule, + # noise=noise, + # ) + + if hasattr(model, "depth_stage_key"): + # depth model + depth_t = torch_image_to_depth_map(init_image_t) + depth_latent = torch.nn.functional.interpolate( + depth_t, + size=shape[2:], + mode="bicubic", + align_corners=False, + ) + result_images["depth_image"] = depth_t + c_cat.append(depth_latent) + + elif is_controlnet_model: + from imaginairy.img_processors.control_modes import CONTROL_MODES + + for control_input in control_inputs: + if control_input.image_raw is not None: + control_image = control_input.image_raw + elif control_input.image is not None: + control_image = control_input.image + else: + raise RuntimeError("Control image must be provided") + assert control_image is not None + control_image = control_image.convert("RGB") + log_img(control_image, "control_image_input") + assert control_image is not None + + control_image_input = pillow_fit_image_within( + control_image, + max_height=prompt.height, + max_width=prompt.width, + ) + control_image_input_t = pillow_img_to_torch_image(control_image_input) + control_image_input_t = control_image_input_t.to(get_device()) + + if control_input.image_raw is None: + control_prep_function = CONTROL_MODES[control_input.mode] + if control_input.mode == "inpaint": + control_image_t = control_prep_function( # type: ignore + control_image_input_t, init_image_t + ) + else: + control_image_t = control_prep_function(control_image_input_t) # type: ignore + else: + control_image_t = (control_image_input_t + 1) / 2 + + control_image_disp = control_image_t * 2 - 1 + result_images[f"control-{control_input.mode}"] = control_image_disp + log_img(control_image_disp, "control_image") + + if len(control_image_t.shape) == 3: + raise RuntimeError("Control image must be 4D") + + if control_image_t.shape[1] != 3: + raise RuntimeError("Control image must have 3 channels") + + if ( + control_input.mode != "inpaint" + and control_image_t.min() < 0 + or control_image_t.max() > 1 + ): + msg = f"Control image must be in [0, 1] but we received {control_image_t.min()} and {control_image_t.max()}" + raise RuntimeError(msg) + + if control_image_t.max() == control_image_t.min(): + msg = f"No control signal found in control image {control_input.mode}." + raise RuntimeError(msg) + + c_cat.append(control_image_t) + control_strengths.append(control_input.strength) + + elif hasattr(model, "masked_image_key"): + # inpainting model + assert mask_image_orig is not None + assert mask_latent is not None + mask_t = pillow_img_to_torch_image(ImageOps.invert(mask_image_orig)).to( + get_device() + ) + inverted_mask = 1 - mask_latent + masked_image_t = init_image_t * (mask_t < 0.5) + log_img(masked_image_t, "masked_image") + + inverted_mask_latent = torch.nn.functional.interpolate( + inverted_mask, size=shape[-2:] + ) + c_cat.append(inverted_mask_latent) + + masked_image_latent = model.get_first_stage_encoding( + model.encode_first_stage(masked_image_t) + ) + c_cat.append(masked_image_latent) + + elif model.cond_stage_key == "edit": + # pix2pix model + c_cat = [model.encode_first_stage(init_image_t)] + assert init_latent is not None + c_cat_neutral = [torch.zeros_like(init_latent)] + denoiser_cls = CFGEditingDenoiser + if c_cat: + c_cat = [torch.cat([c], dim=1) for c in c_cat] + + if c_cat_neutral is None: + c_cat_neutral = c_cat + + positive_conditioning_d: dict[str, Any] = { + "c_concat": c_cat, + "c_crossattn": [positive_conditioning], + } + neutral_conditioning_d: dict[str, Any] = { + "c_concat": c_cat_neutral, + "c_crossattn": [neutral_conditioning], + } + del neutral_conditioning + del positive_conditioning + + if control_strengths and is_controlnet_model: + positive_conditioning_d["control_strengths"] = torch.Tensor( + control_strengths + ) + neutral_conditioning_d["control_strengths"] = torch.Tensor( + control_strengths + ) + + if ( + prompt.allow_compose_phase + and not is_controlnet_model + and model.cond_stage_key != "edit" + ): + default_size = get_model_default_image_size( + prompt.model_weights.architecture + ) + if prompt.init_image: + comp_image = _generate_composition_image( + prompt=prompt, + target_height=init_image.height, + target_width=init_image.width, + cutoff=default_size, + ) + else: + comp_image = _generate_composition_image( + prompt=prompt, + target_height=prompt.height, + target_width=prompt.width, + cutoff=default_size, + ) + if comp_image is not None: + result_images["composition"] = comp_image + # noise = noise[:, :, : comp_image.height, : comp_image.shape[3]] + t_enc = int(prompt.steps * 0.65) + log_img(comp_image, "comp_image") + comp_image_t = pillow_img_to_torch_image(comp_image) + comp_image_t = comp_image_t.to(get_device()) + init_latent = model.get_first_stage_encoding( + model.encode_first_stage(comp_image_t) + ) + with lc.timing("sampling"): + samples = solver.sample( + num_steps=prompt.steps, + positive_conditioning=positive_conditioning_d, + neutral_conditioning=neutral_conditioning_d, + guidance_scale=prompt.prompt_strength, + t_start=t_enc, + mask=mask_latent, + orig_latent=init_latent, + shape=shape, + batch_size=1, + denoiser_cls=denoiser_cls, + noise=noise, + ) + if return_latent: + return samples + + with lc.timing("decoding"): + gen_imgs_t = model.decode_first_stage(samples) + gen_img = torch_img_to_pillow_img(gen_imgs_t) + + if mask_image_orig and init_image: + mask_final = mask_image_orig.copy() + log_img(mask_final, "reconstituting mask") + mask_final = ImageOps.invert(mask_final) + gen_img = Image.composite(gen_img, init_image, mask_final) + gen_img = combine_image( + original_img=init_image, + generated_img=gen_img, + mask_img=mask_image_orig, + ) + log_img(gen_img, "reconstituted image") + + upscaled_img = None + rebuilt_orig_img = None + + if add_caption: + caption = generate_caption(gen_img) + logger.info(f"Generated caption: {caption}") + + with lc.timing("safety-filter"): + safety_score = create_safety_score( + gen_img, + safety_mode=IMAGINAIRY_SAFETY_MODE, + ) + if safety_score.is_filtered: + progress_latents.clear() + if not safety_score.is_filtered: + if prompt.fix_faces: + logger.info("Fixing 😊 's in 🖼 using CodeFormer...") + with lc.timing("face enhancement"): + gen_img = enhance_faces(gen_img, fidelity=prompt.fix_faces_fidelity) + if prompt.upscale: + logger.info("Upscaling 🖼 using real-ESRGAN...") + with lc.timing("upscaling"): + upscaled_img = upscale_image(gen_img) + + # put the newly generated patch back into the original, full-size image + if prompt.mask_modify_original and mask_image_orig and starting_image: + img_to_add_back_to_original = upscaled_img if upscaled_img else gen_img + rebuilt_orig_img = combine_image( + original_img=starting_image, + generated_img=img_to_add_back_to_original, + mask_img=mask_image_orig, + ) + + if prompt.caption_text: + caption_text = prompt.caption_text.format(prompt=prompt.prompt_text) + add_caption_to_image(gen_img, caption_text) + + result_images["upscaled"] = upscaled_img + result_images["modified_original"] = rebuilt_orig_img + result_images["mask_binary"] = mask_image_orig + result_images["mask_grayscale"] = mask_grayscale + + result = ImagineResult( + img=gen_img, + prompt=prompt, + is_nsfw=safety_score.is_nsfw, + safety_score=safety_score, + result_images=result_images, + timings=lc.get_timings(), + progress_latents=progress_latents.copy(), + ) + + _most_recent_result = result + logger.info(f"Image Generated. Timings: {result.timings_str()}") + return result + + +def _prompts_to_embeddings(prompts, model): + total_weight = sum(wp.weight for wp in prompts) + conditioning = sum( + model.get_learned_conditioning(wp.text) * (wp.weight / total_weight) + for wp in prompts + ) + return conditioning + + +def _generate_composition_image( + prompt, + target_height, + target_width, + cutoff: tuple[int, int] = (512, 512), + dtype=None, +): + from PIL import Image + + from imaginairy.api.generate_refiners import generate_single_image + from imaginairy.utils import default, get_default_dtype + + cutoff = normalize_image_size(cutoff) + if prompt.width <= cutoff[0] and prompt.height <= cutoff[1]: + return None, None + + dtype = default(dtype, get_default_dtype) + + shrink_scale = calc_scale_to_fit_within( + height=prompt.height, + width=prompt.width, + max_size=cutoff, + ) + + composition_prompt = prompt.full_copy( + deep=True, + update={ + "size": ( + int(prompt.width * shrink_scale), + int(prompt.height * shrink_scale), + ), + "steps": None, + "upscale": False, + "fix_faces": False, + "mask_modify_original": False, + "allow_compose_phase": False, + }, + ) + + result = generate_single_image(composition_prompt, dtype=dtype) + img = result.images["generated"] + while img.width < target_width: + from imaginairy.enhancers.upscale_realesrgan import upscale_image + + img = upscale_image(img) + + # samples = generate_single_image(composition_prompt, return_latent=True) + # while samples.shape[-1] * 8 < target_width: + # samples = upscale_latent(samples) + # + # img = model_latent_to_pillow_img(samples) + + img = img.resize( + (target_width, target_height), + resample=Image.Resampling.LANCZOS, + ) + + return img, result.images["generated"] diff --git a/imaginairy/api/generate_refiners.py b/imaginairy/api/generate_refiners.py index 0740cd8..e1e1949 100644 --- a/imaginairy/api/generate_refiners.py +++ b/imaginairy/api/generate_refiners.py @@ -5,11 +5,13 @@ from typing import List, Optional from imaginairy.config import CONTROL_CONFIG_SHORTCUTS from imaginairy.schema import ControlInput, ImaginePrompt, MaskMode, WeightedPrompt +from imaginairy.utils.img_utils import calc_scale_to_fit_within +from imaginairy.utils.named_resolutions import normalize_image_size logger = logging.getLogger(__name__) -def _generate_single_image( +def generate_single_image( prompt: ImaginePrompt, debug_img_callback=None, progress_img_callback=None, @@ -28,8 +30,6 @@ def _generate_single_image( from imaginairy.api.generate import ( IMAGINAIRY_SAFETY_MODE, - _generate_composition_image, - combine_image, ) from imaginairy.enhancers.clip_masking import get_img_mask from imaginairy.enhancers.describe_image_blip import generate_caption @@ -40,6 +40,7 @@ def _generate_single_image( from imaginairy.utils import get_device, randn_seeded from imaginairy.utils.img_utils import ( add_caption_to_image, + combine_image, pillow_fit_image_within, pillow_img_to_torch_image, pillow_mask_to_latent_mask, @@ -523,3 +524,64 @@ def prep_control_input( ) controlnet.set_scale(control_input.strength) return controlnet, control_image_t, control_image_disp + + +def _generate_composition_image( + prompt, + target_height, + target_width, + cutoff: tuple[int, int] = (512, 512), + dtype=None, +): + from PIL import Image + + from imaginairy.api.generate_refiners import generate_single_image + from imaginairy.utils import default, get_default_dtype + + cutoff = normalize_image_size(cutoff) + if prompt.width <= cutoff[0] and prompt.height <= cutoff[1]: + return None, None + + dtype = default(dtype, get_default_dtype) + + shrink_scale = calc_scale_to_fit_within( + height=prompt.height, + width=prompt.width, + max_size=cutoff, + ) + + composition_prompt = prompt.full_copy( + deep=True, + update={ + "size": ( + int(prompt.width * shrink_scale), + int(prompt.height * shrink_scale), + ), + "steps": None, + "upscale": False, + "fix_faces": False, + "mask_modify_original": False, + "allow_compose_phase": False, + "caption_text": None, + }, + ) + + result = generate_single_image(composition_prompt, dtype=dtype) + img = result.images["generated"] + while img.width < target_width: + from imaginairy.enhancers.upscale_realesrgan import upscale_image + + img = upscale_image(img) + + # samples = generate_single_image(composition_prompt, return_latent=True) + # while samples.shape[-1] * 8 < target_width: + # samples = upscale_latent(samples) + # + # img = model_latent_to_pillow_img(samples) + + img = img.resize( + (target_width, target_height), + resample=Image.Resampling.LANCZOS, + ) + + return img, result.images["generated"] diff --git a/imaginairy/utils/__init__.py b/imaginairy/utils/__init__.py index feb125c..9207e65 100644 --- a/imaginairy/utils/__init__.py +++ b/imaginairy/utils/__init__.py @@ -1,6 +1,7 @@ import importlib import logging import platform +import re import time from contextlib import contextmanager, nullcontext from functools import lru_cache @@ -315,3 +316,7 @@ def get_nested_attribute(obj, attribute_path, depth=None, return_key=False): current_attribute = getattr(current_attribute, attribute) return (current_attribute, current_key) if return_key else current_attribute + + +def prompt_normalized(prompt, length=130): + return re.sub(r"[^a-zA-Z0-9.,\[\]-]+", "_", prompt)[:length] diff --git a/imaginairy/utils/img_utils.py b/imaginairy/utils/img_utils.py index d7f40e7..f6f62d2 100644 --- a/imaginairy/utils/img_utils.py +++ b/imaginairy/utils/img_utils.py @@ -19,6 +19,7 @@ from PIL import Image, ImageDraw, ImageFont from imaginairy.schema import LazyLoadingImage from imaginairy.utils import get_device +from imaginairy.utils.named_resolutions import normalize_image_size from imaginairy.utils.paths import PKG_ROOT @@ -221,3 +222,39 @@ def create_halo_effect( new_canvas.paste(transparent_image, (0, 0), transparent_image) return new_canvas + + +def combine_image(original_img, generated_img, mask_img): + """Combine the generated image with the original image using the mask image.""" + from PIL import Image + + from imaginairy.utils.log_utils import log_img + + generated_img = generated_img.resize( + original_img.size, + resample=Image.Resampling.LANCZOS, + ) + + mask_for_orig_size = mask_img.resize( + original_img.size, + resample=Image.Resampling.LANCZOS, + ) + log_img(mask_for_orig_size, "mask for original image size") + + rebuilt_orig_img = Image.composite( + original_img, + generated_img, + mask_for_orig_size, + ) + return rebuilt_orig_img + + +def calc_scale_to_fit_within(height: int, width: int, max_size) -> float: + max_width, max_height = normalize_image_size(max_size) + if width <= max_width and height <= max_height: + return 1 + + width_ratio = max_width / width + height_ratio = max_height / height + + return min(width_ratio, height_ratio) diff --git a/imaginairy/utils/named_resolutions.py b/imaginairy/utils/named_resolutions.py index 42e3e65..0402d83 100644 --- a/imaginairy/utils/named_resolutions.py +++ b/imaginairy/utils/named_resolutions.py @@ -2,7 +2,7 @@ import contextlib -_NAMED_RESOLUTIONS = { +NAMED_RESOLUTIONS = { "HD": (1280, 720), "FHD": (1920, 1080), "HALF-FHD": (960, 540), @@ -46,7 +46,7 @@ _NAMED_RESOLUTIONS = { "SVD": (1024, 576), # stable video diffusion } -_NAMED_RESOLUTIONS = {k.upper(): v for k, v in _NAMED_RESOLUTIONS.items()} +NAMED_RESOLUTIONS = {k.upper(): v for k, v in NAMED_RESOLUTIONS.items()} def normalize_image_size(resolution: str | int | tuple[int, int]) -> tuple[int, int]: @@ -66,8 +66,8 @@ def _normalize_image_size(resolution: str | int | tuple[int, int]) -> tuple[int, case str(): resolution = resolution.strip().upper() resolution = resolution.replace(" ", "").replace("X", ",").replace("*", ",") - if resolution.upper() in _NAMED_RESOLUTIONS: - return _NAMED_RESOLUTIONS[resolution.upper()] + if resolution.upper() in NAMED_RESOLUTIONS: + return NAMED_RESOLUTIONS[resolution.upper()] # is it WIDTH,HEIGHT format? try: