refactor: separate controlnet image preprocessing

This commit is contained in:
Bryce 2023-12-17 22:42:11 -08:00 committed by Bryce Drennan
parent 9a0e0cd1a7
commit c6ac5f553a

View File

@ -193,72 +193,15 @@ def _generate_single_image(
controlnets = []
if control_modes:
from imaginairy.img_processors.control_modes import CONTROL_MODES
for control_input in control_inputs:
if control_input.image_raw is not None:
control_image = control_input.image_raw
elif control_input.image is not None:
control_image = control_input.image
control_image = control_image.convert("RGB")
log_img(control_image, "control_image_input")
control_image_input = pillow_fit_image_within(
control_image,
max_height=prompt.height,
max_width=prompt.width,
controlnet, control_image_t, control_image_disp = prep_control_input(
control_input=control_input,
sd=sd,
init_image_t=init_image_t,
fit_width=prompt.width,
fit_height=prompt.height,
)
if control_input.mode == "inpaint":
control_image_input = ImageOps.invert(control_image_input)
control_image_input_t = pillow_img_to_torch_image(control_image_input)
control_image_input_t = control_image_input_t.to(get_device())
if control_input.image_raw is None:
control_prep_function = CONTROL_MODES[control_input.mode]
if control_input.mode == "inpaint":
control_image_t = control_prep_function( # type: ignore
control_image_input_t, init_image_t
)
else:
control_image_t = control_prep_function(control_image_input_t) # type: ignore
else:
control_image_t = (control_image_input_t + 1) / 2
control_image_disp = control_image_t * 2 - 1
result_images[f"control-{control_input.mode}"] = control_image_disp
log_img(control_image_disp, "control_image")
if len(control_image_t.shape) == 3:
raise ValueError("Control image must be 4D")
if control_image_t.shape[1] != 3:
raise ValueError("Control image must have 3 channels")
if (
control_input.mode != "inpaint"
and control_image_t.min() < 0
or control_image_t.max() > 1
):
msg = f"Control image must be in [0, 1] but we received {control_image_t.min()} and {control_image_t.max()}"
raise ValueError(msg)
if control_image_t.max() == control_image_t.min():
msg = f"No control signal found in control image {control_input.mode}."
raise ValueError(msg)
control_config = CONTROL_CONFIG_SHORTCUTS.get(control_input.mode, None)
if not control_config:
msg = f"Unknown control mode: {control_input.mode}"
raise ValueError(msg)
from refiners.foundationals.latent_diffusion import SD1ControlnetAdapter
controlnet = SD1ControlnetAdapter( # type: ignore
name=control_input.mode,
target=sd.unet, # type: ignore
weights_location=control_config.weights_location,
)
controlnet.set_scale(control_input.strength)
controlnets.append((controlnet, control_image_t))
if prompt.allow_compose_phase:
@ -292,6 +235,24 @@ def _generate_single_image(
comp_image_t = pillow_img_to_torch_image(comp_image)
comp_image_t = comp_image_t.to(sd.device, dtype=sd.dtype)
init_latent = sd.lda.encode(comp_image_t)
compose_control_inputs: list[ControlInput] = [
# ControlInput(mode="depth", image=comp_image, strength=1),
# ControlInput(mode="hed", image=comp_image, strength=1),
]
for control_input in compose_control_inputs:
(
controlnet,
control_image_t,
control_image_disp,
) = prep_control_input(
control_input=control_input,
sd=sd,
init_image_t=None,
fit_width=prompt.width,
fit_height=prompt.height,
)
result_images[f"control-{control_input.mode}"] = control_image_disp
controlnets.append((controlnet, control_image_t))
for controlnet, control_image_t in controlnets:
controlnet.set_controlnet_condition(
@ -478,3 +439,87 @@ def clear_gpu_cache():
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def prep_control_input(
control_input: ControlInput, sd, init_image_t, fit_width, fit_height
):
from PIL import ImageOps
from imaginairy.utils import get_device
from imaginairy.utils.img_utils import (
pillow_fit_image_within,
pillow_img_to_torch_image,
)
from imaginairy.utils.log_utils import (
log_img,
)
if control_input.image_raw is not None:
control_image = control_input.image_raw
elif control_input.image is not None:
control_image = control_input.image
else:
raise ValueError("No control image found")
assert control_image is not None
control_image = control_image.convert("RGB")
log_img(control_image, "control_image_input")
control_image_input = pillow_fit_image_within(
control_image,
max_height=fit_height,
max_width=fit_width,
)
if control_input.mode == "inpaint":
control_image_input = ImageOps.invert(control_image_input)
control_image_input_t = pillow_img_to_torch_image(control_image_input)
control_image_input_t = control_image_input_t.to(get_device())
if control_input.image_raw is None:
from imaginairy.img_processors.control_modes import CONTROL_MODES
control_prep_function = CONTROL_MODES[control_input.mode]
if control_input.mode == "inpaint":
control_image_t = control_prep_function( # type: ignore
control_image_input_t, init_image_t
)
else:
control_image_t = control_prep_function(control_image_input_t) # type: ignore
else:
control_image_t = (control_image_input_t + 1) / 2
control_image_disp = control_image_t * 2 - 1
log_img(control_image_disp, "control_image")
if len(control_image_t.shape) == 3:
raise ValueError("Control image must be 4D")
if control_image_t.shape[1] != 3:
raise ValueError("Control image must have 3 channels")
if (
control_input.mode != "inpaint"
and control_image_t.min() < 0
or control_image_t.max() > 1
):
msg = f"Control image must be in [0, 1] but we received {control_image_t.min()} and {control_image_t.max()}"
raise ValueError(msg)
if control_image_t.max() == control_image_t.min():
msg = f"No control signal found in control image {control_input.mode}."
raise ValueError(msg)
control_config = CONTROL_CONFIG_SHORTCUTS.get(control_input.mode, None)
if not control_config:
msg = f"Unknown control mode: {control_input.mode}"
raise ValueError(msg)
from refiners.foundationals.latent_diffusion import SD1ControlnetAdapter
controlnet = SD1ControlnetAdapter( # type: ignore
name=control_input.mode,
target=sd.unet,
weights_location=control_config.weights_location,
)
controlnet.set_scale(control_input.strength)
return controlnet, control_image_t, control_image_disp