|
|
|
@ -2,7 +2,7 @@ import logging
|
|
|
|
|
import os
|
|
|
|
|
import re
|
|
|
|
|
|
|
|
|
|
from imaginairy.schema import SafetyMode
|
|
|
|
|
from imaginairy.schema import ControlNetInput, SafetyMode
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
@ -195,7 +195,8 @@ def _generate_single_image(
|
|
|
|
|
progress_img_interval_min_s=0.1,
|
|
|
|
|
half_mode=None,
|
|
|
|
|
add_caption=False,
|
|
|
|
|
suppress_inpaint=False,
|
|
|
|
|
# controlnet, finetune, naive, auto
|
|
|
|
|
inpaint_method="finetune",
|
|
|
|
|
return_latent=False,
|
|
|
|
|
):
|
|
|
|
|
import torch.nn
|
|
|
|
@ -243,15 +244,26 @@ def _generate_single_image(
|
|
|
|
|
_, img_type = prompt.mask_image.strip("*").split(".")
|
|
|
|
|
prompt.mask_image = _most_recent_result.images[img_type]
|
|
|
|
|
control_modes = []
|
|
|
|
|
if prompt.control_inputs:
|
|
|
|
|
control_inputs = prompt.control_inputs or []
|
|
|
|
|
control_inputs = control_inputs.copy()
|
|
|
|
|
for_inpainting = bool(prompt.mask_image or prompt.mask_prompt or prompt.outpaint)
|
|
|
|
|
|
|
|
|
|
if control_inputs:
|
|
|
|
|
control_modes = [c.mode for c in prompt.control_inputs]
|
|
|
|
|
if inpaint_method == "auto":
|
|
|
|
|
if prompt.model in {"SD-1.5", "SD-2.0"}:
|
|
|
|
|
inpaint_method = "finetune"
|
|
|
|
|
else:
|
|
|
|
|
inpaint_method = "controlnet"
|
|
|
|
|
|
|
|
|
|
if for_inpainting and inpaint_method == "controlnet":
|
|
|
|
|
control_modes.append("inpaint")
|
|
|
|
|
model = get_diffusion_model(
|
|
|
|
|
weights_location=prompt.model,
|
|
|
|
|
config_path=prompt.model_config_path,
|
|
|
|
|
control_weights_locations=control_modes,
|
|
|
|
|
half_mode=half_mode,
|
|
|
|
|
for_inpainting=(prompt.mask_image or prompt.mask_prompt or prompt.outpaint)
|
|
|
|
|
and not suppress_inpaint,
|
|
|
|
|
for_inpainting=for_inpainting and inpaint_method == "finetune",
|
|
|
|
|
)
|
|
|
|
|
is_controlnet_model = hasattr(model, "control_key")
|
|
|
|
|
|
|
|
|
@ -303,6 +315,7 @@ def _generate_single_image(
|
|
|
|
|
result_images = {}
|
|
|
|
|
seed_everything(prompt.seed)
|
|
|
|
|
noise = randn_seeded(seed=prompt.seed, size=shape).to(get_device())
|
|
|
|
|
control_strengths = []
|
|
|
|
|
|
|
|
|
|
if prompt.init_image:
|
|
|
|
|
starting_image = prompt.init_image
|
|
|
|
@ -330,6 +343,14 @@ def _generate_single_image(
|
|
|
|
|
max_height=prompt.height,
|
|
|
|
|
max_width=prompt.width,
|
|
|
|
|
)
|
|
|
|
|
init_image_t = pillow_img_to_torch_image(init_image)
|
|
|
|
|
init_image_t = init_image_t.to(get_device())
|
|
|
|
|
init_latent = model.get_first_stage_encoding(
|
|
|
|
|
model.encode_first_stage(init_image_t)
|
|
|
|
|
)
|
|
|
|
|
shape = init_latent.shape
|
|
|
|
|
|
|
|
|
|
log_latent(init_latent, "init_latent")
|
|
|
|
|
|
|
|
|
|
if mask_image is not None:
|
|
|
|
|
mask_image = pillow_fit_image_within(
|
|
|
|
@ -349,15 +370,12 @@ def _generate_single_image(
|
|
|
|
|
mask_latent = pillow_mask_to_latent_mask(
|
|
|
|
|
mask_image, downsampling_factor=downsampling_factor
|
|
|
|
|
).to(get_device())
|
|
|
|
|
if inpaint_method == "controlnet":
|
|
|
|
|
result_images["control-inpaint"] = mask_image
|
|
|
|
|
control_inputs.append(
|
|
|
|
|
ControlNetInput(mode="inpaint", image=mask_image)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
init_image_t = pillow_img_to_torch_image(init_image)
|
|
|
|
|
init_image_t = init_image_t.to(get_device())
|
|
|
|
|
init_latent = model.get_first_stage_encoding(
|
|
|
|
|
model.encode_first_stage(init_image_t)
|
|
|
|
|
)
|
|
|
|
|
shape = init_latent.shape
|
|
|
|
|
|
|
|
|
|
log_latent(init_latent, "init_latent")
|
|
|
|
|
seed_everything(prompt.seed)
|
|
|
|
|
noise = randn_seeded(seed=prompt.seed, size=init_latent.shape).to(
|
|
|
|
|
get_device()
|
|
|
|
@ -398,7 +416,7 @@ def _generate_single_image(
|
|
|
|
|
elif is_controlnet_model:
|
|
|
|
|
from imaginairy.img_processors.control_modes import CONTROL_MODES
|
|
|
|
|
|
|
|
|
|
for control_input in prompt.control_inputs:
|
|
|
|
|
for control_input in control_inputs:
|
|
|
|
|
if control_input.image_raw is not None:
|
|
|
|
|
control_image = control_input.image_raw
|
|
|
|
|
elif control_input.image is not None:
|
|
|
|
@ -414,9 +432,13 @@ def _generate_single_image(
|
|
|
|
|
control_image_input_t = control_image_input_t.to(get_device())
|
|
|
|
|
|
|
|
|
|
if control_input.image_raw is None:
|
|
|
|
|
control_image_t = CONTROL_MODES[control_input.mode](
|
|
|
|
|
control_image_input_t
|
|
|
|
|
)
|
|
|
|
|
control_prep_function = CONTROL_MODES[control_input.mode]
|
|
|
|
|
if control_input.mode == "inpaint":
|
|
|
|
|
control_image_t = control_prep_function(
|
|
|
|
|
control_image_input_t, init_image_t
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
control_image_t = control_prep_function(control_image_input_t)
|
|
|
|
|
else:
|
|
|
|
|
control_image_t = (control_image_input_t + 1) / 2
|
|
|
|
|
|
|
|
|
@ -430,10 +452,11 @@ def _generate_single_image(
|
|
|
|
|
if control_image_t.shape[1] != 3:
|
|
|
|
|
raise RuntimeError("Control image must have 3 channels")
|
|
|
|
|
|
|
|
|
|
if control_image_t.min() < 0 or control_image_t.max() > 1:
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
f"Control image must be in [0, 1] but we received {control_image_t.min()} and {control_image_t.max()}"
|
|
|
|
|
)
|
|
|
|
|
if control_input.mode != "inpaint":
|
|
|
|
|
if control_image_t.min() < 0 or control_image_t.max() > 1:
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
f"Control image must be in [0, 1] but we received {control_image_t.min()} and {control_image_t.max()}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if control_image_t.max() == control_image_t.min():
|
|
|
|
|
raise RuntimeError(
|
|
|
|
@ -441,6 +464,7 @@ def _generate_single_image(
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
c_cat.append(control_image_t)
|
|
|
|
|
control_strengths.append(control_input.strength)
|
|
|
|
|
|
|
|
|
|
elif hasattr(model, "masked_image_key"):
|
|
|
|
|
# inpainting model
|
|
|
|
@ -481,6 +505,10 @@ def _generate_single_image(
|
|
|
|
|
"c_crossattn": [neutral_conditioning],
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if control_strengths and is_controlnet_model:
|
|
|
|
|
positive_conditioning["control_strengths"] = torch.Tensor(control_strengths)
|
|
|
|
|
neutral_conditioning["control_strengths"] = torch.Tensor(control_strengths)
|
|
|
|
|
|
|
|
|
|
if (
|
|
|
|
|
prompt.allow_compose_phase
|
|
|
|
|
and not is_controlnet_model
|
|
|
|
|