refactor: move code to more intuitive places

This commit is contained in:
Bryce 2023-12-20 12:32:29 -08:00 committed by Bryce Drennan
parent 8cfb46d6de
commit 6ebd12abb1
6 changed files with 662 additions and 608 deletions

View File

@ -2,13 +2,9 @@
import logging import logging
import os import os
import re from typing import Callable
from typing import TYPE_CHECKING, Any, Callable
from imaginairy.utils.named_resolutions import normalize_image_size from imaginairy.utils import prompt_normalized
if TYPE_CHECKING:
from imaginairy.schema import ImaginePrompt, LazyLoadingImage
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -160,7 +156,7 @@ def imagine(
): ):
import torch.nn import torch.nn
from imaginairy.api.generate_refiners import _generate_single_image from imaginairy.api.generate_refiners import generate_single_image
from imaginairy.schema import ImaginePrompt from imaginairy.schema import ImaginePrompt
from imaginairy.utils import ( from imaginairy.utils import (
check_torch_version, check_torch_version,
@ -199,7 +195,7 @@ def imagine(
for attempt in range(unsafe_retry_count + 1): for attempt in range(unsafe_retry_count + 1):
if attempt > 0 and isinstance(prompt.seed, int): if attempt > 0 and isinstance(prompt.seed, int):
prompt.seed += 100_000_000 + attempt prompt.seed += 100_000_000 + attempt
result = _generate_single_image( result = generate_single_image(
prompt, prompt,
debug_img_callback=debug_img_callback, debug_img_callback=debug_img_callback,
progress_img_callback=progress_img_callback, progress_img_callback=progress_img_callback,
@ -215,596 +211,3 @@ def imagine(
logger.info(" Image was unsafe, retrying with new seed...") logger.info(" Image was unsafe, retrying with new seed...")
yield result yield result
def _generate_single_image_compvis(
prompt: "ImaginePrompt",
debug_img_callback=None,
progress_img_callback=None,
progress_img_interval_steps=3,
progress_img_interval_min_s=0.1,
half_mode=None,
add_caption=False,
# controlnet, finetune, naive, auto
inpaint_method="finetune",
return_latent=False,
):
import torch.nn
from PIL import Image, ImageOps
from pytorch_lightning import seed_everything
from imaginairy.enhancers.clip_masking import get_img_mask
from imaginairy.enhancers.describe_image_blip import generate_caption
from imaginairy.enhancers.face_restoration_codeformer import enhance_faces
from imaginairy.enhancers.upscale_realesrgan import upscale_image
from imaginairy.modules.midas.api import torch_image_to_depth_map
from imaginairy.samplers import SOLVER_LOOKUP
from imaginairy.samplers.editing import CFGEditingDenoiser
from imaginairy.schema import ControlInput, ImagineResult, MaskMode
from imaginairy.utils import get_device, randn_seeded
from imaginairy.utils.img_utils import (
add_caption_to_image,
pillow_fit_image_within,
pillow_img_to_torch_image,
pillow_mask_to_latent_mask,
torch_img_to_pillow_img,
)
from imaginairy.utils.log_utils import (
ImageLoggingContext,
log_conditioning,
log_img,
log_latent,
)
from imaginairy.utils.model_manager import (
get_diffusion_model,
get_model_default_image_size,
)
from imaginairy.utils.outpaint import (
outpaint_arg_str_parse,
prepare_image_for_outpaint,
)
from imaginairy.utils.safety import create_safety_score
latent_channels = 4
downsampling_factor = 8
batch_size = 1
global _most_recent_result
# handle prompt pulling in previous values
# if isinstance(prompt.init_image, str) and prompt.init_image.startswith("*prev"):
# _, img_type = prompt.init_image.strip("*").split(".")
# prompt.init_image = _most_recent_result.images[img_type]
# if isinstance(prompt.mask_image, str) and prompt.mask_image.startswith("*prev"):
# _, img_type = prompt.mask_image.strip("*").split(".")
# prompt.mask_image = _most_recent_result.images[img_type]
prompt = prompt.make_concrete_copy()
control_modes = []
control_inputs = prompt.control_inputs or []
control_inputs = control_inputs.copy()
for_inpainting = bool(prompt.mask_image or prompt.mask_prompt or prompt.outpaint)
if control_inputs:
control_modes = [c.mode for c in prompt.control_inputs]
if inpaint_method == "auto":
if prompt.model_weights in {"SD-1.5"}:
inpaint_method = "finetune"
else:
inpaint_method = "controlnet"
if for_inpainting and inpaint_method == "controlnet":
control_modes.append("inpaint")
model = get_diffusion_model(
weights_location=prompt.model_weights,
config_path=prompt.model_architecture,
control_weights_locations=control_modes,
half_mode=half_mode,
for_inpainting=for_inpainting and inpaint_method == "finetune",
)
is_controlnet_model = hasattr(model, "control_key")
progress_latents = []
def latent_logger(latents):
progress_latents.append(latents)
with ImageLoggingContext(
prompt=prompt,
model=model,
debug_img_callback=debug_img_callback,
progress_img_callback=progress_img_callback,
progress_img_interval_steps=progress_img_interval_steps,
progress_img_interval_min_s=progress_img_interval_min_s,
progress_latent_callback=latent_logger
if prompt.collect_progress_latents
else None,
) as lc:
seed_everything(prompt.seed)
model.tile_mode(prompt.tile_mode)
with lc.timing("conditioning"):
# need to expand if doing batches
neutral_conditioning = _prompts_to_embeddings(prompt.negative_prompt, model)
_prompts_to_embeddings("", model)
log_conditioning(neutral_conditioning, "neutral conditioning")
if prompt.conditioning is not None:
positive_conditioning = prompt.conditioning
else:
positive_conditioning = _prompts_to_embeddings(prompt.prompts, model)
log_conditioning(positive_conditioning, "positive conditioning")
shape = [
batch_size,
latent_channels,
prompt.height // downsampling_factor,
prompt.width // downsampling_factor,
]
SolverCls = SOLVER_LOOKUP[prompt.solver_type.lower()]
solver = SolverCls(model)
mask_image: Image.Image | LazyLoadingImage | None = None
mask_latent = mask_image_orig = mask_grayscale = None
init_latent: torch.Tensor | None = None
t_enc = None
starting_image = None
denoiser_cls = None
c_cat = []
c_cat_neutral = None
result_images: dict[str, torch.Tensor | Image.Image | None] = {}
assert prompt.seed is not None
seed_everything(prompt.seed)
noise = randn_seeded(seed=prompt.seed, size=shape).to(get_device())
control_strengths = []
if prompt.init_image:
starting_image = prompt.init_image
assert prompt.init_image_strength is not None
generation_strength = 1 - prompt.init_image_strength
if model.cond_stage_key == "edit" or generation_strength >= 1:
t_enc = None
else:
t_enc = int(prompt.steps * generation_strength)
if prompt.mask_prompt:
mask_image, mask_grayscale = get_img_mask(
starting_image, prompt.mask_prompt, threshold=0.1
)
elif prompt.mask_image:
mask_image = prompt.mask_image.convert("L")
if prompt.outpaint:
outpaint_kwargs = outpaint_arg_str_parse(prompt.outpaint)
starting_image, mask_image = prepare_image_for_outpaint(
starting_image, mask_image, **outpaint_kwargs
)
assert starting_image is not None
init_image = pillow_fit_image_within(
starting_image,
max_height=prompt.height,
max_width=prompt.width,
)
init_image_t = pillow_img_to_torch_image(init_image).to(get_device())
init_latent = model.get_first_stage_encoding(
model.encode_first_stage(init_image_t)
)
assert init_latent is not None
shape = list(init_latent.shape)
log_latent(init_latent, "init_latent")
if mask_image is not None:
mask_image = pillow_fit_image_within(
mask_image,
max_height=prompt.height,
max_width=prompt.width,
convert="L",
)
log_img(mask_image, "init mask")
if prompt.mask_mode == MaskMode.REPLACE:
mask_image = ImageOps.invert(mask_image)
mask_image_orig = mask_image
log_img(mask_image, "latent_mask")
mask_latent = pillow_mask_to_latent_mask(
mask_image, downsampling_factor=downsampling_factor
).to(get_device())
if inpaint_method == "controlnet":
result_images["control-inpaint"] = mask_image
control_inputs.append(
ControlInput(mode="inpaint", image=mask_image)
)
assert prompt.seed is not None
seed_everything(prompt.seed)
noise = randn_seeded(seed=prompt.seed, size=list(init_latent.shape)).to(
get_device()
)
# noise = noise[:, :, : init_latent.shape[2], : init_latent.shape[3]]
# schedule = NoiseSchedule(
# model_num_timesteps=model.num_timesteps,
# ddim_num_steps=prompt.steps,
# model_alphas_cumprod=model.alphas_cumprod,
# ddim_discretize="uniform",
# )
# if generation_strength >= 1:
# # prompt strength gets converted to time encodings,
# # which means you can't get to true 0 without this hack
# # (or setting steps=1000)
# init_latent_noised = noise
# else:
# init_latent_noised = noise_an_image(
# init_latent,
# torch.tensor([t_enc - 1]).to(get_device()),
# schedule=schedule,
# noise=noise,
# )
if hasattr(model, "depth_stage_key"):
# depth model
depth_t = torch_image_to_depth_map(init_image_t)
depth_latent = torch.nn.functional.interpolate(
depth_t,
size=shape[2:],
mode="bicubic",
align_corners=False,
)
result_images["depth_image"] = depth_t
c_cat.append(depth_latent)
elif is_controlnet_model:
from imaginairy.img_processors.control_modes import CONTROL_MODES
for control_input in control_inputs:
if control_input.image_raw is not None:
control_image = control_input.image_raw
elif control_input.image is not None:
control_image = control_input.image
else:
raise RuntimeError("Control image must be provided")
assert control_image is not None
control_image = control_image.convert("RGB")
log_img(control_image, "control_image_input")
assert control_image is not None
control_image_input = pillow_fit_image_within(
control_image,
max_height=prompt.height,
max_width=prompt.width,
)
control_image_input_t = pillow_img_to_torch_image(control_image_input)
control_image_input_t = control_image_input_t.to(get_device())
if control_input.image_raw is None:
control_prep_function = CONTROL_MODES[control_input.mode]
if control_input.mode == "inpaint":
control_image_t = control_prep_function( # type: ignore
control_image_input_t, init_image_t
)
else:
control_image_t = control_prep_function(control_image_input_t) # type: ignore
else:
control_image_t = (control_image_input_t + 1) / 2
control_image_disp = control_image_t * 2 - 1
result_images[f"control-{control_input.mode}"] = control_image_disp
log_img(control_image_disp, "control_image")
if len(control_image_t.shape) == 3:
raise RuntimeError("Control image must be 4D")
if control_image_t.shape[1] != 3:
raise RuntimeError("Control image must have 3 channels")
if (
control_input.mode != "inpaint"
and control_image_t.min() < 0
or control_image_t.max() > 1
):
msg = f"Control image must be in [0, 1] but we received {control_image_t.min()} and {control_image_t.max()}"
raise RuntimeError(msg)
if control_image_t.max() == control_image_t.min():
msg = f"No control signal found in control image {control_input.mode}."
raise RuntimeError(msg)
c_cat.append(control_image_t)
control_strengths.append(control_input.strength)
elif hasattr(model, "masked_image_key"):
# inpainting model
assert mask_image_orig is not None
assert mask_latent is not None
mask_t = pillow_img_to_torch_image(ImageOps.invert(mask_image_orig)).to(
get_device()
)
inverted_mask = 1 - mask_latent
masked_image_t = init_image_t * (mask_t < 0.5)
log_img(masked_image_t, "masked_image")
inverted_mask_latent = torch.nn.functional.interpolate(
inverted_mask, size=shape[-2:]
)
c_cat.append(inverted_mask_latent)
masked_image_latent = model.get_first_stage_encoding(
model.encode_first_stage(masked_image_t)
)
c_cat.append(masked_image_latent)
elif model.cond_stage_key == "edit":
# pix2pix model
c_cat = [model.encode_first_stage(init_image_t)]
assert init_latent is not None
c_cat_neutral = [torch.zeros_like(init_latent)]
denoiser_cls = CFGEditingDenoiser
if c_cat:
c_cat = [torch.cat([c], dim=1) for c in c_cat]
if c_cat_neutral is None:
c_cat_neutral = c_cat
positive_conditioning_d: dict[str, Any] = {
"c_concat": c_cat,
"c_crossattn": [positive_conditioning],
}
neutral_conditioning_d: dict[str, Any] = {
"c_concat": c_cat_neutral,
"c_crossattn": [neutral_conditioning],
}
del neutral_conditioning
del positive_conditioning
if control_strengths and is_controlnet_model:
positive_conditioning_d["control_strengths"] = torch.Tensor(
control_strengths
)
neutral_conditioning_d["control_strengths"] = torch.Tensor(
control_strengths
)
if (
prompt.allow_compose_phase
and not is_controlnet_model
and model.cond_stage_key != "edit"
):
default_size = get_model_default_image_size(
prompt.model_weights.architecture
)
if prompt.init_image:
comp_image = _generate_composition_image(
prompt=prompt,
target_height=init_image.height,
target_width=init_image.width,
cutoff=default_size,
)
else:
comp_image = _generate_composition_image(
prompt=prompt,
target_height=prompt.height,
target_width=prompt.width,
cutoff=default_size,
)
if comp_image is not None:
result_images["composition"] = comp_image
# noise = noise[:, :, : comp_image.height, : comp_image.shape[3]]
t_enc = int(prompt.steps * 0.65)
log_img(comp_image, "comp_image")
comp_image_t = pillow_img_to_torch_image(comp_image)
comp_image_t = comp_image_t.to(get_device())
init_latent = model.get_first_stage_encoding(
model.encode_first_stage(comp_image_t)
)
with lc.timing("sampling"):
samples = solver.sample(
num_steps=prompt.steps,
positive_conditioning=positive_conditioning_d,
neutral_conditioning=neutral_conditioning_d,
guidance_scale=prompt.prompt_strength,
t_start=t_enc,
mask=mask_latent,
orig_latent=init_latent,
shape=shape,
batch_size=1,
denoiser_cls=denoiser_cls,
noise=noise,
)
if return_latent:
return samples
with lc.timing("decoding"):
gen_imgs_t = model.decode_first_stage(samples)
gen_img = torch_img_to_pillow_img(gen_imgs_t)
if mask_image_orig and init_image:
mask_final = mask_image_orig.copy()
log_img(mask_final, "reconstituting mask")
mask_final = ImageOps.invert(mask_final)
gen_img = Image.composite(gen_img, init_image, mask_final)
gen_img = combine_image(
original_img=init_image,
generated_img=gen_img,
mask_img=mask_image_orig,
)
log_img(gen_img, "reconstituted image")
upscaled_img = None
rebuilt_orig_img = None
if add_caption:
caption = generate_caption(gen_img)
logger.info(f"Generated caption: {caption}")
with lc.timing("safety-filter"):
safety_score = create_safety_score(
gen_img,
safety_mode=IMAGINAIRY_SAFETY_MODE,
)
if safety_score.is_filtered:
progress_latents.clear()
if not safety_score.is_filtered:
if prompt.fix_faces:
logger.info("Fixing 😊 's in 🖼 using CodeFormer...")
with lc.timing("face enhancement"):
gen_img = enhance_faces(gen_img, fidelity=prompt.fix_faces_fidelity)
if prompt.upscale:
logger.info("Upscaling 🖼 using real-ESRGAN...")
with lc.timing("upscaling"):
upscaled_img = upscale_image(gen_img)
# put the newly generated patch back into the original, full-size image
if prompt.mask_modify_original and mask_image_orig and starting_image:
img_to_add_back_to_original = upscaled_img if upscaled_img else gen_img
rebuilt_orig_img = combine_image(
original_img=starting_image,
generated_img=img_to_add_back_to_original,
mask_img=mask_image_orig,
)
if prompt.caption_text:
caption_text = prompt.caption_text.format(prompt=prompt.prompt_text)
add_caption_to_image(gen_img, caption_text)
result_images["upscaled"] = upscaled_img
result_images["modified_original"] = rebuilt_orig_img
result_images["mask_binary"] = mask_image_orig
result_images["mask_grayscale"] = mask_grayscale
result = ImagineResult(
img=gen_img,
prompt=prompt,
is_nsfw=safety_score.is_nsfw,
safety_score=safety_score,
result_images=result_images,
timings=lc.get_timings(),
progress_latents=progress_latents.copy(),
)
_most_recent_result = result
logger.info(f"Image Generated. Timings: {result.timings_str()}")
return result
def _prompts_to_embeddings(prompts, model):
total_weight = sum(wp.weight for wp in prompts)
conditioning = sum(
model.get_learned_conditioning(wp.text) * (wp.weight / total_weight)
for wp in prompts
)
return conditioning
def calc_scale_to_fit_within(height: int, width: int, max_size) -> float:
max_width, max_height = normalize_image_size(max_size)
if width <= max_width and height <= max_height:
return 1
width_ratio = max_width / width
height_ratio = max_height / height
return min(width_ratio, height_ratio)
def _scale_latent(
latent,
model,
h,
w,
):
from torch.nn import functional as F
# convert to non-latent-space first
img = model.decode_first_stage(latent)
img = F.interpolate(img, size=(h, w), mode="bicubic", align_corners=False)
latent = model.get_first_stage_encoding(model.encode_first_stage(img))
return latent
def _generate_composition_image(
prompt,
target_height,
target_width,
cutoff: tuple[int, int] = (512, 512),
dtype=None,
):
from PIL import Image
from imaginairy.api.generate_refiners import _generate_single_image
from imaginairy.utils import default, get_default_dtype
cutoff = normalize_image_size(cutoff)
if prompt.width <= cutoff[0] and prompt.height <= cutoff[1]:
return None, None
dtype = default(dtype, get_default_dtype)
shrink_scale = calc_scale_to_fit_within(
height=prompt.height,
width=prompt.width,
max_size=cutoff,
)
composition_prompt = prompt.full_copy(
deep=True,
update={
"size": (
int(prompt.width * shrink_scale),
int(prompt.height * shrink_scale),
),
"steps": None,
"upscale": False,
"fix_faces": False,
"mask_modify_original": False,
"allow_compose_phase": False,
},
)
result = _generate_single_image(composition_prompt, dtype=dtype)
img = result.images["generated"]
while img.width < target_width:
from imaginairy.enhancers.upscale_realesrgan import upscale_image
img = upscale_image(img)
# samples = _generate_single_image(composition_prompt, return_latent=True)
# while samples.shape[-1] * 8 < target_width:
# samples = upscale_latent(samples)
#
# img = model_latent_to_pillow_img(samples)
img = img.resize(
(target_width, target_height),
resample=Image.Resampling.LANCZOS,
)
return img, result.images["generated"]
def prompt_normalized(prompt, length=130):
return re.sub(r"[^a-zA-Z0-9.,\[\]-]+", "_", prompt)[:length]
def combine_image(original_img, generated_img, mask_img):
"""Combine the generated image with the original image using the mask image."""
from PIL import Image
from imaginairy.utils.log_utils import log_img
generated_img = generated_img.resize(
original_img.size,
resample=Image.Resampling.LANCZOS,
)
mask_for_orig_size = mask_img.resize(
original_img.size,
resample=Image.Resampling.LANCZOS,
)
log_img(mask_for_orig_size, "mask for original image size")
rebuilt_orig_img = Image.composite(
original_img,
generated_img,
mask_for_orig_size,
)
log_img(rebuilt_orig_img, "reconstituted original")
return rebuilt_orig_img

View File

@ -0,0 +1,547 @@
from typing import Any
from imaginairy.api.generate import (
IMAGINAIRY_SAFETY_MODE,
logger,
)
from imaginairy.api.generate_refiners import _generate_composition_image
from imaginairy.schema import ImaginePrompt, LazyLoadingImage
from imaginairy.utils.img_utils import calc_scale_to_fit_within, combine_image
from imaginairy.utils.named_resolutions import normalize_image_size
def _generate_single_image_compvis(
prompt: "ImaginePrompt",
debug_img_callback=None,
progress_img_callback=None,
progress_img_interval_steps=3,
progress_img_interval_min_s=0.1,
half_mode=None,
add_caption=False,
# controlnet, finetune, naive, auto
inpaint_method="finetune",
return_latent=False,
):
import torch.nn
from PIL import Image, ImageOps
from pytorch_lightning import seed_everything
from imaginairy.enhancers.clip_masking import get_img_mask
from imaginairy.enhancers.describe_image_blip import generate_caption
from imaginairy.enhancers.face_restoration_codeformer import enhance_faces
from imaginairy.enhancers.upscale_realesrgan import upscale_image
from imaginairy.modules.midas.api import torch_image_to_depth_map
from imaginairy.samplers import SOLVER_LOOKUP
from imaginairy.samplers.editing import CFGEditingDenoiser
from imaginairy.schema import ControlInput, ImagineResult, MaskMode
from imaginairy.utils import get_device, randn_seeded
from imaginairy.utils.img_utils import (
add_caption_to_image,
pillow_fit_image_within,
pillow_img_to_torch_image,
pillow_mask_to_latent_mask,
torch_img_to_pillow_img,
)
from imaginairy.utils.log_utils import (
ImageLoggingContext,
log_conditioning,
log_img,
log_latent,
)
from imaginairy.utils.model_manager import (
get_diffusion_model,
get_model_default_image_size,
)
from imaginairy.utils.outpaint import (
outpaint_arg_str_parse,
prepare_image_for_outpaint,
)
from imaginairy.utils.safety import create_safety_score
latent_channels = 4
downsampling_factor = 8
batch_size = 1
global _most_recent_result
# handle prompt pulling in previous values
# if isinstance(prompt.init_image, str) and prompt.init_image.startswith("*prev"):
# _, img_type = prompt.init_image.strip("*").split(".")
# prompt.init_image = _most_recent_result.images[img_type]
# if isinstance(prompt.mask_image, str) and prompt.mask_image.startswith("*prev"):
# _, img_type = prompt.mask_image.strip("*").split(".")
# prompt.mask_image = _most_recent_result.images[img_type]
prompt = prompt.make_concrete_copy()
control_modes = []
control_inputs = prompt.control_inputs or []
control_inputs = control_inputs.copy()
for_inpainting = bool(prompt.mask_image or prompt.mask_prompt or prompt.outpaint)
if control_inputs:
control_modes = [c.mode for c in prompt.control_inputs]
if inpaint_method == "auto":
if prompt.model_weights in {"SD-1.5"}:
inpaint_method = "finetune"
else:
inpaint_method = "controlnet"
if for_inpainting and inpaint_method == "controlnet":
control_modes.append("inpaint")
model = get_diffusion_model(
weights_location=prompt.model_weights,
config_path=prompt.model_architecture,
control_weights_locations=control_modes,
half_mode=half_mode,
for_inpainting=for_inpainting and inpaint_method == "finetune",
)
is_controlnet_model = hasattr(model, "control_key")
progress_latents = []
def latent_logger(latents):
progress_latents.append(latents)
with ImageLoggingContext(
prompt=prompt,
model=model,
debug_img_callback=debug_img_callback,
progress_img_callback=progress_img_callback,
progress_img_interval_steps=progress_img_interval_steps,
progress_img_interval_min_s=progress_img_interval_min_s,
progress_latent_callback=latent_logger
if prompt.collect_progress_latents
else None,
) as lc:
seed_everything(prompt.seed)
model.tile_mode(prompt.tile_mode)
with lc.timing("conditioning"):
# need to expand if doing batches
neutral_conditioning = _prompts_to_embeddings(prompt.negative_prompt, model)
_prompts_to_embeddings("", model)
log_conditioning(neutral_conditioning, "neutral conditioning")
if prompt.conditioning is not None:
positive_conditioning = prompt.conditioning
else:
positive_conditioning = _prompts_to_embeddings(prompt.prompts, model)
log_conditioning(positive_conditioning, "positive conditioning")
shape = [
batch_size,
latent_channels,
prompt.height // downsampling_factor,
prompt.width // downsampling_factor,
]
SolverCls = SOLVER_LOOKUP[prompt.solver_type.lower()]
solver = SolverCls(model)
mask_image: Image.Image | LazyLoadingImage | None = None
mask_latent = mask_image_orig = mask_grayscale = None
init_latent: torch.Tensor | None = None
t_enc = None
starting_image = None
denoiser_cls = None
c_cat = []
c_cat_neutral = None
result_images: dict[str, torch.Tensor | Image.Image | None] = {}
assert prompt.seed is not None
seed_everything(prompt.seed)
noise = randn_seeded(seed=prompt.seed, size=shape).to(get_device())
control_strengths = []
if prompt.init_image:
starting_image = prompt.init_image
assert prompt.init_image_strength is not None
generation_strength = 1 - prompt.init_image_strength
if model.cond_stage_key == "edit" or generation_strength >= 1:
t_enc = None
else:
t_enc = int(prompt.steps * generation_strength)
if prompt.mask_prompt:
mask_image, mask_grayscale = get_img_mask(
starting_image, prompt.mask_prompt, threshold=0.1
)
elif prompt.mask_image:
mask_image = prompt.mask_image.convert("L")
if prompt.outpaint:
outpaint_kwargs = outpaint_arg_str_parse(prompt.outpaint)
starting_image, mask_image = prepare_image_for_outpaint(
starting_image, mask_image, **outpaint_kwargs
)
assert starting_image is not None
init_image = pillow_fit_image_within(
starting_image,
max_height=prompt.height,
max_width=prompt.width,
)
init_image_t = pillow_img_to_torch_image(init_image).to(get_device())
init_latent = model.get_first_stage_encoding(
model.encode_first_stage(init_image_t)
)
assert init_latent is not None
shape = list(init_latent.shape)
log_latent(init_latent, "init_latent")
if mask_image is not None:
mask_image = pillow_fit_image_within(
mask_image,
max_height=prompt.height,
max_width=prompt.width,
convert="L",
)
log_img(mask_image, "init mask")
if prompt.mask_mode == MaskMode.REPLACE:
mask_image = ImageOps.invert(mask_image)
mask_image_orig = mask_image
log_img(mask_image, "latent_mask")
mask_latent = pillow_mask_to_latent_mask(
mask_image, downsampling_factor=downsampling_factor
).to(get_device())
if inpaint_method == "controlnet":
result_images["control-inpaint"] = mask_image
control_inputs.append(
ControlInput(mode="inpaint", image=mask_image)
)
assert prompt.seed is not None
seed_everything(prompt.seed)
noise = randn_seeded(seed=prompt.seed, size=list(init_latent.shape)).to(
get_device()
)
# noise = noise[:, :, : init_latent.shape[2], : init_latent.shape[3]]
# schedule = NoiseSchedule(
# model_num_timesteps=model.num_timesteps,
# ddim_num_steps=prompt.steps,
# model_alphas_cumprod=model.alphas_cumprod,
# ddim_discretize="uniform",
# )
# if generation_strength >= 1:
# # prompt strength gets converted to time encodings,
# # which means you can't get to true 0 without this hack
# # (or setting steps=1000)
# init_latent_noised = noise
# else:
# init_latent_noised = noise_an_image(
# init_latent,
# torch.tensor([t_enc - 1]).to(get_device()),
# schedule=schedule,
# noise=noise,
# )
if hasattr(model, "depth_stage_key"):
# depth model
depth_t = torch_image_to_depth_map(init_image_t)
depth_latent = torch.nn.functional.interpolate(
depth_t,
size=shape[2:],
mode="bicubic",
align_corners=False,
)
result_images["depth_image"] = depth_t
c_cat.append(depth_latent)
elif is_controlnet_model:
from imaginairy.img_processors.control_modes import CONTROL_MODES
for control_input in control_inputs:
if control_input.image_raw is not None:
control_image = control_input.image_raw
elif control_input.image is not None:
control_image = control_input.image
else:
raise RuntimeError("Control image must be provided")
assert control_image is not None
control_image = control_image.convert("RGB")
log_img(control_image, "control_image_input")
assert control_image is not None
control_image_input = pillow_fit_image_within(
control_image,
max_height=prompt.height,
max_width=prompt.width,
)
control_image_input_t = pillow_img_to_torch_image(control_image_input)
control_image_input_t = control_image_input_t.to(get_device())
if control_input.image_raw is None:
control_prep_function = CONTROL_MODES[control_input.mode]
if control_input.mode == "inpaint":
control_image_t = control_prep_function( # type: ignore
control_image_input_t, init_image_t
)
else:
control_image_t = control_prep_function(control_image_input_t) # type: ignore
else:
control_image_t = (control_image_input_t + 1) / 2
control_image_disp = control_image_t * 2 - 1
result_images[f"control-{control_input.mode}"] = control_image_disp
log_img(control_image_disp, "control_image")
if len(control_image_t.shape) == 3:
raise RuntimeError("Control image must be 4D")
if control_image_t.shape[1] != 3:
raise RuntimeError("Control image must have 3 channels")
if (
control_input.mode != "inpaint"
and control_image_t.min() < 0
or control_image_t.max() > 1
):
msg = f"Control image must be in [0, 1] but we received {control_image_t.min()} and {control_image_t.max()}"
raise RuntimeError(msg)
if control_image_t.max() == control_image_t.min():
msg = f"No control signal found in control image {control_input.mode}."
raise RuntimeError(msg)
c_cat.append(control_image_t)
control_strengths.append(control_input.strength)
elif hasattr(model, "masked_image_key"):
# inpainting model
assert mask_image_orig is not None
assert mask_latent is not None
mask_t = pillow_img_to_torch_image(ImageOps.invert(mask_image_orig)).to(
get_device()
)
inverted_mask = 1 - mask_latent
masked_image_t = init_image_t * (mask_t < 0.5)
log_img(masked_image_t, "masked_image")
inverted_mask_latent = torch.nn.functional.interpolate(
inverted_mask, size=shape[-2:]
)
c_cat.append(inverted_mask_latent)
masked_image_latent = model.get_first_stage_encoding(
model.encode_first_stage(masked_image_t)
)
c_cat.append(masked_image_latent)
elif model.cond_stage_key == "edit":
# pix2pix model
c_cat = [model.encode_first_stage(init_image_t)]
assert init_latent is not None
c_cat_neutral = [torch.zeros_like(init_latent)]
denoiser_cls = CFGEditingDenoiser
if c_cat:
c_cat = [torch.cat([c], dim=1) for c in c_cat]
if c_cat_neutral is None:
c_cat_neutral = c_cat
positive_conditioning_d: dict[str, Any] = {
"c_concat": c_cat,
"c_crossattn": [positive_conditioning],
}
neutral_conditioning_d: dict[str, Any] = {
"c_concat": c_cat_neutral,
"c_crossattn": [neutral_conditioning],
}
del neutral_conditioning
del positive_conditioning
if control_strengths and is_controlnet_model:
positive_conditioning_d["control_strengths"] = torch.Tensor(
control_strengths
)
neutral_conditioning_d["control_strengths"] = torch.Tensor(
control_strengths
)
if (
prompt.allow_compose_phase
and not is_controlnet_model
and model.cond_stage_key != "edit"
):
default_size = get_model_default_image_size(
prompt.model_weights.architecture
)
if prompt.init_image:
comp_image = _generate_composition_image(
prompt=prompt,
target_height=init_image.height,
target_width=init_image.width,
cutoff=default_size,
)
else:
comp_image = _generate_composition_image(
prompt=prompt,
target_height=prompt.height,
target_width=prompt.width,
cutoff=default_size,
)
if comp_image is not None:
result_images["composition"] = comp_image
# noise = noise[:, :, : comp_image.height, : comp_image.shape[3]]
t_enc = int(prompt.steps * 0.65)
log_img(comp_image, "comp_image")
comp_image_t = pillow_img_to_torch_image(comp_image)
comp_image_t = comp_image_t.to(get_device())
init_latent = model.get_first_stage_encoding(
model.encode_first_stage(comp_image_t)
)
with lc.timing("sampling"):
samples = solver.sample(
num_steps=prompt.steps,
positive_conditioning=positive_conditioning_d,
neutral_conditioning=neutral_conditioning_d,
guidance_scale=prompt.prompt_strength,
t_start=t_enc,
mask=mask_latent,
orig_latent=init_latent,
shape=shape,
batch_size=1,
denoiser_cls=denoiser_cls,
noise=noise,
)
if return_latent:
return samples
with lc.timing("decoding"):
gen_imgs_t = model.decode_first_stage(samples)
gen_img = torch_img_to_pillow_img(gen_imgs_t)
if mask_image_orig and init_image:
mask_final = mask_image_orig.copy()
log_img(mask_final, "reconstituting mask")
mask_final = ImageOps.invert(mask_final)
gen_img = Image.composite(gen_img, init_image, mask_final)
gen_img = combine_image(
original_img=init_image,
generated_img=gen_img,
mask_img=mask_image_orig,
)
log_img(gen_img, "reconstituted image")
upscaled_img = None
rebuilt_orig_img = None
if add_caption:
caption = generate_caption(gen_img)
logger.info(f"Generated caption: {caption}")
with lc.timing("safety-filter"):
safety_score = create_safety_score(
gen_img,
safety_mode=IMAGINAIRY_SAFETY_MODE,
)
if safety_score.is_filtered:
progress_latents.clear()
if not safety_score.is_filtered:
if prompt.fix_faces:
logger.info("Fixing 😊 's in 🖼 using CodeFormer...")
with lc.timing("face enhancement"):
gen_img = enhance_faces(gen_img, fidelity=prompt.fix_faces_fidelity)
if prompt.upscale:
logger.info("Upscaling 🖼 using real-ESRGAN...")
with lc.timing("upscaling"):
upscaled_img = upscale_image(gen_img)
# put the newly generated patch back into the original, full-size image
if prompt.mask_modify_original and mask_image_orig and starting_image:
img_to_add_back_to_original = upscaled_img if upscaled_img else gen_img
rebuilt_orig_img = combine_image(
original_img=starting_image,
generated_img=img_to_add_back_to_original,
mask_img=mask_image_orig,
)
if prompt.caption_text:
caption_text = prompt.caption_text.format(prompt=prompt.prompt_text)
add_caption_to_image(gen_img, caption_text)
result_images["upscaled"] = upscaled_img
result_images["modified_original"] = rebuilt_orig_img
result_images["mask_binary"] = mask_image_orig
result_images["mask_grayscale"] = mask_grayscale
result = ImagineResult(
img=gen_img,
prompt=prompt,
is_nsfw=safety_score.is_nsfw,
safety_score=safety_score,
result_images=result_images,
timings=lc.get_timings(),
progress_latents=progress_latents.copy(),
)
_most_recent_result = result
logger.info(f"Image Generated. Timings: {result.timings_str()}")
return result
def _prompts_to_embeddings(prompts, model):
total_weight = sum(wp.weight for wp in prompts)
conditioning = sum(
model.get_learned_conditioning(wp.text) * (wp.weight / total_weight)
for wp in prompts
)
return conditioning
def _generate_composition_image(
prompt,
target_height,
target_width,
cutoff: tuple[int, int] = (512, 512),
dtype=None,
):
from PIL import Image
from imaginairy.api.generate_refiners import generate_single_image
from imaginairy.utils import default, get_default_dtype
cutoff = normalize_image_size(cutoff)
if prompt.width <= cutoff[0] and prompt.height <= cutoff[1]:
return None, None
dtype = default(dtype, get_default_dtype)
shrink_scale = calc_scale_to_fit_within(
height=prompt.height,
width=prompt.width,
max_size=cutoff,
)
composition_prompt = prompt.full_copy(
deep=True,
update={
"size": (
int(prompt.width * shrink_scale),
int(prompt.height * shrink_scale),
),
"steps": None,
"upscale": False,
"fix_faces": False,
"mask_modify_original": False,
"allow_compose_phase": False,
},
)
result = generate_single_image(composition_prompt, dtype=dtype)
img = result.images["generated"]
while img.width < target_width:
from imaginairy.enhancers.upscale_realesrgan import upscale_image
img = upscale_image(img)
# samples = generate_single_image(composition_prompt, return_latent=True)
# while samples.shape[-1] * 8 < target_width:
# samples = upscale_latent(samples)
#
# img = model_latent_to_pillow_img(samples)
img = img.resize(
(target_width, target_height),
resample=Image.Resampling.LANCZOS,
)
return img, result.images["generated"]

View File

@ -5,11 +5,13 @@ from typing import List, Optional
from imaginairy.config import CONTROL_CONFIG_SHORTCUTS from imaginairy.config import CONTROL_CONFIG_SHORTCUTS
from imaginairy.schema import ControlInput, ImaginePrompt, MaskMode, WeightedPrompt from imaginairy.schema import ControlInput, ImaginePrompt, MaskMode, WeightedPrompt
from imaginairy.utils.img_utils import calc_scale_to_fit_within
from imaginairy.utils.named_resolutions import normalize_image_size
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _generate_single_image( def generate_single_image(
prompt: ImaginePrompt, prompt: ImaginePrompt,
debug_img_callback=None, debug_img_callback=None,
progress_img_callback=None, progress_img_callback=None,
@ -28,8 +30,6 @@ def _generate_single_image(
from imaginairy.api.generate import ( from imaginairy.api.generate import (
IMAGINAIRY_SAFETY_MODE, IMAGINAIRY_SAFETY_MODE,
_generate_composition_image,
combine_image,
) )
from imaginairy.enhancers.clip_masking import get_img_mask from imaginairy.enhancers.clip_masking import get_img_mask
from imaginairy.enhancers.describe_image_blip import generate_caption from imaginairy.enhancers.describe_image_blip import generate_caption
@ -40,6 +40,7 @@ def _generate_single_image(
from imaginairy.utils import get_device, randn_seeded from imaginairy.utils import get_device, randn_seeded
from imaginairy.utils.img_utils import ( from imaginairy.utils.img_utils import (
add_caption_to_image, add_caption_to_image,
combine_image,
pillow_fit_image_within, pillow_fit_image_within,
pillow_img_to_torch_image, pillow_img_to_torch_image,
pillow_mask_to_latent_mask, pillow_mask_to_latent_mask,
@ -523,3 +524,64 @@ def prep_control_input(
) )
controlnet.set_scale(control_input.strength) controlnet.set_scale(control_input.strength)
return controlnet, control_image_t, control_image_disp return controlnet, control_image_t, control_image_disp
def _generate_composition_image(
prompt,
target_height,
target_width,
cutoff: tuple[int, int] = (512, 512),
dtype=None,
):
from PIL import Image
from imaginairy.api.generate_refiners import generate_single_image
from imaginairy.utils import default, get_default_dtype
cutoff = normalize_image_size(cutoff)
if prompt.width <= cutoff[0] and prompt.height <= cutoff[1]:
return None, None
dtype = default(dtype, get_default_dtype)
shrink_scale = calc_scale_to_fit_within(
height=prompt.height,
width=prompt.width,
max_size=cutoff,
)
composition_prompt = prompt.full_copy(
deep=True,
update={
"size": (
int(prompt.width * shrink_scale),
int(prompt.height * shrink_scale),
),
"steps": None,
"upscale": False,
"fix_faces": False,
"mask_modify_original": False,
"allow_compose_phase": False,
"caption_text": None,
},
)
result = generate_single_image(composition_prompt, dtype=dtype)
img = result.images["generated"]
while img.width < target_width:
from imaginairy.enhancers.upscale_realesrgan import upscale_image
img = upscale_image(img)
# samples = generate_single_image(composition_prompt, return_latent=True)
# while samples.shape[-1] * 8 < target_width:
# samples = upscale_latent(samples)
#
# img = model_latent_to_pillow_img(samples)
img = img.resize(
(target_width, target_height),
resample=Image.Resampling.LANCZOS,
)
return img, result.images["generated"]

View File

@ -1,6 +1,7 @@
import importlib import importlib
import logging import logging
import platform import platform
import re
import time import time
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
from functools import lru_cache from functools import lru_cache
@ -315,3 +316,7 @@ def get_nested_attribute(obj, attribute_path, depth=None, return_key=False):
current_attribute = getattr(current_attribute, attribute) current_attribute = getattr(current_attribute, attribute)
return (current_attribute, current_key) if return_key else current_attribute return (current_attribute, current_key) if return_key else current_attribute
def prompt_normalized(prompt, length=130):
return re.sub(r"[^a-zA-Z0-9.,\[\]-]+", "_", prompt)[:length]

View File

@ -19,6 +19,7 @@ from PIL import Image, ImageDraw, ImageFont
from imaginairy.schema import LazyLoadingImage from imaginairy.schema import LazyLoadingImage
from imaginairy.utils import get_device from imaginairy.utils import get_device
from imaginairy.utils.named_resolutions import normalize_image_size
from imaginairy.utils.paths import PKG_ROOT from imaginairy.utils.paths import PKG_ROOT
@ -221,3 +222,39 @@ def create_halo_effect(
new_canvas.paste(transparent_image, (0, 0), transparent_image) new_canvas.paste(transparent_image, (0, 0), transparent_image)
return new_canvas return new_canvas
def combine_image(original_img, generated_img, mask_img):
"""Combine the generated image with the original image using the mask image."""
from PIL import Image
from imaginairy.utils.log_utils import log_img
generated_img = generated_img.resize(
original_img.size,
resample=Image.Resampling.LANCZOS,
)
mask_for_orig_size = mask_img.resize(
original_img.size,
resample=Image.Resampling.LANCZOS,
)
log_img(mask_for_orig_size, "mask for original image size")
rebuilt_orig_img = Image.composite(
original_img,
generated_img,
mask_for_orig_size,
)
return rebuilt_orig_img
def calc_scale_to_fit_within(height: int, width: int, max_size) -> float:
max_width, max_height = normalize_image_size(max_size)
if width <= max_width and height <= max_height:
return 1
width_ratio = max_width / width
height_ratio = max_height / height
return min(width_ratio, height_ratio)

View File

@ -2,7 +2,7 @@
import contextlib import contextlib
_NAMED_RESOLUTIONS = { NAMED_RESOLUTIONS = {
"HD": (1280, 720), "HD": (1280, 720),
"FHD": (1920, 1080), "FHD": (1920, 1080),
"HALF-FHD": (960, 540), "HALF-FHD": (960, 540),
@ -46,7 +46,7 @@ _NAMED_RESOLUTIONS = {
"SVD": (1024, 576), # stable video diffusion "SVD": (1024, 576), # stable video diffusion
} }
_NAMED_RESOLUTIONS = {k.upper(): v for k, v in _NAMED_RESOLUTIONS.items()} NAMED_RESOLUTIONS = {k.upper(): v for k, v in NAMED_RESOLUTIONS.items()}
def normalize_image_size(resolution: str | int | tuple[int, int]) -> tuple[int, int]: def normalize_image_size(resolution: str | int | tuple[int, int]) -> tuple[int, int]:
@ -66,8 +66,8 @@ def _normalize_image_size(resolution: str | int | tuple[int, int]) -> tuple[int,
case str(): case str():
resolution = resolution.strip().upper() resolution = resolution.strip().upper()
resolution = resolution.replace(" ", "").replace("X", ",").replace("*", ",") resolution = resolution.replace(" ", "").replace("X", ",").replace("*", ",")
if resolution.upper() in _NAMED_RESOLUTIONS: if resolution.upper() in NAMED_RESOLUTIONS:
return _NAMED_RESOLUTIONS[resolution.upper()] return NAMED_RESOLUTIONS[resolution.upper()]
# is it WIDTH,HEIGHT format? # is it WIDTH,HEIGHT format?
try: try: