|
|
|
@ -378,80 +378,88 @@ def _generate_single_image(
|
|
|
|
|
noise=noise,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if hasattr(model, "depth_stage_key"):
|
|
|
|
|
# depth model
|
|
|
|
|
depth_t = torch_image_to_depth_map(init_image_t)
|
|
|
|
|
depth_latent = torch.nn.functional.interpolate(
|
|
|
|
|
depth_t,
|
|
|
|
|
size=shape[2:],
|
|
|
|
|
mode="bicubic",
|
|
|
|
|
align_corners=False,
|
|
|
|
|
)
|
|
|
|
|
result_images["depth_image"] = depth_t
|
|
|
|
|
c_cat.append(depth_latent)
|
|
|
|
|
|
|
|
|
|
elif is_controlnet_model and starting_image:
|
|
|
|
|
from imaginairy.img_processors.control_modes import CONTROL_MODES
|
|
|
|
|
|
|
|
|
|
control_image_input = pillow_fit_image_within(
|
|
|
|
|
prompt.control_image,
|
|
|
|
|
max_height=prompt.height,
|
|
|
|
|
max_width=prompt.width,
|
|
|
|
|
)
|
|
|
|
|
control_image_input_t = pillow_img_to_torch_image(control_image_input)
|
|
|
|
|
control_image_input_t = control_image_input_t.to(get_device())
|
|
|
|
|
if hasattr(model, "depth_stage_key"):
|
|
|
|
|
# depth model
|
|
|
|
|
depth_t = torch_image_to_depth_map(init_image_t)
|
|
|
|
|
depth_latent = torch.nn.functional.interpolate(
|
|
|
|
|
depth_t,
|
|
|
|
|
size=shape[2:],
|
|
|
|
|
mode="bicubic",
|
|
|
|
|
align_corners=False,
|
|
|
|
|
)
|
|
|
|
|
result_images["depth_image"] = depth_t
|
|
|
|
|
c_cat.append(depth_latent)
|
|
|
|
|
|
|
|
|
|
elif is_controlnet_model:
|
|
|
|
|
from imaginairy.img_processors.control_modes import CONTROL_MODES
|
|
|
|
|
|
|
|
|
|
if prompt.control_image_raw is not None:
|
|
|
|
|
control_image = prompt.control_image_raw
|
|
|
|
|
elif prompt.control_image is not None:
|
|
|
|
|
control_image = prompt.control_image
|
|
|
|
|
control_image = control_image.convert("RGB")
|
|
|
|
|
log_img(control_image, "control_image_input")
|
|
|
|
|
control_image_input = pillow_fit_image_within(
|
|
|
|
|
control_image,
|
|
|
|
|
max_height=prompt.height,
|
|
|
|
|
max_width=prompt.width,
|
|
|
|
|
)
|
|
|
|
|
control_image_input_t = pillow_img_to_torch_image(control_image_input)
|
|
|
|
|
control_image_input_t = control_image_input_t.to(get_device())
|
|
|
|
|
|
|
|
|
|
control_image = CONTROL_MODES[prompt.control_mode](
|
|
|
|
|
if prompt.control_image_raw is None:
|
|
|
|
|
control_image_t = CONTROL_MODES[prompt.control_mode](
|
|
|
|
|
control_image_input_t
|
|
|
|
|
)
|
|
|
|
|
if len(control_image.shape) == 3:
|
|
|
|
|
raise RuntimeError("Control image must be 4D")
|
|
|
|
|
else:
|
|
|
|
|
control_image_t = (control_image_input_t + 1) / 2
|
|
|
|
|
|
|
|
|
|
if control_image.shape[1] != 3:
|
|
|
|
|
raise RuntimeError("Control image must have 3 channels")
|
|
|
|
|
control_image_disp = control_image_t * 2 - 1
|
|
|
|
|
result_images["control"] = control_image_disp[:, [2, 0, 1], :, :]
|
|
|
|
|
log_img(control_image_disp, "control_image")
|
|
|
|
|
|
|
|
|
|
if control_image.min() < 0 or control_image.max() > 1:
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
f"Control image must be in [0, 1] but we received {control_image.min()} and {control_image.max()}"
|
|
|
|
|
)
|
|
|
|
|
if len(control_image_t.shape) == 3:
|
|
|
|
|
raise RuntimeError("Control image must be 4D")
|
|
|
|
|
|
|
|
|
|
if control_image.max() <= 0.5:
|
|
|
|
|
raise RuntimeError("Control image must be in [0, 1]")
|
|
|
|
|
if control_image_t.shape[1] != 3:
|
|
|
|
|
raise RuntimeError("Control image must have 3 channels")
|
|
|
|
|
|
|
|
|
|
if control_image.min() >= 0.5:
|
|
|
|
|
raise RuntimeError("Control image must be in [0, 1]")
|
|
|
|
|
if control_image_t.min() < 0 or control_image_t.max() > 1:
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
f"Control image must be in [0, 1] but we received {control_image_t.min()} and {control_image_t.max()}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
control_image_disp = control_image * 2 - 1
|
|
|
|
|
result_images["control"] = control_image_disp[:, [2, 0, 1], :, :]
|
|
|
|
|
log_img(control_image_disp, "control_image")
|
|
|
|
|
c_cat.append(control_image)
|
|
|
|
|
if control_image_t.max() == control_image_t.min():
|
|
|
|
|
raise RuntimeError("No control signal found in control image.")
|
|
|
|
|
|
|
|
|
|
elif hasattr(model, "masked_image_key"):
|
|
|
|
|
# inpainting model
|
|
|
|
|
mask_t = pillow_img_to_torch_image(ImageOps.invert(mask_image_orig)).to(
|
|
|
|
|
get_device()
|
|
|
|
|
)
|
|
|
|
|
inverted_mask = 1 - mask_latent
|
|
|
|
|
masked_image_t = init_image_t * (mask_t < 0.5)
|
|
|
|
|
log_img(masked_image_t, "masked_image")
|
|
|
|
|
c_cat.append(control_image_t)
|
|
|
|
|
|
|
|
|
|
inverted_mask_latent = torch.nn.functional.interpolate(
|
|
|
|
|
inverted_mask, size=shape[-2:]
|
|
|
|
|
)
|
|
|
|
|
c_cat.append(inverted_mask_latent)
|
|
|
|
|
elif hasattr(model, "masked_image_key"):
|
|
|
|
|
# inpainting model
|
|
|
|
|
mask_t = pillow_img_to_torch_image(ImageOps.invert(mask_image_orig)).to(
|
|
|
|
|
get_device()
|
|
|
|
|
)
|
|
|
|
|
inverted_mask = 1 - mask_latent
|
|
|
|
|
masked_image_t = init_image_t * (mask_t < 0.5)
|
|
|
|
|
log_img(masked_image_t, "masked_image")
|
|
|
|
|
|
|
|
|
|
masked_image_latent = model.get_first_stage_encoding(
|
|
|
|
|
model.encode_first_stage(masked_image_t)
|
|
|
|
|
)
|
|
|
|
|
c_cat.append(masked_image_latent)
|
|
|
|
|
|
|
|
|
|
elif model.cond_stage_key == "edit":
|
|
|
|
|
# pix2pix model
|
|
|
|
|
c_cat = [model.encode_first_stage(init_image_t)]
|
|
|
|
|
c_cat_neutral = [torch.zeros_like(init_latent)]
|
|
|
|
|
denoiser_cls = CFGEditingDenoiser
|
|
|
|
|
if c_cat:
|
|
|
|
|
c_cat = [torch.cat(c_cat, dim=1)]
|
|
|
|
|
inverted_mask_latent = torch.nn.functional.interpolate(
|
|
|
|
|
inverted_mask, size=shape[-2:]
|
|
|
|
|
)
|
|
|
|
|
c_cat.append(inverted_mask_latent)
|
|
|
|
|
|
|
|
|
|
masked_image_latent = model.get_first_stage_encoding(
|
|
|
|
|
model.encode_first_stage(masked_image_t)
|
|
|
|
|
)
|
|
|
|
|
c_cat.append(masked_image_latent)
|
|
|
|
|
|
|
|
|
|
elif model.cond_stage_key == "edit":
|
|
|
|
|
# pix2pix model
|
|
|
|
|
c_cat = [model.encode_first_stage(init_image_t)]
|
|
|
|
|
c_cat_neutral = [torch.zeros_like(init_latent)]
|
|
|
|
|
denoiser_cls = CFGEditingDenoiser
|
|
|
|
|
if c_cat:
|
|
|
|
|
c_cat = [torch.cat(c_cat, dim=1)]
|
|
|
|
|
|
|
|
|
|
if c_cat_neutral is None:
|
|
|
|
|
c_cat_neutral = c_cat
|
|
|
|
|