|
|
|
@ -1,12 +1,12 @@
|
|
|
|
|
import logging
|
|
|
|
|
import os
|
|
|
|
|
import re
|
|
|
|
|
from typing import TYPE_CHECKING, Callable
|
|
|
|
|
from typing import TYPE_CHECKING, Any, Callable
|
|
|
|
|
|
|
|
|
|
from imaginairy.utils.named_resolutions import normalize_image_size
|
|
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
|
from imaginairy.schema import ImaginePrompt
|
|
|
|
|
from imaginairy.schema import ImaginePrompt, LazyLoadingImage
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
@ -335,20 +335,24 @@ def _generate_single_image_compvis(
|
|
|
|
|
]
|
|
|
|
|
SolverCls = SOLVER_LOOKUP[prompt.solver_type.lower()]
|
|
|
|
|
solver = SolverCls(model)
|
|
|
|
|
mask_latent = mask_image = mask_image_orig = mask_grayscale = None
|
|
|
|
|
t_enc = init_latent = control_image = None
|
|
|
|
|
mask_image: Image.Image | LazyLoadingImage | None = None
|
|
|
|
|
mask_latent = mask_image_orig = mask_grayscale = None
|
|
|
|
|
init_latent: torch.Tensor | None = None
|
|
|
|
|
t_enc = None
|
|
|
|
|
starting_image = None
|
|
|
|
|
denoiser_cls = None
|
|
|
|
|
|
|
|
|
|
c_cat = []
|
|
|
|
|
c_cat_neutral = None
|
|
|
|
|
result_images = {}
|
|
|
|
|
result_images: dict[str, torch.Tensor | Image.Image | None] = {}
|
|
|
|
|
assert prompt.seed is not None
|
|
|
|
|
seed_everything(prompt.seed)
|
|
|
|
|
noise = randn_seeded(seed=prompt.seed, size=shape).to(get_device())
|
|
|
|
|
control_strengths = []
|
|
|
|
|
|
|
|
|
|
if prompt.init_image:
|
|
|
|
|
starting_image = prompt.init_image
|
|
|
|
|
assert prompt.init_image_strength is not None
|
|
|
|
|
generation_strength = 1 - prompt.init_image_strength
|
|
|
|
|
|
|
|
|
|
if model.cond_stage_key == "edit" or generation_strength >= 1:
|
|
|
|
@ -367,18 +371,18 @@ def _generate_single_image_compvis(
|
|
|
|
|
starting_image, mask_image = prepare_image_for_outpaint(
|
|
|
|
|
starting_image, mask_image, **outpaint_kwargs
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
assert starting_image is not None
|
|
|
|
|
init_image = pillow_fit_image_within(
|
|
|
|
|
starting_image,
|
|
|
|
|
max_height=prompt.height,
|
|
|
|
|
max_width=prompt.width,
|
|
|
|
|
)
|
|
|
|
|
init_image_t = pillow_img_to_torch_image(init_image)
|
|
|
|
|
init_image_t = init_image_t.to(get_device())
|
|
|
|
|
init_image_t = pillow_img_to_torch_image(init_image).to(get_device())
|
|
|
|
|
init_latent = model.get_first_stage_encoding(
|
|
|
|
|
model.encode_first_stage(init_image_t)
|
|
|
|
|
)
|
|
|
|
|
shape = init_latent.shape
|
|
|
|
|
assert init_latent is not None
|
|
|
|
|
shape = list(init_latent.shape)
|
|
|
|
|
|
|
|
|
|
log_latent(init_latent, "init_latent")
|
|
|
|
|
|
|
|
|
@ -405,9 +409,9 @@ def _generate_single_image_compvis(
|
|
|
|
|
control_inputs.append(
|
|
|
|
|
ControlInput(mode="inpaint", image=mask_image)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
assert prompt.seed is not None
|
|
|
|
|
seed_everything(prompt.seed)
|
|
|
|
|
noise = randn_seeded(seed=prompt.seed, size=init_latent.shape).to(
|
|
|
|
|
noise = randn_seeded(seed=prompt.seed, size=list(init_latent.shape)).to(
|
|
|
|
|
get_device()
|
|
|
|
|
)
|
|
|
|
|
# noise = noise[:, :, : init_latent.shape[2], : init_latent.shape[3]]
|
|
|
|
@ -451,8 +455,13 @@ def _generate_single_image_compvis(
|
|
|
|
|
control_image = control_input.image_raw
|
|
|
|
|
elif control_input.image is not None:
|
|
|
|
|
control_image = control_input.image
|
|
|
|
|
else:
|
|
|
|
|
raise RuntimeError("Control image must be provided")
|
|
|
|
|
assert control_image is not None
|
|
|
|
|
control_image = control_image.convert("RGB")
|
|
|
|
|
log_img(control_image, "control_image_input")
|
|
|
|
|
assert control_image is not None
|
|
|
|
|
|
|
|
|
|
control_image_input = pillow_fit_image_within(
|
|
|
|
|
control_image,
|
|
|
|
|
max_height=prompt.height,
|
|
|
|
@ -464,11 +473,11 @@ def _generate_single_image_compvis(
|
|
|
|
|
if control_input.image_raw is None:
|
|
|
|
|
control_prep_function = CONTROL_MODES[control_input.mode]
|
|
|
|
|
if control_input.mode == "inpaint":
|
|
|
|
|
control_image_t = control_prep_function(
|
|
|
|
|
control_image_t = control_prep_function( # type: ignore
|
|
|
|
|
control_image_input_t, init_image_t
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
control_image_t = control_prep_function(control_image_input_t)
|
|
|
|
|
control_image_t = control_prep_function(control_image_input_t) # type: ignore
|
|
|
|
|
else:
|
|
|
|
|
control_image_t = (control_image_input_t + 1) / 2
|
|
|
|
|
|
|
|
|
@ -499,6 +508,8 @@ def _generate_single_image_compvis(
|
|
|
|
|
|
|
|
|
|
elif hasattr(model, "masked_image_key"):
|
|
|
|
|
# inpainting model
|
|
|
|
|
assert mask_image_orig is not None
|
|
|
|
|
assert mask_latent is not None
|
|
|
|
|
mask_t = pillow_img_to_torch_image(ImageOps.invert(mask_image_orig)).to(
|
|
|
|
|
get_device()
|
|
|
|
|
)
|
|
|
|
@ -519,6 +530,7 @@ def _generate_single_image_compvis(
|
|
|
|
|
elif model.cond_stage_key == "edit":
|
|
|
|
|
# pix2pix model
|
|
|
|
|
c_cat = [model.encode_first_stage(init_image_t)]
|
|
|
|
|
assert init_latent is not None
|
|
|
|
|
c_cat_neutral = [torch.zeros_like(init_latent)]
|
|
|
|
|
denoiser_cls = CFGEditingDenoiser
|
|
|
|
|
if c_cat:
|
|
|
|
@ -527,18 +539,24 @@ def _generate_single_image_compvis(
|
|
|
|
|
if c_cat_neutral is None:
|
|
|
|
|
c_cat_neutral = c_cat
|
|
|
|
|
|
|
|
|
|
positive_conditioning = {
|
|
|
|
|
positive_conditioning_d: dict[str, Any] = {
|
|
|
|
|
"c_concat": c_cat,
|
|
|
|
|
"c_crossattn": [positive_conditioning],
|
|
|
|
|
}
|
|
|
|
|
neutral_conditioning = {
|
|
|
|
|
neutral_conditioning_d: dict[str, Any] = {
|
|
|
|
|
"c_concat": c_cat_neutral,
|
|
|
|
|
"c_crossattn": [neutral_conditioning],
|
|
|
|
|
}
|
|
|
|
|
del neutral_conditioning
|
|
|
|
|
del positive_conditioning
|
|
|
|
|
|
|
|
|
|
if control_strengths and is_controlnet_model:
|
|
|
|
|
positive_conditioning["control_strengths"] = torch.Tensor(control_strengths)
|
|
|
|
|
neutral_conditioning["control_strengths"] = torch.Tensor(control_strengths)
|
|
|
|
|
positive_conditioning_d["control_strengths"] = torch.Tensor(
|
|
|
|
|
control_strengths
|
|
|
|
|
)
|
|
|
|
|
neutral_conditioning_d["control_strengths"] = torch.Tensor(
|
|
|
|
|
control_strengths
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if (
|
|
|
|
|
prompt.allow_compose_phase
|
|
|
|
@ -575,8 +593,8 @@ def _generate_single_image_compvis(
|
|
|
|
|
with lc.timing("sampling"):
|
|
|
|
|
samples = solver.sample(
|
|
|
|
|
num_steps=prompt.steps,
|
|
|
|
|
positive_conditioning=positive_conditioning,
|
|
|
|
|
neutral_conditioning=neutral_conditioning,
|
|
|
|
|
positive_conditioning=positive_conditioning_d,
|
|
|
|
|
neutral_conditioning=neutral_conditioning_d,
|
|
|
|
|
guidance_scale=prompt.prompt_strength,
|
|
|
|
|
t_start=t_enc,
|
|
|
|
|
mask=mask_latent,
|
|
|
|
|