feature: allow input of raw control images

pull/276/head
Bryce 1 year ago committed by Bryce Drennan
parent 8f56e14dc7
commit d5cff45bff

@ -378,80 +378,88 @@ def _generate_single_image(
noise=noise,
)
if hasattr(model, "depth_stage_key"):
# depth model
depth_t = torch_image_to_depth_map(init_image_t)
depth_latent = torch.nn.functional.interpolate(
depth_t,
size=shape[2:],
mode="bicubic",
align_corners=False,
)
result_images["depth_image"] = depth_t
c_cat.append(depth_latent)
elif is_controlnet_model and starting_image:
from imaginairy.img_processors.control_modes import CONTROL_MODES
control_image_input = pillow_fit_image_within(
prompt.control_image,
max_height=prompt.height,
max_width=prompt.width,
)
control_image_input_t = pillow_img_to_torch_image(control_image_input)
control_image_input_t = control_image_input_t.to(get_device())
if hasattr(model, "depth_stage_key"):
# depth model
depth_t = torch_image_to_depth_map(init_image_t)
depth_latent = torch.nn.functional.interpolate(
depth_t,
size=shape[2:],
mode="bicubic",
align_corners=False,
)
result_images["depth_image"] = depth_t
c_cat.append(depth_latent)
elif is_controlnet_model:
from imaginairy.img_processors.control_modes import CONTROL_MODES
if prompt.control_image_raw is not None:
control_image = prompt.control_image_raw
elif prompt.control_image is not None:
control_image = prompt.control_image
control_image = control_image.convert("RGB")
log_img(control_image, "control_image_input")
control_image_input = pillow_fit_image_within(
control_image,
max_height=prompt.height,
max_width=prompt.width,
)
control_image_input_t = pillow_img_to_torch_image(control_image_input)
control_image_input_t = control_image_input_t.to(get_device())
control_image = CONTROL_MODES[prompt.control_mode](
if prompt.control_image_raw is None:
control_image_t = CONTROL_MODES[prompt.control_mode](
control_image_input_t
)
if len(control_image.shape) == 3:
raise RuntimeError("Control image must be 4D")
else:
control_image_t = (control_image_input_t + 1) / 2
if control_image.shape[1] != 3:
raise RuntimeError("Control image must have 3 channels")
control_image_disp = control_image_t * 2 - 1
result_images["control"] = control_image_disp[:, [2, 0, 1], :, :]
log_img(control_image_disp, "control_image")
if control_image.min() < 0 or control_image.max() > 1:
raise RuntimeError(
f"Control image must be in [0, 1] but we received {control_image.min()} and {control_image.max()}"
)
if len(control_image_t.shape) == 3:
raise RuntimeError("Control image must be 4D")
if control_image.max() <= 0.5:
raise RuntimeError("Control image must be in [0, 1]")
if control_image_t.shape[1] != 3:
raise RuntimeError("Control image must have 3 channels")
if control_image.min() >= 0.5:
raise RuntimeError("Control image must be in [0, 1]")
if control_image_t.min() < 0 or control_image_t.max() > 1:
raise RuntimeError(
f"Control image must be in [0, 1] but we received {control_image_t.min()} and {control_image_t.max()}"
)
control_image_disp = control_image * 2 - 1
result_images["control"] = control_image_disp[:, [2, 0, 1], :, :]
log_img(control_image_disp, "control_image")
c_cat.append(control_image)
if control_image_t.max() == control_image_t.min():
raise RuntimeError("No control signal found in control image.")
elif hasattr(model, "masked_image_key"):
# inpainting model
mask_t = pillow_img_to_torch_image(ImageOps.invert(mask_image_orig)).to(
get_device()
)
inverted_mask = 1 - mask_latent
masked_image_t = init_image_t * (mask_t < 0.5)
log_img(masked_image_t, "masked_image")
c_cat.append(control_image_t)
inverted_mask_latent = torch.nn.functional.interpolate(
inverted_mask, size=shape[-2:]
)
c_cat.append(inverted_mask_latent)
elif hasattr(model, "masked_image_key"):
# inpainting model
mask_t = pillow_img_to_torch_image(ImageOps.invert(mask_image_orig)).to(
get_device()
)
inverted_mask = 1 - mask_latent
masked_image_t = init_image_t * (mask_t < 0.5)
log_img(masked_image_t, "masked_image")
masked_image_latent = model.get_first_stage_encoding(
model.encode_first_stage(masked_image_t)
)
c_cat.append(masked_image_latent)
elif model.cond_stage_key == "edit":
# pix2pix model
c_cat = [model.encode_first_stage(init_image_t)]
c_cat_neutral = [torch.zeros_like(init_latent)]
denoiser_cls = CFGEditingDenoiser
if c_cat:
c_cat = [torch.cat(c_cat, dim=1)]
inverted_mask_latent = torch.nn.functional.interpolate(
inverted_mask, size=shape[-2:]
)
c_cat.append(inverted_mask_latent)
masked_image_latent = model.get_first_stage_encoding(
model.encode_first_stage(masked_image_t)
)
c_cat.append(masked_image_latent)
elif model.cond_stage_key == "edit":
# pix2pix model
c_cat = [model.encode_first_stage(init_image_t)]
c_cat_neutral = [torch.zeros_like(init_latent)]
denoiser_cls = CFGEditingDenoiser
if c_cat:
c_cat = [torch.cat(c_cat, dim=1)]
if c_cat_neutral is None:
c_cat_neutral = c_cat

@ -195,7 +195,7 @@ class ImaginePrompt:
):
self.control_image = self.init_image
if self.control_mode and not self.control_image:
if self.control_mode and not (self.control_image or self.control_image_raw):
raise ValueError("You must set `control_image` when using `control_mode`")
if self.mask_image is not None and self.mask_prompt is not None:

Loading…
Cancel
Save