mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-11-19 03:25:41 +00:00
refactor: move code to more intuitive places
This commit is contained in:
parent
8cfb46d6de
commit
6ebd12abb1
@ -2,13 +2,9 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
from typing import Callable
|
||||||
from typing import TYPE_CHECKING, Any, Callable
|
|
||||||
|
|
||||||
from imaginairy.utils.named_resolutions import normalize_image_size
|
from imaginairy.utils import prompt_normalized
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from imaginairy.schema import ImaginePrompt, LazyLoadingImage
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -160,7 +156,7 @@ def imagine(
|
|||||||
):
|
):
|
||||||
import torch.nn
|
import torch.nn
|
||||||
|
|
||||||
from imaginairy.api.generate_refiners import _generate_single_image
|
from imaginairy.api.generate_refiners import generate_single_image
|
||||||
from imaginairy.schema import ImaginePrompt
|
from imaginairy.schema import ImaginePrompt
|
||||||
from imaginairy.utils import (
|
from imaginairy.utils import (
|
||||||
check_torch_version,
|
check_torch_version,
|
||||||
@ -199,7 +195,7 @@ def imagine(
|
|||||||
for attempt in range(unsafe_retry_count + 1):
|
for attempt in range(unsafe_retry_count + 1):
|
||||||
if attempt > 0 and isinstance(prompt.seed, int):
|
if attempt > 0 and isinstance(prompt.seed, int):
|
||||||
prompt.seed += 100_000_000 + attempt
|
prompt.seed += 100_000_000 + attempt
|
||||||
result = _generate_single_image(
|
result = generate_single_image(
|
||||||
prompt,
|
prompt,
|
||||||
debug_img_callback=debug_img_callback,
|
debug_img_callback=debug_img_callback,
|
||||||
progress_img_callback=progress_img_callback,
|
progress_img_callback=progress_img_callback,
|
||||||
@ -215,596 +211,3 @@ def imagine(
|
|||||||
logger.info(" Image was unsafe, retrying with new seed...")
|
logger.info(" Image was unsafe, retrying with new seed...")
|
||||||
|
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
|
|
||||||
def _generate_single_image_compvis(
|
|
||||||
prompt: "ImaginePrompt",
|
|
||||||
debug_img_callback=None,
|
|
||||||
progress_img_callback=None,
|
|
||||||
progress_img_interval_steps=3,
|
|
||||||
progress_img_interval_min_s=0.1,
|
|
||||||
half_mode=None,
|
|
||||||
add_caption=False,
|
|
||||||
# controlnet, finetune, naive, auto
|
|
||||||
inpaint_method="finetune",
|
|
||||||
return_latent=False,
|
|
||||||
):
|
|
||||||
import torch.nn
|
|
||||||
from PIL import Image, ImageOps
|
|
||||||
from pytorch_lightning import seed_everything
|
|
||||||
|
|
||||||
from imaginairy.enhancers.clip_masking import get_img_mask
|
|
||||||
from imaginairy.enhancers.describe_image_blip import generate_caption
|
|
||||||
from imaginairy.enhancers.face_restoration_codeformer import enhance_faces
|
|
||||||
from imaginairy.enhancers.upscale_realesrgan import upscale_image
|
|
||||||
from imaginairy.modules.midas.api import torch_image_to_depth_map
|
|
||||||
from imaginairy.samplers import SOLVER_LOOKUP
|
|
||||||
from imaginairy.samplers.editing import CFGEditingDenoiser
|
|
||||||
from imaginairy.schema import ControlInput, ImagineResult, MaskMode
|
|
||||||
from imaginairy.utils import get_device, randn_seeded
|
|
||||||
from imaginairy.utils.img_utils import (
|
|
||||||
add_caption_to_image,
|
|
||||||
pillow_fit_image_within,
|
|
||||||
pillow_img_to_torch_image,
|
|
||||||
pillow_mask_to_latent_mask,
|
|
||||||
torch_img_to_pillow_img,
|
|
||||||
)
|
|
||||||
from imaginairy.utils.log_utils import (
|
|
||||||
ImageLoggingContext,
|
|
||||||
log_conditioning,
|
|
||||||
log_img,
|
|
||||||
log_latent,
|
|
||||||
)
|
|
||||||
from imaginairy.utils.model_manager import (
|
|
||||||
get_diffusion_model,
|
|
||||||
get_model_default_image_size,
|
|
||||||
)
|
|
||||||
from imaginairy.utils.outpaint import (
|
|
||||||
outpaint_arg_str_parse,
|
|
||||||
prepare_image_for_outpaint,
|
|
||||||
)
|
|
||||||
from imaginairy.utils.safety import create_safety_score
|
|
||||||
|
|
||||||
latent_channels = 4
|
|
||||||
downsampling_factor = 8
|
|
||||||
batch_size = 1
|
|
||||||
global _most_recent_result
|
|
||||||
# handle prompt pulling in previous values
|
|
||||||
# if isinstance(prompt.init_image, str) and prompt.init_image.startswith("*prev"):
|
|
||||||
# _, img_type = prompt.init_image.strip("*").split(".")
|
|
||||||
# prompt.init_image = _most_recent_result.images[img_type]
|
|
||||||
# if isinstance(prompt.mask_image, str) and prompt.mask_image.startswith("*prev"):
|
|
||||||
# _, img_type = prompt.mask_image.strip("*").split(".")
|
|
||||||
# prompt.mask_image = _most_recent_result.images[img_type]
|
|
||||||
prompt = prompt.make_concrete_copy()
|
|
||||||
|
|
||||||
control_modes = []
|
|
||||||
control_inputs = prompt.control_inputs or []
|
|
||||||
control_inputs = control_inputs.copy()
|
|
||||||
for_inpainting = bool(prompt.mask_image or prompt.mask_prompt or prompt.outpaint)
|
|
||||||
|
|
||||||
if control_inputs:
|
|
||||||
control_modes = [c.mode for c in prompt.control_inputs]
|
|
||||||
if inpaint_method == "auto":
|
|
||||||
if prompt.model_weights in {"SD-1.5"}:
|
|
||||||
inpaint_method = "finetune"
|
|
||||||
else:
|
|
||||||
inpaint_method = "controlnet"
|
|
||||||
|
|
||||||
if for_inpainting and inpaint_method == "controlnet":
|
|
||||||
control_modes.append("inpaint")
|
|
||||||
model = get_diffusion_model(
|
|
||||||
weights_location=prompt.model_weights,
|
|
||||||
config_path=prompt.model_architecture,
|
|
||||||
control_weights_locations=control_modes,
|
|
||||||
half_mode=half_mode,
|
|
||||||
for_inpainting=for_inpainting and inpaint_method == "finetune",
|
|
||||||
)
|
|
||||||
is_controlnet_model = hasattr(model, "control_key")
|
|
||||||
|
|
||||||
progress_latents = []
|
|
||||||
|
|
||||||
def latent_logger(latents):
|
|
||||||
progress_latents.append(latents)
|
|
||||||
|
|
||||||
with ImageLoggingContext(
|
|
||||||
prompt=prompt,
|
|
||||||
model=model,
|
|
||||||
debug_img_callback=debug_img_callback,
|
|
||||||
progress_img_callback=progress_img_callback,
|
|
||||||
progress_img_interval_steps=progress_img_interval_steps,
|
|
||||||
progress_img_interval_min_s=progress_img_interval_min_s,
|
|
||||||
progress_latent_callback=latent_logger
|
|
||||||
if prompt.collect_progress_latents
|
|
||||||
else None,
|
|
||||||
) as lc:
|
|
||||||
seed_everything(prompt.seed)
|
|
||||||
|
|
||||||
model.tile_mode(prompt.tile_mode)
|
|
||||||
with lc.timing("conditioning"):
|
|
||||||
# need to expand if doing batches
|
|
||||||
neutral_conditioning = _prompts_to_embeddings(prompt.negative_prompt, model)
|
|
||||||
_prompts_to_embeddings("", model)
|
|
||||||
log_conditioning(neutral_conditioning, "neutral conditioning")
|
|
||||||
if prompt.conditioning is not None:
|
|
||||||
positive_conditioning = prompt.conditioning
|
|
||||||
else:
|
|
||||||
positive_conditioning = _prompts_to_embeddings(prompt.prompts, model)
|
|
||||||
log_conditioning(positive_conditioning, "positive conditioning")
|
|
||||||
|
|
||||||
shape = [
|
|
||||||
batch_size,
|
|
||||||
latent_channels,
|
|
||||||
prompt.height // downsampling_factor,
|
|
||||||
prompt.width // downsampling_factor,
|
|
||||||
]
|
|
||||||
SolverCls = SOLVER_LOOKUP[prompt.solver_type.lower()]
|
|
||||||
solver = SolverCls(model)
|
|
||||||
mask_image: Image.Image | LazyLoadingImage | None = None
|
|
||||||
mask_latent = mask_image_orig = mask_grayscale = None
|
|
||||||
init_latent: torch.Tensor | None = None
|
|
||||||
t_enc = None
|
|
||||||
starting_image = None
|
|
||||||
denoiser_cls = None
|
|
||||||
|
|
||||||
c_cat = []
|
|
||||||
c_cat_neutral = None
|
|
||||||
result_images: dict[str, torch.Tensor | Image.Image | None] = {}
|
|
||||||
assert prompt.seed is not None
|
|
||||||
seed_everything(prompt.seed)
|
|
||||||
noise = randn_seeded(seed=prompt.seed, size=shape).to(get_device())
|
|
||||||
control_strengths = []
|
|
||||||
|
|
||||||
if prompt.init_image:
|
|
||||||
starting_image = prompt.init_image
|
|
||||||
assert prompt.init_image_strength is not None
|
|
||||||
generation_strength = 1 - prompt.init_image_strength
|
|
||||||
|
|
||||||
if model.cond_stage_key == "edit" or generation_strength >= 1:
|
|
||||||
t_enc = None
|
|
||||||
else:
|
|
||||||
t_enc = int(prompt.steps * generation_strength)
|
|
||||||
|
|
||||||
if prompt.mask_prompt:
|
|
||||||
mask_image, mask_grayscale = get_img_mask(
|
|
||||||
starting_image, prompt.mask_prompt, threshold=0.1
|
|
||||||
)
|
|
||||||
elif prompt.mask_image:
|
|
||||||
mask_image = prompt.mask_image.convert("L")
|
|
||||||
if prompt.outpaint:
|
|
||||||
outpaint_kwargs = outpaint_arg_str_parse(prompt.outpaint)
|
|
||||||
starting_image, mask_image = prepare_image_for_outpaint(
|
|
||||||
starting_image, mask_image, **outpaint_kwargs
|
|
||||||
)
|
|
||||||
assert starting_image is not None
|
|
||||||
init_image = pillow_fit_image_within(
|
|
||||||
starting_image,
|
|
||||||
max_height=prompt.height,
|
|
||||||
max_width=prompt.width,
|
|
||||||
)
|
|
||||||
init_image_t = pillow_img_to_torch_image(init_image).to(get_device())
|
|
||||||
init_latent = model.get_first_stage_encoding(
|
|
||||||
model.encode_first_stage(init_image_t)
|
|
||||||
)
|
|
||||||
assert init_latent is not None
|
|
||||||
shape = list(init_latent.shape)
|
|
||||||
|
|
||||||
log_latent(init_latent, "init_latent")
|
|
||||||
|
|
||||||
if mask_image is not None:
|
|
||||||
mask_image = pillow_fit_image_within(
|
|
||||||
mask_image,
|
|
||||||
max_height=prompt.height,
|
|
||||||
max_width=prompt.width,
|
|
||||||
convert="L",
|
|
||||||
)
|
|
||||||
|
|
||||||
log_img(mask_image, "init mask")
|
|
||||||
|
|
||||||
if prompt.mask_mode == MaskMode.REPLACE:
|
|
||||||
mask_image = ImageOps.invert(mask_image)
|
|
||||||
|
|
||||||
mask_image_orig = mask_image
|
|
||||||
log_img(mask_image, "latent_mask")
|
|
||||||
mask_latent = pillow_mask_to_latent_mask(
|
|
||||||
mask_image, downsampling_factor=downsampling_factor
|
|
||||||
).to(get_device())
|
|
||||||
if inpaint_method == "controlnet":
|
|
||||||
result_images["control-inpaint"] = mask_image
|
|
||||||
control_inputs.append(
|
|
||||||
ControlInput(mode="inpaint", image=mask_image)
|
|
||||||
)
|
|
||||||
assert prompt.seed is not None
|
|
||||||
seed_everything(prompt.seed)
|
|
||||||
noise = randn_seeded(seed=prompt.seed, size=list(init_latent.shape)).to(
|
|
||||||
get_device()
|
|
||||||
)
|
|
||||||
# noise = noise[:, :, : init_latent.shape[2], : init_latent.shape[3]]
|
|
||||||
|
|
||||||
# schedule = NoiseSchedule(
|
|
||||||
# model_num_timesteps=model.num_timesteps,
|
|
||||||
# ddim_num_steps=prompt.steps,
|
|
||||||
# model_alphas_cumprod=model.alphas_cumprod,
|
|
||||||
# ddim_discretize="uniform",
|
|
||||||
# )
|
|
||||||
# if generation_strength >= 1:
|
|
||||||
# # prompt strength gets converted to time encodings,
|
|
||||||
# # which means you can't get to true 0 without this hack
|
|
||||||
# # (or setting steps=1000)
|
|
||||||
# init_latent_noised = noise
|
|
||||||
# else:
|
|
||||||
# init_latent_noised = noise_an_image(
|
|
||||||
# init_latent,
|
|
||||||
# torch.tensor([t_enc - 1]).to(get_device()),
|
|
||||||
# schedule=schedule,
|
|
||||||
# noise=noise,
|
|
||||||
# )
|
|
||||||
|
|
||||||
if hasattr(model, "depth_stage_key"):
|
|
||||||
# depth model
|
|
||||||
depth_t = torch_image_to_depth_map(init_image_t)
|
|
||||||
depth_latent = torch.nn.functional.interpolate(
|
|
||||||
depth_t,
|
|
||||||
size=shape[2:],
|
|
||||||
mode="bicubic",
|
|
||||||
align_corners=False,
|
|
||||||
)
|
|
||||||
result_images["depth_image"] = depth_t
|
|
||||||
c_cat.append(depth_latent)
|
|
||||||
|
|
||||||
elif is_controlnet_model:
|
|
||||||
from imaginairy.img_processors.control_modes import CONTROL_MODES
|
|
||||||
|
|
||||||
for control_input in control_inputs:
|
|
||||||
if control_input.image_raw is not None:
|
|
||||||
control_image = control_input.image_raw
|
|
||||||
elif control_input.image is not None:
|
|
||||||
control_image = control_input.image
|
|
||||||
else:
|
|
||||||
raise RuntimeError("Control image must be provided")
|
|
||||||
assert control_image is not None
|
|
||||||
control_image = control_image.convert("RGB")
|
|
||||||
log_img(control_image, "control_image_input")
|
|
||||||
assert control_image is not None
|
|
||||||
|
|
||||||
control_image_input = pillow_fit_image_within(
|
|
||||||
control_image,
|
|
||||||
max_height=prompt.height,
|
|
||||||
max_width=prompt.width,
|
|
||||||
)
|
|
||||||
control_image_input_t = pillow_img_to_torch_image(control_image_input)
|
|
||||||
control_image_input_t = control_image_input_t.to(get_device())
|
|
||||||
|
|
||||||
if control_input.image_raw is None:
|
|
||||||
control_prep_function = CONTROL_MODES[control_input.mode]
|
|
||||||
if control_input.mode == "inpaint":
|
|
||||||
control_image_t = control_prep_function( # type: ignore
|
|
||||||
control_image_input_t, init_image_t
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
control_image_t = control_prep_function(control_image_input_t) # type: ignore
|
|
||||||
else:
|
|
||||||
control_image_t = (control_image_input_t + 1) / 2
|
|
||||||
|
|
||||||
control_image_disp = control_image_t * 2 - 1
|
|
||||||
result_images[f"control-{control_input.mode}"] = control_image_disp
|
|
||||||
log_img(control_image_disp, "control_image")
|
|
||||||
|
|
||||||
if len(control_image_t.shape) == 3:
|
|
||||||
raise RuntimeError("Control image must be 4D")
|
|
||||||
|
|
||||||
if control_image_t.shape[1] != 3:
|
|
||||||
raise RuntimeError("Control image must have 3 channels")
|
|
||||||
|
|
||||||
if (
|
|
||||||
control_input.mode != "inpaint"
|
|
||||||
and control_image_t.min() < 0
|
|
||||||
or control_image_t.max() > 1
|
|
||||||
):
|
|
||||||
msg = f"Control image must be in [0, 1] but we received {control_image_t.min()} and {control_image_t.max()}"
|
|
||||||
raise RuntimeError(msg)
|
|
||||||
|
|
||||||
if control_image_t.max() == control_image_t.min():
|
|
||||||
msg = f"No control signal found in control image {control_input.mode}."
|
|
||||||
raise RuntimeError(msg)
|
|
||||||
|
|
||||||
c_cat.append(control_image_t)
|
|
||||||
control_strengths.append(control_input.strength)
|
|
||||||
|
|
||||||
elif hasattr(model, "masked_image_key"):
|
|
||||||
# inpainting model
|
|
||||||
assert mask_image_orig is not None
|
|
||||||
assert mask_latent is not None
|
|
||||||
mask_t = pillow_img_to_torch_image(ImageOps.invert(mask_image_orig)).to(
|
|
||||||
get_device()
|
|
||||||
)
|
|
||||||
inverted_mask = 1 - mask_latent
|
|
||||||
masked_image_t = init_image_t * (mask_t < 0.5)
|
|
||||||
log_img(masked_image_t, "masked_image")
|
|
||||||
|
|
||||||
inverted_mask_latent = torch.nn.functional.interpolate(
|
|
||||||
inverted_mask, size=shape[-2:]
|
|
||||||
)
|
|
||||||
c_cat.append(inverted_mask_latent)
|
|
||||||
|
|
||||||
masked_image_latent = model.get_first_stage_encoding(
|
|
||||||
model.encode_first_stage(masked_image_t)
|
|
||||||
)
|
|
||||||
c_cat.append(masked_image_latent)
|
|
||||||
|
|
||||||
elif model.cond_stage_key == "edit":
|
|
||||||
# pix2pix model
|
|
||||||
c_cat = [model.encode_first_stage(init_image_t)]
|
|
||||||
assert init_latent is not None
|
|
||||||
c_cat_neutral = [torch.zeros_like(init_latent)]
|
|
||||||
denoiser_cls = CFGEditingDenoiser
|
|
||||||
if c_cat:
|
|
||||||
c_cat = [torch.cat([c], dim=1) for c in c_cat]
|
|
||||||
|
|
||||||
if c_cat_neutral is None:
|
|
||||||
c_cat_neutral = c_cat
|
|
||||||
|
|
||||||
positive_conditioning_d: dict[str, Any] = {
|
|
||||||
"c_concat": c_cat,
|
|
||||||
"c_crossattn": [positive_conditioning],
|
|
||||||
}
|
|
||||||
neutral_conditioning_d: dict[str, Any] = {
|
|
||||||
"c_concat": c_cat_neutral,
|
|
||||||
"c_crossattn": [neutral_conditioning],
|
|
||||||
}
|
|
||||||
del neutral_conditioning
|
|
||||||
del positive_conditioning
|
|
||||||
|
|
||||||
if control_strengths and is_controlnet_model:
|
|
||||||
positive_conditioning_d["control_strengths"] = torch.Tensor(
|
|
||||||
control_strengths
|
|
||||||
)
|
|
||||||
neutral_conditioning_d["control_strengths"] = torch.Tensor(
|
|
||||||
control_strengths
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
prompt.allow_compose_phase
|
|
||||||
and not is_controlnet_model
|
|
||||||
and model.cond_stage_key != "edit"
|
|
||||||
):
|
|
||||||
default_size = get_model_default_image_size(
|
|
||||||
prompt.model_weights.architecture
|
|
||||||
)
|
|
||||||
if prompt.init_image:
|
|
||||||
comp_image = _generate_composition_image(
|
|
||||||
prompt=prompt,
|
|
||||||
target_height=init_image.height,
|
|
||||||
target_width=init_image.width,
|
|
||||||
cutoff=default_size,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
comp_image = _generate_composition_image(
|
|
||||||
prompt=prompt,
|
|
||||||
target_height=prompt.height,
|
|
||||||
target_width=prompt.width,
|
|
||||||
cutoff=default_size,
|
|
||||||
)
|
|
||||||
if comp_image is not None:
|
|
||||||
result_images["composition"] = comp_image
|
|
||||||
# noise = noise[:, :, : comp_image.height, : comp_image.shape[3]]
|
|
||||||
t_enc = int(prompt.steps * 0.65)
|
|
||||||
log_img(comp_image, "comp_image")
|
|
||||||
comp_image_t = pillow_img_to_torch_image(comp_image)
|
|
||||||
comp_image_t = comp_image_t.to(get_device())
|
|
||||||
init_latent = model.get_first_stage_encoding(
|
|
||||||
model.encode_first_stage(comp_image_t)
|
|
||||||
)
|
|
||||||
with lc.timing("sampling"):
|
|
||||||
samples = solver.sample(
|
|
||||||
num_steps=prompt.steps,
|
|
||||||
positive_conditioning=positive_conditioning_d,
|
|
||||||
neutral_conditioning=neutral_conditioning_d,
|
|
||||||
guidance_scale=prompt.prompt_strength,
|
|
||||||
t_start=t_enc,
|
|
||||||
mask=mask_latent,
|
|
||||||
orig_latent=init_latent,
|
|
||||||
shape=shape,
|
|
||||||
batch_size=1,
|
|
||||||
denoiser_cls=denoiser_cls,
|
|
||||||
noise=noise,
|
|
||||||
)
|
|
||||||
if return_latent:
|
|
||||||
return samples
|
|
||||||
|
|
||||||
with lc.timing("decoding"):
|
|
||||||
gen_imgs_t = model.decode_first_stage(samples)
|
|
||||||
gen_img = torch_img_to_pillow_img(gen_imgs_t)
|
|
||||||
|
|
||||||
if mask_image_orig and init_image:
|
|
||||||
mask_final = mask_image_orig.copy()
|
|
||||||
log_img(mask_final, "reconstituting mask")
|
|
||||||
mask_final = ImageOps.invert(mask_final)
|
|
||||||
gen_img = Image.composite(gen_img, init_image, mask_final)
|
|
||||||
gen_img = combine_image(
|
|
||||||
original_img=init_image,
|
|
||||||
generated_img=gen_img,
|
|
||||||
mask_img=mask_image_orig,
|
|
||||||
)
|
|
||||||
log_img(gen_img, "reconstituted image")
|
|
||||||
|
|
||||||
upscaled_img = None
|
|
||||||
rebuilt_orig_img = None
|
|
||||||
|
|
||||||
if add_caption:
|
|
||||||
caption = generate_caption(gen_img)
|
|
||||||
logger.info(f"Generated caption: {caption}")
|
|
||||||
|
|
||||||
with lc.timing("safety-filter"):
|
|
||||||
safety_score = create_safety_score(
|
|
||||||
gen_img,
|
|
||||||
safety_mode=IMAGINAIRY_SAFETY_MODE,
|
|
||||||
)
|
|
||||||
if safety_score.is_filtered:
|
|
||||||
progress_latents.clear()
|
|
||||||
if not safety_score.is_filtered:
|
|
||||||
if prompt.fix_faces:
|
|
||||||
logger.info("Fixing 😊 's in 🖼 using CodeFormer...")
|
|
||||||
with lc.timing("face enhancement"):
|
|
||||||
gen_img = enhance_faces(gen_img, fidelity=prompt.fix_faces_fidelity)
|
|
||||||
if prompt.upscale:
|
|
||||||
logger.info("Upscaling 🖼 using real-ESRGAN...")
|
|
||||||
with lc.timing("upscaling"):
|
|
||||||
upscaled_img = upscale_image(gen_img)
|
|
||||||
|
|
||||||
# put the newly generated patch back into the original, full-size image
|
|
||||||
if prompt.mask_modify_original and mask_image_orig and starting_image:
|
|
||||||
img_to_add_back_to_original = upscaled_img if upscaled_img else gen_img
|
|
||||||
rebuilt_orig_img = combine_image(
|
|
||||||
original_img=starting_image,
|
|
||||||
generated_img=img_to_add_back_to_original,
|
|
||||||
mask_img=mask_image_orig,
|
|
||||||
)
|
|
||||||
|
|
||||||
if prompt.caption_text:
|
|
||||||
caption_text = prompt.caption_text.format(prompt=prompt.prompt_text)
|
|
||||||
add_caption_to_image(gen_img, caption_text)
|
|
||||||
|
|
||||||
result_images["upscaled"] = upscaled_img
|
|
||||||
result_images["modified_original"] = rebuilt_orig_img
|
|
||||||
result_images["mask_binary"] = mask_image_orig
|
|
||||||
result_images["mask_grayscale"] = mask_grayscale
|
|
||||||
|
|
||||||
result = ImagineResult(
|
|
||||||
img=gen_img,
|
|
||||||
prompt=prompt,
|
|
||||||
is_nsfw=safety_score.is_nsfw,
|
|
||||||
safety_score=safety_score,
|
|
||||||
result_images=result_images,
|
|
||||||
timings=lc.get_timings(),
|
|
||||||
progress_latents=progress_latents.copy(),
|
|
||||||
)
|
|
||||||
|
|
||||||
_most_recent_result = result
|
|
||||||
logger.info(f"Image Generated. Timings: {result.timings_str()}")
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def _prompts_to_embeddings(prompts, model):
|
|
||||||
total_weight = sum(wp.weight for wp in prompts)
|
|
||||||
conditioning = sum(
|
|
||||||
model.get_learned_conditioning(wp.text) * (wp.weight / total_weight)
|
|
||||||
for wp in prompts
|
|
||||||
)
|
|
||||||
return conditioning
|
|
||||||
|
|
||||||
|
|
||||||
def calc_scale_to_fit_within(height: int, width: int, max_size) -> float:
|
|
||||||
max_width, max_height = normalize_image_size(max_size)
|
|
||||||
if width <= max_width and height <= max_height:
|
|
||||||
return 1
|
|
||||||
|
|
||||||
width_ratio = max_width / width
|
|
||||||
height_ratio = max_height / height
|
|
||||||
|
|
||||||
return min(width_ratio, height_ratio)
|
|
||||||
|
|
||||||
|
|
||||||
def _scale_latent(
|
|
||||||
latent,
|
|
||||||
model,
|
|
||||||
h,
|
|
||||||
w,
|
|
||||||
):
|
|
||||||
from torch.nn import functional as F
|
|
||||||
|
|
||||||
# convert to non-latent-space first
|
|
||||||
img = model.decode_first_stage(latent)
|
|
||||||
img = F.interpolate(img, size=(h, w), mode="bicubic", align_corners=False)
|
|
||||||
latent = model.get_first_stage_encoding(model.encode_first_stage(img))
|
|
||||||
return latent
|
|
||||||
|
|
||||||
|
|
||||||
def _generate_composition_image(
|
|
||||||
prompt,
|
|
||||||
target_height,
|
|
||||||
target_width,
|
|
||||||
cutoff: tuple[int, int] = (512, 512),
|
|
||||||
dtype=None,
|
|
||||||
):
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from imaginairy.api.generate_refiners import _generate_single_image
|
|
||||||
from imaginairy.utils import default, get_default_dtype
|
|
||||||
|
|
||||||
cutoff = normalize_image_size(cutoff)
|
|
||||||
if prompt.width <= cutoff[0] and prompt.height <= cutoff[1]:
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
dtype = default(dtype, get_default_dtype)
|
|
||||||
|
|
||||||
shrink_scale = calc_scale_to_fit_within(
|
|
||||||
height=prompt.height,
|
|
||||||
width=prompt.width,
|
|
||||||
max_size=cutoff,
|
|
||||||
)
|
|
||||||
|
|
||||||
composition_prompt = prompt.full_copy(
|
|
||||||
deep=True,
|
|
||||||
update={
|
|
||||||
"size": (
|
|
||||||
int(prompt.width * shrink_scale),
|
|
||||||
int(prompt.height * shrink_scale),
|
|
||||||
),
|
|
||||||
"steps": None,
|
|
||||||
"upscale": False,
|
|
||||||
"fix_faces": False,
|
|
||||||
"mask_modify_original": False,
|
|
||||||
"allow_compose_phase": False,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
result = _generate_single_image(composition_prompt, dtype=dtype)
|
|
||||||
img = result.images["generated"]
|
|
||||||
while img.width < target_width:
|
|
||||||
from imaginairy.enhancers.upscale_realesrgan import upscale_image
|
|
||||||
|
|
||||||
img = upscale_image(img)
|
|
||||||
|
|
||||||
# samples = _generate_single_image(composition_prompt, return_latent=True)
|
|
||||||
# while samples.shape[-1] * 8 < target_width:
|
|
||||||
# samples = upscale_latent(samples)
|
|
||||||
#
|
|
||||||
# img = model_latent_to_pillow_img(samples)
|
|
||||||
|
|
||||||
img = img.resize(
|
|
||||||
(target_width, target_height),
|
|
||||||
resample=Image.Resampling.LANCZOS,
|
|
||||||
)
|
|
||||||
|
|
||||||
return img, result.images["generated"]
|
|
||||||
|
|
||||||
|
|
||||||
def prompt_normalized(prompt, length=130):
|
|
||||||
return re.sub(r"[^a-zA-Z0-9.,\[\]-]+", "_", prompt)[:length]
|
|
||||||
|
|
||||||
|
|
||||||
def combine_image(original_img, generated_img, mask_img):
|
|
||||||
"""Combine the generated image with the original image using the mask image."""
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from imaginairy.utils.log_utils import log_img
|
|
||||||
|
|
||||||
generated_img = generated_img.resize(
|
|
||||||
original_img.size,
|
|
||||||
resample=Image.Resampling.LANCZOS,
|
|
||||||
)
|
|
||||||
|
|
||||||
mask_for_orig_size = mask_img.resize(
|
|
||||||
original_img.size,
|
|
||||||
resample=Image.Resampling.LANCZOS,
|
|
||||||
)
|
|
||||||
log_img(mask_for_orig_size, "mask for original image size")
|
|
||||||
|
|
||||||
rebuilt_orig_img = Image.composite(
|
|
||||||
original_img,
|
|
||||||
generated_img,
|
|
||||||
mask_for_orig_size,
|
|
||||||
)
|
|
||||||
log_img(rebuilt_orig_img, "reconstituted original")
|
|
||||||
return rebuilt_orig_img
|
|
||||||
|
547
imaginairy/api/generate_compvis.py
Normal file
547
imaginairy/api/generate_compvis.py
Normal file
@ -0,0 +1,547 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from imaginairy.api.generate import (
|
||||||
|
IMAGINAIRY_SAFETY_MODE,
|
||||||
|
logger,
|
||||||
|
)
|
||||||
|
from imaginairy.api.generate_refiners import _generate_composition_image
|
||||||
|
from imaginairy.schema import ImaginePrompt, LazyLoadingImage
|
||||||
|
from imaginairy.utils.img_utils import calc_scale_to_fit_within, combine_image
|
||||||
|
from imaginairy.utils.named_resolutions import normalize_image_size
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_single_image_compvis(
|
||||||
|
prompt: "ImaginePrompt",
|
||||||
|
debug_img_callback=None,
|
||||||
|
progress_img_callback=None,
|
||||||
|
progress_img_interval_steps=3,
|
||||||
|
progress_img_interval_min_s=0.1,
|
||||||
|
half_mode=None,
|
||||||
|
add_caption=False,
|
||||||
|
# controlnet, finetune, naive, auto
|
||||||
|
inpaint_method="finetune",
|
||||||
|
return_latent=False,
|
||||||
|
):
|
||||||
|
import torch.nn
|
||||||
|
from PIL import Image, ImageOps
|
||||||
|
from pytorch_lightning import seed_everything
|
||||||
|
|
||||||
|
from imaginairy.enhancers.clip_masking import get_img_mask
|
||||||
|
from imaginairy.enhancers.describe_image_blip import generate_caption
|
||||||
|
from imaginairy.enhancers.face_restoration_codeformer import enhance_faces
|
||||||
|
from imaginairy.enhancers.upscale_realesrgan import upscale_image
|
||||||
|
from imaginairy.modules.midas.api import torch_image_to_depth_map
|
||||||
|
from imaginairy.samplers import SOLVER_LOOKUP
|
||||||
|
from imaginairy.samplers.editing import CFGEditingDenoiser
|
||||||
|
from imaginairy.schema import ControlInput, ImagineResult, MaskMode
|
||||||
|
from imaginairy.utils import get_device, randn_seeded
|
||||||
|
from imaginairy.utils.img_utils import (
|
||||||
|
add_caption_to_image,
|
||||||
|
pillow_fit_image_within,
|
||||||
|
pillow_img_to_torch_image,
|
||||||
|
pillow_mask_to_latent_mask,
|
||||||
|
torch_img_to_pillow_img,
|
||||||
|
)
|
||||||
|
from imaginairy.utils.log_utils import (
|
||||||
|
ImageLoggingContext,
|
||||||
|
log_conditioning,
|
||||||
|
log_img,
|
||||||
|
log_latent,
|
||||||
|
)
|
||||||
|
from imaginairy.utils.model_manager import (
|
||||||
|
get_diffusion_model,
|
||||||
|
get_model_default_image_size,
|
||||||
|
)
|
||||||
|
from imaginairy.utils.outpaint import (
|
||||||
|
outpaint_arg_str_parse,
|
||||||
|
prepare_image_for_outpaint,
|
||||||
|
)
|
||||||
|
from imaginairy.utils.safety import create_safety_score
|
||||||
|
|
||||||
|
latent_channels = 4
|
||||||
|
downsampling_factor = 8
|
||||||
|
batch_size = 1
|
||||||
|
global _most_recent_result
|
||||||
|
# handle prompt pulling in previous values
|
||||||
|
# if isinstance(prompt.init_image, str) and prompt.init_image.startswith("*prev"):
|
||||||
|
# _, img_type = prompt.init_image.strip("*").split(".")
|
||||||
|
# prompt.init_image = _most_recent_result.images[img_type]
|
||||||
|
# if isinstance(prompt.mask_image, str) and prompt.mask_image.startswith("*prev"):
|
||||||
|
# _, img_type = prompt.mask_image.strip("*").split(".")
|
||||||
|
# prompt.mask_image = _most_recent_result.images[img_type]
|
||||||
|
prompt = prompt.make_concrete_copy()
|
||||||
|
|
||||||
|
control_modes = []
|
||||||
|
control_inputs = prompt.control_inputs or []
|
||||||
|
control_inputs = control_inputs.copy()
|
||||||
|
for_inpainting = bool(prompt.mask_image or prompt.mask_prompt or prompt.outpaint)
|
||||||
|
|
||||||
|
if control_inputs:
|
||||||
|
control_modes = [c.mode for c in prompt.control_inputs]
|
||||||
|
if inpaint_method == "auto":
|
||||||
|
if prompt.model_weights in {"SD-1.5"}:
|
||||||
|
inpaint_method = "finetune"
|
||||||
|
else:
|
||||||
|
inpaint_method = "controlnet"
|
||||||
|
|
||||||
|
if for_inpainting and inpaint_method == "controlnet":
|
||||||
|
control_modes.append("inpaint")
|
||||||
|
model = get_diffusion_model(
|
||||||
|
weights_location=prompt.model_weights,
|
||||||
|
config_path=prompt.model_architecture,
|
||||||
|
control_weights_locations=control_modes,
|
||||||
|
half_mode=half_mode,
|
||||||
|
for_inpainting=for_inpainting and inpaint_method == "finetune",
|
||||||
|
)
|
||||||
|
is_controlnet_model = hasattr(model, "control_key")
|
||||||
|
|
||||||
|
progress_latents = []
|
||||||
|
|
||||||
|
def latent_logger(latents):
|
||||||
|
progress_latents.append(latents)
|
||||||
|
|
||||||
|
with ImageLoggingContext(
|
||||||
|
prompt=prompt,
|
||||||
|
model=model,
|
||||||
|
debug_img_callback=debug_img_callback,
|
||||||
|
progress_img_callback=progress_img_callback,
|
||||||
|
progress_img_interval_steps=progress_img_interval_steps,
|
||||||
|
progress_img_interval_min_s=progress_img_interval_min_s,
|
||||||
|
progress_latent_callback=latent_logger
|
||||||
|
if prompt.collect_progress_latents
|
||||||
|
else None,
|
||||||
|
) as lc:
|
||||||
|
seed_everything(prompt.seed)
|
||||||
|
|
||||||
|
model.tile_mode(prompt.tile_mode)
|
||||||
|
with lc.timing("conditioning"):
|
||||||
|
# need to expand if doing batches
|
||||||
|
neutral_conditioning = _prompts_to_embeddings(prompt.negative_prompt, model)
|
||||||
|
_prompts_to_embeddings("", model)
|
||||||
|
log_conditioning(neutral_conditioning, "neutral conditioning")
|
||||||
|
if prompt.conditioning is not None:
|
||||||
|
positive_conditioning = prompt.conditioning
|
||||||
|
else:
|
||||||
|
positive_conditioning = _prompts_to_embeddings(prompt.prompts, model)
|
||||||
|
log_conditioning(positive_conditioning, "positive conditioning")
|
||||||
|
|
||||||
|
shape = [
|
||||||
|
batch_size,
|
||||||
|
latent_channels,
|
||||||
|
prompt.height // downsampling_factor,
|
||||||
|
prompt.width // downsampling_factor,
|
||||||
|
]
|
||||||
|
SolverCls = SOLVER_LOOKUP[prompt.solver_type.lower()]
|
||||||
|
solver = SolverCls(model)
|
||||||
|
mask_image: Image.Image | LazyLoadingImage | None = None
|
||||||
|
mask_latent = mask_image_orig = mask_grayscale = None
|
||||||
|
init_latent: torch.Tensor | None = None
|
||||||
|
t_enc = None
|
||||||
|
starting_image = None
|
||||||
|
denoiser_cls = None
|
||||||
|
|
||||||
|
c_cat = []
|
||||||
|
c_cat_neutral = None
|
||||||
|
result_images: dict[str, torch.Tensor | Image.Image | None] = {}
|
||||||
|
assert prompt.seed is not None
|
||||||
|
seed_everything(prompt.seed)
|
||||||
|
noise = randn_seeded(seed=prompt.seed, size=shape).to(get_device())
|
||||||
|
control_strengths = []
|
||||||
|
|
||||||
|
if prompt.init_image:
|
||||||
|
starting_image = prompt.init_image
|
||||||
|
assert prompt.init_image_strength is not None
|
||||||
|
generation_strength = 1 - prompt.init_image_strength
|
||||||
|
|
||||||
|
if model.cond_stage_key == "edit" or generation_strength >= 1:
|
||||||
|
t_enc = None
|
||||||
|
else:
|
||||||
|
t_enc = int(prompt.steps * generation_strength)
|
||||||
|
|
||||||
|
if prompt.mask_prompt:
|
||||||
|
mask_image, mask_grayscale = get_img_mask(
|
||||||
|
starting_image, prompt.mask_prompt, threshold=0.1
|
||||||
|
)
|
||||||
|
elif prompt.mask_image:
|
||||||
|
mask_image = prompt.mask_image.convert("L")
|
||||||
|
if prompt.outpaint:
|
||||||
|
outpaint_kwargs = outpaint_arg_str_parse(prompt.outpaint)
|
||||||
|
starting_image, mask_image = prepare_image_for_outpaint(
|
||||||
|
starting_image, mask_image, **outpaint_kwargs
|
||||||
|
)
|
||||||
|
assert starting_image is not None
|
||||||
|
init_image = pillow_fit_image_within(
|
||||||
|
starting_image,
|
||||||
|
max_height=prompt.height,
|
||||||
|
max_width=prompt.width,
|
||||||
|
)
|
||||||
|
init_image_t = pillow_img_to_torch_image(init_image).to(get_device())
|
||||||
|
init_latent = model.get_first_stage_encoding(
|
||||||
|
model.encode_first_stage(init_image_t)
|
||||||
|
)
|
||||||
|
assert init_latent is not None
|
||||||
|
shape = list(init_latent.shape)
|
||||||
|
|
||||||
|
log_latent(init_latent, "init_latent")
|
||||||
|
|
||||||
|
if mask_image is not None:
|
||||||
|
mask_image = pillow_fit_image_within(
|
||||||
|
mask_image,
|
||||||
|
max_height=prompt.height,
|
||||||
|
max_width=prompt.width,
|
||||||
|
convert="L",
|
||||||
|
)
|
||||||
|
|
||||||
|
log_img(mask_image, "init mask")
|
||||||
|
|
||||||
|
if prompt.mask_mode == MaskMode.REPLACE:
|
||||||
|
mask_image = ImageOps.invert(mask_image)
|
||||||
|
|
||||||
|
mask_image_orig = mask_image
|
||||||
|
log_img(mask_image, "latent_mask")
|
||||||
|
mask_latent = pillow_mask_to_latent_mask(
|
||||||
|
mask_image, downsampling_factor=downsampling_factor
|
||||||
|
).to(get_device())
|
||||||
|
if inpaint_method == "controlnet":
|
||||||
|
result_images["control-inpaint"] = mask_image
|
||||||
|
control_inputs.append(
|
||||||
|
ControlInput(mode="inpaint", image=mask_image)
|
||||||
|
)
|
||||||
|
assert prompt.seed is not None
|
||||||
|
seed_everything(prompt.seed)
|
||||||
|
noise = randn_seeded(seed=prompt.seed, size=list(init_latent.shape)).to(
|
||||||
|
get_device()
|
||||||
|
)
|
||||||
|
# noise = noise[:, :, : init_latent.shape[2], : init_latent.shape[3]]
|
||||||
|
|
||||||
|
# schedule = NoiseSchedule(
|
||||||
|
# model_num_timesteps=model.num_timesteps,
|
||||||
|
# ddim_num_steps=prompt.steps,
|
||||||
|
# model_alphas_cumprod=model.alphas_cumprod,
|
||||||
|
# ddim_discretize="uniform",
|
||||||
|
# )
|
||||||
|
# if generation_strength >= 1:
|
||||||
|
# # prompt strength gets converted to time encodings,
|
||||||
|
# # which means you can't get to true 0 without this hack
|
||||||
|
# # (or setting steps=1000)
|
||||||
|
# init_latent_noised = noise
|
||||||
|
# else:
|
||||||
|
# init_latent_noised = noise_an_image(
|
||||||
|
# init_latent,
|
||||||
|
# torch.tensor([t_enc - 1]).to(get_device()),
|
||||||
|
# schedule=schedule,
|
||||||
|
# noise=noise,
|
||||||
|
# )
|
||||||
|
|
||||||
|
if hasattr(model, "depth_stage_key"):
|
||||||
|
# depth model
|
||||||
|
depth_t = torch_image_to_depth_map(init_image_t)
|
||||||
|
depth_latent = torch.nn.functional.interpolate(
|
||||||
|
depth_t,
|
||||||
|
size=shape[2:],
|
||||||
|
mode="bicubic",
|
||||||
|
align_corners=False,
|
||||||
|
)
|
||||||
|
result_images["depth_image"] = depth_t
|
||||||
|
c_cat.append(depth_latent)
|
||||||
|
|
||||||
|
elif is_controlnet_model:
|
||||||
|
from imaginairy.img_processors.control_modes import CONTROL_MODES
|
||||||
|
|
||||||
|
for control_input in control_inputs:
|
||||||
|
if control_input.image_raw is not None:
|
||||||
|
control_image = control_input.image_raw
|
||||||
|
elif control_input.image is not None:
|
||||||
|
control_image = control_input.image
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Control image must be provided")
|
||||||
|
assert control_image is not None
|
||||||
|
control_image = control_image.convert("RGB")
|
||||||
|
log_img(control_image, "control_image_input")
|
||||||
|
assert control_image is not None
|
||||||
|
|
||||||
|
control_image_input = pillow_fit_image_within(
|
||||||
|
control_image,
|
||||||
|
max_height=prompt.height,
|
||||||
|
max_width=prompt.width,
|
||||||
|
)
|
||||||
|
control_image_input_t = pillow_img_to_torch_image(control_image_input)
|
||||||
|
control_image_input_t = control_image_input_t.to(get_device())
|
||||||
|
|
||||||
|
if control_input.image_raw is None:
|
||||||
|
control_prep_function = CONTROL_MODES[control_input.mode]
|
||||||
|
if control_input.mode == "inpaint":
|
||||||
|
control_image_t = control_prep_function( # type: ignore
|
||||||
|
control_image_input_t, init_image_t
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
control_image_t = control_prep_function(control_image_input_t) # type: ignore
|
||||||
|
else:
|
||||||
|
control_image_t = (control_image_input_t + 1) / 2
|
||||||
|
|
||||||
|
control_image_disp = control_image_t * 2 - 1
|
||||||
|
result_images[f"control-{control_input.mode}"] = control_image_disp
|
||||||
|
log_img(control_image_disp, "control_image")
|
||||||
|
|
||||||
|
if len(control_image_t.shape) == 3:
|
||||||
|
raise RuntimeError("Control image must be 4D")
|
||||||
|
|
||||||
|
if control_image_t.shape[1] != 3:
|
||||||
|
raise RuntimeError("Control image must have 3 channels")
|
||||||
|
|
||||||
|
if (
|
||||||
|
control_input.mode != "inpaint"
|
||||||
|
and control_image_t.min() < 0
|
||||||
|
or control_image_t.max() > 1
|
||||||
|
):
|
||||||
|
msg = f"Control image must be in [0, 1] but we received {control_image_t.min()} and {control_image_t.max()}"
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
|
||||||
|
if control_image_t.max() == control_image_t.min():
|
||||||
|
msg = f"No control signal found in control image {control_input.mode}."
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
|
||||||
|
c_cat.append(control_image_t)
|
||||||
|
control_strengths.append(control_input.strength)
|
||||||
|
|
||||||
|
elif hasattr(model, "masked_image_key"):
|
||||||
|
# inpainting model
|
||||||
|
assert mask_image_orig is not None
|
||||||
|
assert mask_latent is not None
|
||||||
|
mask_t = pillow_img_to_torch_image(ImageOps.invert(mask_image_orig)).to(
|
||||||
|
get_device()
|
||||||
|
)
|
||||||
|
inverted_mask = 1 - mask_latent
|
||||||
|
masked_image_t = init_image_t * (mask_t < 0.5)
|
||||||
|
log_img(masked_image_t, "masked_image")
|
||||||
|
|
||||||
|
inverted_mask_latent = torch.nn.functional.interpolate(
|
||||||
|
inverted_mask, size=shape[-2:]
|
||||||
|
)
|
||||||
|
c_cat.append(inverted_mask_latent)
|
||||||
|
|
||||||
|
masked_image_latent = model.get_first_stage_encoding(
|
||||||
|
model.encode_first_stage(masked_image_t)
|
||||||
|
)
|
||||||
|
c_cat.append(masked_image_latent)
|
||||||
|
|
||||||
|
elif model.cond_stage_key == "edit":
|
||||||
|
# pix2pix model
|
||||||
|
c_cat = [model.encode_first_stage(init_image_t)]
|
||||||
|
assert init_latent is not None
|
||||||
|
c_cat_neutral = [torch.zeros_like(init_latent)]
|
||||||
|
denoiser_cls = CFGEditingDenoiser
|
||||||
|
if c_cat:
|
||||||
|
c_cat = [torch.cat([c], dim=1) for c in c_cat]
|
||||||
|
|
||||||
|
if c_cat_neutral is None:
|
||||||
|
c_cat_neutral = c_cat
|
||||||
|
|
||||||
|
positive_conditioning_d: dict[str, Any] = {
|
||||||
|
"c_concat": c_cat,
|
||||||
|
"c_crossattn": [positive_conditioning],
|
||||||
|
}
|
||||||
|
neutral_conditioning_d: dict[str, Any] = {
|
||||||
|
"c_concat": c_cat_neutral,
|
||||||
|
"c_crossattn": [neutral_conditioning],
|
||||||
|
}
|
||||||
|
del neutral_conditioning
|
||||||
|
del positive_conditioning
|
||||||
|
|
||||||
|
if control_strengths and is_controlnet_model:
|
||||||
|
positive_conditioning_d["control_strengths"] = torch.Tensor(
|
||||||
|
control_strengths
|
||||||
|
)
|
||||||
|
neutral_conditioning_d["control_strengths"] = torch.Tensor(
|
||||||
|
control_strengths
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
prompt.allow_compose_phase
|
||||||
|
and not is_controlnet_model
|
||||||
|
and model.cond_stage_key != "edit"
|
||||||
|
):
|
||||||
|
default_size = get_model_default_image_size(
|
||||||
|
prompt.model_weights.architecture
|
||||||
|
)
|
||||||
|
if prompt.init_image:
|
||||||
|
comp_image = _generate_composition_image(
|
||||||
|
prompt=prompt,
|
||||||
|
target_height=init_image.height,
|
||||||
|
target_width=init_image.width,
|
||||||
|
cutoff=default_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
comp_image = _generate_composition_image(
|
||||||
|
prompt=prompt,
|
||||||
|
target_height=prompt.height,
|
||||||
|
target_width=prompt.width,
|
||||||
|
cutoff=default_size,
|
||||||
|
)
|
||||||
|
if comp_image is not None:
|
||||||
|
result_images["composition"] = comp_image
|
||||||
|
# noise = noise[:, :, : comp_image.height, : comp_image.shape[3]]
|
||||||
|
t_enc = int(prompt.steps * 0.65)
|
||||||
|
log_img(comp_image, "comp_image")
|
||||||
|
comp_image_t = pillow_img_to_torch_image(comp_image)
|
||||||
|
comp_image_t = comp_image_t.to(get_device())
|
||||||
|
init_latent = model.get_first_stage_encoding(
|
||||||
|
model.encode_first_stage(comp_image_t)
|
||||||
|
)
|
||||||
|
with lc.timing("sampling"):
|
||||||
|
samples = solver.sample(
|
||||||
|
num_steps=prompt.steps,
|
||||||
|
positive_conditioning=positive_conditioning_d,
|
||||||
|
neutral_conditioning=neutral_conditioning_d,
|
||||||
|
guidance_scale=prompt.prompt_strength,
|
||||||
|
t_start=t_enc,
|
||||||
|
mask=mask_latent,
|
||||||
|
orig_latent=init_latent,
|
||||||
|
shape=shape,
|
||||||
|
batch_size=1,
|
||||||
|
denoiser_cls=denoiser_cls,
|
||||||
|
noise=noise,
|
||||||
|
)
|
||||||
|
if return_latent:
|
||||||
|
return samples
|
||||||
|
|
||||||
|
with lc.timing("decoding"):
|
||||||
|
gen_imgs_t = model.decode_first_stage(samples)
|
||||||
|
gen_img = torch_img_to_pillow_img(gen_imgs_t)
|
||||||
|
|
||||||
|
if mask_image_orig and init_image:
|
||||||
|
mask_final = mask_image_orig.copy()
|
||||||
|
log_img(mask_final, "reconstituting mask")
|
||||||
|
mask_final = ImageOps.invert(mask_final)
|
||||||
|
gen_img = Image.composite(gen_img, init_image, mask_final)
|
||||||
|
gen_img = combine_image(
|
||||||
|
original_img=init_image,
|
||||||
|
generated_img=gen_img,
|
||||||
|
mask_img=mask_image_orig,
|
||||||
|
)
|
||||||
|
log_img(gen_img, "reconstituted image")
|
||||||
|
|
||||||
|
upscaled_img = None
|
||||||
|
rebuilt_orig_img = None
|
||||||
|
|
||||||
|
if add_caption:
|
||||||
|
caption = generate_caption(gen_img)
|
||||||
|
logger.info(f"Generated caption: {caption}")
|
||||||
|
|
||||||
|
with lc.timing("safety-filter"):
|
||||||
|
safety_score = create_safety_score(
|
||||||
|
gen_img,
|
||||||
|
safety_mode=IMAGINAIRY_SAFETY_MODE,
|
||||||
|
)
|
||||||
|
if safety_score.is_filtered:
|
||||||
|
progress_latents.clear()
|
||||||
|
if not safety_score.is_filtered:
|
||||||
|
if prompt.fix_faces:
|
||||||
|
logger.info("Fixing 😊 's in 🖼 using CodeFormer...")
|
||||||
|
with lc.timing("face enhancement"):
|
||||||
|
gen_img = enhance_faces(gen_img, fidelity=prompt.fix_faces_fidelity)
|
||||||
|
if prompt.upscale:
|
||||||
|
logger.info("Upscaling 🖼 using real-ESRGAN...")
|
||||||
|
with lc.timing("upscaling"):
|
||||||
|
upscaled_img = upscale_image(gen_img)
|
||||||
|
|
||||||
|
# put the newly generated patch back into the original, full-size image
|
||||||
|
if prompt.mask_modify_original and mask_image_orig and starting_image:
|
||||||
|
img_to_add_back_to_original = upscaled_img if upscaled_img else gen_img
|
||||||
|
rebuilt_orig_img = combine_image(
|
||||||
|
original_img=starting_image,
|
||||||
|
generated_img=img_to_add_back_to_original,
|
||||||
|
mask_img=mask_image_orig,
|
||||||
|
)
|
||||||
|
|
||||||
|
if prompt.caption_text:
|
||||||
|
caption_text = prompt.caption_text.format(prompt=prompt.prompt_text)
|
||||||
|
add_caption_to_image(gen_img, caption_text)
|
||||||
|
|
||||||
|
result_images["upscaled"] = upscaled_img
|
||||||
|
result_images["modified_original"] = rebuilt_orig_img
|
||||||
|
result_images["mask_binary"] = mask_image_orig
|
||||||
|
result_images["mask_grayscale"] = mask_grayscale
|
||||||
|
|
||||||
|
result = ImagineResult(
|
||||||
|
img=gen_img,
|
||||||
|
prompt=prompt,
|
||||||
|
is_nsfw=safety_score.is_nsfw,
|
||||||
|
safety_score=safety_score,
|
||||||
|
result_images=result_images,
|
||||||
|
timings=lc.get_timings(),
|
||||||
|
progress_latents=progress_latents.copy(),
|
||||||
|
)
|
||||||
|
|
||||||
|
_most_recent_result = result
|
||||||
|
logger.info(f"Image Generated. Timings: {result.timings_str()}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _prompts_to_embeddings(prompts, model):
|
||||||
|
total_weight = sum(wp.weight for wp in prompts)
|
||||||
|
conditioning = sum(
|
||||||
|
model.get_learned_conditioning(wp.text) * (wp.weight / total_weight)
|
||||||
|
for wp in prompts
|
||||||
|
)
|
||||||
|
return conditioning
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_composition_image(
|
||||||
|
prompt,
|
||||||
|
target_height,
|
||||||
|
target_width,
|
||||||
|
cutoff: tuple[int, int] = (512, 512),
|
||||||
|
dtype=None,
|
||||||
|
):
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from imaginairy.api.generate_refiners import generate_single_image
|
||||||
|
from imaginairy.utils import default, get_default_dtype
|
||||||
|
|
||||||
|
cutoff = normalize_image_size(cutoff)
|
||||||
|
if prompt.width <= cutoff[0] and prompt.height <= cutoff[1]:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
dtype = default(dtype, get_default_dtype)
|
||||||
|
|
||||||
|
shrink_scale = calc_scale_to_fit_within(
|
||||||
|
height=prompt.height,
|
||||||
|
width=prompt.width,
|
||||||
|
max_size=cutoff,
|
||||||
|
)
|
||||||
|
|
||||||
|
composition_prompt = prompt.full_copy(
|
||||||
|
deep=True,
|
||||||
|
update={
|
||||||
|
"size": (
|
||||||
|
int(prompt.width * shrink_scale),
|
||||||
|
int(prompt.height * shrink_scale),
|
||||||
|
),
|
||||||
|
"steps": None,
|
||||||
|
"upscale": False,
|
||||||
|
"fix_faces": False,
|
||||||
|
"mask_modify_original": False,
|
||||||
|
"allow_compose_phase": False,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
result = generate_single_image(composition_prompt, dtype=dtype)
|
||||||
|
img = result.images["generated"]
|
||||||
|
while img.width < target_width:
|
||||||
|
from imaginairy.enhancers.upscale_realesrgan import upscale_image
|
||||||
|
|
||||||
|
img = upscale_image(img)
|
||||||
|
|
||||||
|
# samples = generate_single_image(composition_prompt, return_latent=True)
|
||||||
|
# while samples.shape[-1] * 8 < target_width:
|
||||||
|
# samples = upscale_latent(samples)
|
||||||
|
#
|
||||||
|
# img = model_latent_to_pillow_img(samples)
|
||||||
|
|
||||||
|
img = img.resize(
|
||||||
|
(target_width, target_height),
|
||||||
|
resample=Image.Resampling.LANCZOS,
|
||||||
|
)
|
||||||
|
|
||||||
|
return img, result.images["generated"]
|
@ -5,11 +5,13 @@ from typing import List, Optional
|
|||||||
|
|
||||||
from imaginairy.config import CONTROL_CONFIG_SHORTCUTS
|
from imaginairy.config import CONTROL_CONFIG_SHORTCUTS
|
||||||
from imaginairy.schema import ControlInput, ImaginePrompt, MaskMode, WeightedPrompt
|
from imaginairy.schema import ControlInput, ImaginePrompt, MaskMode, WeightedPrompt
|
||||||
|
from imaginairy.utils.img_utils import calc_scale_to_fit_within
|
||||||
|
from imaginairy.utils.named_resolutions import normalize_image_size
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _generate_single_image(
|
def generate_single_image(
|
||||||
prompt: ImaginePrompt,
|
prompt: ImaginePrompt,
|
||||||
debug_img_callback=None,
|
debug_img_callback=None,
|
||||||
progress_img_callback=None,
|
progress_img_callback=None,
|
||||||
@ -28,8 +30,6 @@ def _generate_single_image(
|
|||||||
|
|
||||||
from imaginairy.api.generate import (
|
from imaginairy.api.generate import (
|
||||||
IMAGINAIRY_SAFETY_MODE,
|
IMAGINAIRY_SAFETY_MODE,
|
||||||
_generate_composition_image,
|
|
||||||
combine_image,
|
|
||||||
)
|
)
|
||||||
from imaginairy.enhancers.clip_masking import get_img_mask
|
from imaginairy.enhancers.clip_masking import get_img_mask
|
||||||
from imaginairy.enhancers.describe_image_blip import generate_caption
|
from imaginairy.enhancers.describe_image_blip import generate_caption
|
||||||
@ -40,6 +40,7 @@ def _generate_single_image(
|
|||||||
from imaginairy.utils import get_device, randn_seeded
|
from imaginairy.utils import get_device, randn_seeded
|
||||||
from imaginairy.utils.img_utils import (
|
from imaginairy.utils.img_utils import (
|
||||||
add_caption_to_image,
|
add_caption_to_image,
|
||||||
|
combine_image,
|
||||||
pillow_fit_image_within,
|
pillow_fit_image_within,
|
||||||
pillow_img_to_torch_image,
|
pillow_img_to_torch_image,
|
||||||
pillow_mask_to_latent_mask,
|
pillow_mask_to_latent_mask,
|
||||||
@ -523,3 +524,64 @@ def prep_control_input(
|
|||||||
)
|
)
|
||||||
controlnet.set_scale(control_input.strength)
|
controlnet.set_scale(control_input.strength)
|
||||||
return controlnet, control_image_t, control_image_disp
|
return controlnet, control_image_t, control_image_disp
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_composition_image(
|
||||||
|
prompt,
|
||||||
|
target_height,
|
||||||
|
target_width,
|
||||||
|
cutoff: tuple[int, int] = (512, 512),
|
||||||
|
dtype=None,
|
||||||
|
):
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from imaginairy.api.generate_refiners import generate_single_image
|
||||||
|
from imaginairy.utils import default, get_default_dtype
|
||||||
|
|
||||||
|
cutoff = normalize_image_size(cutoff)
|
||||||
|
if prompt.width <= cutoff[0] and prompt.height <= cutoff[1]:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
dtype = default(dtype, get_default_dtype)
|
||||||
|
|
||||||
|
shrink_scale = calc_scale_to_fit_within(
|
||||||
|
height=prompt.height,
|
||||||
|
width=prompt.width,
|
||||||
|
max_size=cutoff,
|
||||||
|
)
|
||||||
|
|
||||||
|
composition_prompt = prompt.full_copy(
|
||||||
|
deep=True,
|
||||||
|
update={
|
||||||
|
"size": (
|
||||||
|
int(prompt.width * shrink_scale),
|
||||||
|
int(prompt.height * shrink_scale),
|
||||||
|
),
|
||||||
|
"steps": None,
|
||||||
|
"upscale": False,
|
||||||
|
"fix_faces": False,
|
||||||
|
"mask_modify_original": False,
|
||||||
|
"allow_compose_phase": False,
|
||||||
|
"caption_text": None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
result = generate_single_image(composition_prompt, dtype=dtype)
|
||||||
|
img = result.images["generated"]
|
||||||
|
while img.width < target_width:
|
||||||
|
from imaginairy.enhancers.upscale_realesrgan import upscale_image
|
||||||
|
|
||||||
|
img = upscale_image(img)
|
||||||
|
|
||||||
|
# samples = generate_single_image(composition_prompt, return_latent=True)
|
||||||
|
# while samples.shape[-1] * 8 < target_width:
|
||||||
|
# samples = upscale_latent(samples)
|
||||||
|
#
|
||||||
|
# img = model_latent_to_pillow_img(samples)
|
||||||
|
|
||||||
|
img = img.resize(
|
||||||
|
(target_width, target_height),
|
||||||
|
resample=Image.Resampling.LANCZOS,
|
||||||
|
)
|
||||||
|
|
||||||
|
return img, result.images["generated"]
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import importlib
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
import platform
|
import platform
|
||||||
|
import re
|
||||||
import time
|
import time
|
||||||
from contextlib import contextmanager, nullcontext
|
from contextlib import contextmanager, nullcontext
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
@ -315,3 +316,7 @@ def get_nested_attribute(obj, attribute_path, depth=None, return_key=False):
|
|||||||
current_attribute = getattr(current_attribute, attribute)
|
current_attribute = getattr(current_attribute, attribute)
|
||||||
|
|
||||||
return (current_attribute, current_key) if return_key else current_attribute
|
return (current_attribute, current_key) if return_key else current_attribute
|
||||||
|
|
||||||
|
|
||||||
|
def prompt_normalized(prompt, length=130):
|
||||||
|
return re.sub(r"[^a-zA-Z0-9.,\[\]-]+", "_", prompt)[:length]
|
||||||
|
@ -19,6 +19,7 @@ from PIL import Image, ImageDraw, ImageFont
|
|||||||
|
|
||||||
from imaginairy.schema import LazyLoadingImage
|
from imaginairy.schema import LazyLoadingImage
|
||||||
from imaginairy.utils import get_device
|
from imaginairy.utils import get_device
|
||||||
|
from imaginairy.utils.named_resolutions import normalize_image_size
|
||||||
from imaginairy.utils.paths import PKG_ROOT
|
from imaginairy.utils.paths import PKG_ROOT
|
||||||
|
|
||||||
|
|
||||||
@ -221,3 +222,39 @@ def create_halo_effect(
|
|||||||
new_canvas.paste(transparent_image, (0, 0), transparent_image)
|
new_canvas.paste(transparent_image, (0, 0), transparent_image)
|
||||||
|
|
||||||
return new_canvas
|
return new_canvas
|
||||||
|
|
||||||
|
|
||||||
|
def combine_image(original_img, generated_img, mask_img):
|
||||||
|
"""Combine the generated image with the original image using the mask image."""
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from imaginairy.utils.log_utils import log_img
|
||||||
|
|
||||||
|
generated_img = generated_img.resize(
|
||||||
|
original_img.size,
|
||||||
|
resample=Image.Resampling.LANCZOS,
|
||||||
|
)
|
||||||
|
|
||||||
|
mask_for_orig_size = mask_img.resize(
|
||||||
|
original_img.size,
|
||||||
|
resample=Image.Resampling.LANCZOS,
|
||||||
|
)
|
||||||
|
log_img(mask_for_orig_size, "mask for original image size")
|
||||||
|
|
||||||
|
rebuilt_orig_img = Image.composite(
|
||||||
|
original_img,
|
||||||
|
generated_img,
|
||||||
|
mask_for_orig_size,
|
||||||
|
)
|
||||||
|
return rebuilt_orig_img
|
||||||
|
|
||||||
|
|
||||||
|
def calc_scale_to_fit_within(height: int, width: int, max_size) -> float:
|
||||||
|
max_width, max_height = normalize_image_size(max_size)
|
||||||
|
if width <= max_width and height <= max_height:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
width_ratio = max_width / width
|
||||||
|
height_ratio = max_height / height
|
||||||
|
|
||||||
|
return min(width_ratio, height_ratio)
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
|
|
||||||
_NAMED_RESOLUTIONS = {
|
NAMED_RESOLUTIONS = {
|
||||||
"HD": (1280, 720),
|
"HD": (1280, 720),
|
||||||
"FHD": (1920, 1080),
|
"FHD": (1920, 1080),
|
||||||
"HALF-FHD": (960, 540),
|
"HALF-FHD": (960, 540),
|
||||||
@ -46,7 +46,7 @@ _NAMED_RESOLUTIONS = {
|
|||||||
"SVD": (1024, 576), # stable video diffusion
|
"SVD": (1024, 576), # stable video diffusion
|
||||||
}
|
}
|
||||||
|
|
||||||
_NAMED_RESOLUTIONS = {k.upper(): v for k, v in _NAMED_RESOLUTIONS.items()}
|
NAMED_RESOLUTIONS = {k.upper(): v for k, v in NAMED_RESOLUTIONS.items()}
|
||||||
|
|
||||||
|
|
||||||
def normalize_image_size(resolution: str | int | tuple[int, int]) -> tuple[int, int]:
|
def normalize_image_size(resolution: str | int | tuple[int, int]) -> tuple[int, int]:
|
||||||
@ -66,8 +66,8 @@ def _normalize_image_size(resolution: str | int | tuple[int, int]) -> tuple[int,
|
|||||||
case str():
|
case str():
|
||||||
resolution = resolution.strip().upper()
|
resolution = resolution.strip().upper()
|
||||||
resolution = resolution.replace(" ", "").replace("X", ",").replace("*", ",")
|
resolution = resolution.replace(" ", "").replace("X", ",").replace("*", ",")
|
||||||
if resolution.upper() in _NAMED_RESOLUTIONS:
|
if resolution.upper() in NAMED_RESOLUTIONS:
|
||||||
return _NAMED_RESOLUTIONS[resolution.upper()]
|
return NAMED_RESOLUTIONS[resolution.upper()]
|
||||||
|
|
||||||
# is it WIDTH,HEIGHT format?
|
# is it WIDTH,HEIGHT format?
|
||||||
try:
|
try:
|
||||||
|
Loading…
Reference in New Issue
Block a user