mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-10-31 03:20:40 +00:00
refactor: separate controlnet image preprocessing
This commit is contained in:
parent
9a0e0cd1a7
commit
c6ac5f553a
@ -193,72 +193,15 @@ def _generate_single_image(
|
||||
controlnets = []
|
||||
|
||||
if control_modes:
|
||||
from imaginairy.img_processors.control_modes import CONTROL_MODES
|
||||
|
||||
for control_input in control_inputs:
|
||||
if control_input.image_raw is not None:
|
||||
control_image = control_input.image_raw
|
||||
elif control_input.image is not None:
|
||||
control_image = control_input.image
|
||||
control_image = control_image.convert("RGB")
|
||||
log_img(control_image, "control_image_input")
|
||||
control_image_input = pillow_fit_image_within(
|
||||
control_image,
|
||||
max_height=prompt.height,
|
||||
max_width=prompt.width,
|
||||
controlnet, control_image_t, control_image_disp = prep_control_input(
|
||||
control_input=control_input,
|
||||
sd=sd,
|
||||
init_image_t=init_image_t,
|
||||
fit_width=prompt.width,
|
||||
fit_height=prompt.height,
|
||||
)
|
||||
if control_input.mode == "inpaint":
|
||||
control_image_input = ImageOps.invert(control_image_input)
|
||||
|
||||
control_image_input_t = pillow_img_to_torch_image(control_image_input)
|
||||
control_image_input_t = control_image_input_t.to(get_device())
|
||||
|
||||
if control_input.image_raw is None:
|
||||
control_prep_function = CONTROL_MODES[control_input.mode]
|
||||
if control_input.mode == "inpaint":
|
||||
control_image_t = control_prep_function( # type: ignore
|
||||
control_image_input_t, init_image_t
|
||||
)
|
||||
else:
|
||||
control_image_t = control_prep_function(control_image_input_t) # type: ignore
|
||||
else:
|
||||
control_image_t = (control_image_input_t + 1) / 2
|
||||
|
||||
control_image_disp = control_image_t * 2 - 1
|
||||
result_images[f"control-{control_input.mode}"] = control_image_disp
|
||||
log_img(control_image_disp, "control_image")
|
||||
|
||||
if len(control_image_t.shape) == 3:
|
||||
raise ValueError("Control image must be 4D")
|
||||
|
||||
if control_image_t.shape[1] != 3:
|
||||
raise ValueError("Control image must have 3 channels")
|
||||
|
||||
if (
|
||||
control_input.mode != "inpaint"
|
||||
and control_image_t.min() < 0
|
||||
or control_image_t.max() > 1
|
||||
):
|
||||
msg = f"Control image must be in [0, 1] but we received {control_image_t.min()} and {control_image_t.max()}"
|
||||
raise ValueError(msg)
|
||||
|
||||
if control_image_t.max() == control_image_t.min():
|
||||
msg = f"No control signal found in control image {control_input.mode}."
|
||||
raise ValueError(msg)
|
||||
|
||||
control_config = CONTROL_CONFIG_SHORTCUTS.get(control_input.mode, None)
|
||||
if not control_config:
|
||||
msg = f"Unknown control mode: {control_input.mode}"
|
||||
raise ValueError(msg)
|
||||
from refiners.foundationals.latent_diffusion import SD1ControlnetAdapter
|
||||
|
||||
controlnet = SD1ControlnetAdapter( # type: ignore
|
||||
name=control_input.mode,
|
||||
target=sd.unet, # type: ignore
|
||||
weights_location=control_config.weights_location,
|
||||
)
|
||||
controlnet.set_scale(control_input.strength)
|
||||
|
||||
controlnets.append((controlnet, control_image_t))
|
||||
|
||||
if prompt.allow_compose_phase:
|
||||
@ -292,6 +235,24 @@ def _generate_single_image(
|
||||
comp_image_t = pillow_img_to_torch_image(comp_image)
|
||||
comp_image_t = comp_image_t.to(sd.device, dtype=sd.dtype)
|
||||
init_latent = sd.lda.encode(comp_image_t)
|
||||
compose_control_inputs: list[ControlInput] = [
|
||||
# ControlInput(mode="depth", image=comp_image, strength=1),
|
||||
# ControlInput(mode="hed", image=comp_image, strength=1),
|
||||
]
|
||||
for control_input in compose_control_inputs:
|
||||
(
|
||||
controlnet,
|
||||
control_image_t,
|
||||
control_image_disp,
|
||||
) = prep_control_input(
|
||||
control_input=control_input,
|
||||
sd=sd,
|
||||
init_image_t=None,
|
||||
fit_width=prompt.width,
|
||||
fit_height=prompt.height,
|
||||
)
|
||||
result_images[f"control-{control_input.mode}"] = control_image_disp
|
||||
controlnets.append((controlnet, control_image_t))
|
||||
|
||||
for controlnet, control_image_t in controlnets:
|
||||
controlnet.set_controlnet_condition(
|
||||
@ -478,3 +439,87 @@ def clear_gpu_cache():
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def prep_control_input(
|
||||
control_input: ControlInput, sd, init_image_t, fit_width, fit_height
|
||||
):
|
||||
from PIL import ImageOps
|
||||
|
||||
from imaginairy.utils import get_device
|
||||
from imaginairy.utils.img_utils import (
|
||||
pillow_fit_image_within,
|
||||
pillow_img_to_torch_image,
|
||||
)
|
||||
from imaginairy.utils.log_utils import (
|
||||
log_img,
|
||||
)
|
||||
|
||||
if control_input.image_raw is not None:
|
||||
control_image = control_input.image_raw
|
||||
elif control_input.image is not None:
|
||||
control_image = control_input.image
|
||||
else:
|
||||
raise ValueError("No control image found")
|
||||
assert control_image is not None
|
||||
control_image = control_image.convert("RGB")
|
||||
log_img(control_image, "control_image_input")
|
||||
control_image_input = pillow_fit_image_within(
|
||||
control_image,
|
||||
max_height=fit_height,
|
||||
max_width=fit_width,
|
||||
)
|
||||
if control_input.mode == "inpaint":
|
||||
control_image_input = ImageOps.invert(control_image_input)
|
||||
|
||||
control_image_input_t = pillow_img_to_torch_image(control_image_input)
|
||||
control_image_input_t = control_image_input_t.to(get_device())
|
||||
|
||||
if control_input.image_raw is None:
|
||||
from imaginairy.img_processors.control_modes import CONTROL_MODES
|
||||
|
||||
control_prep_function = CONTROL_MODES[control_input.mode]
|
||||
if control_input.mode == "inpaint":
|
||||
control_image_t = control_prep_function( # type: ignore
|
||||
control_image_input_t, init_image_t
|
||||
)
|
||||
else:
|
||||
control_image_t = control_prep_function(control_image_input_t) # type: ignore
|
||||
else:
|
||||
control_image_t = (control_image_input_t + 1) / 2
|
||||
|
||||
control_image_disp = control_image_t * 2 - 1
|
||||
|
||||
log_img(control_image_disp, "control_image")
|
||||
|
||||
if len(control_image_t.shape) == 3:
|
||||
raise ValueError("Control image must be 4D")
|
||||
|
||||
if control_image_t.shape[1] != 3:
|
||||
raise ValueError("Control image must have 3 channels")
|
||||
|
||||
if (
|
||||
control_input.mode != "inpaint"
|
||||
and control_image_t.min() < 0
|
||||
or control_image_t.max() > 1
|
||||
):
|
||||
msg = f"Control image must be in [0, 1] but we received {control_image_t.min()} and {control_image_t.max()}"
|
||||
raise ValueError(msg)
|
||||
|
||||
if control_image_t.max() == control_image_t.min():
|
||||
msg = f"No control signal found in control image {control_input.mode}."
|
||||
raise ValueError(msg)
|
||||
|
||||
control_config = CONTROL_CONFIG_SHORTCUTS.get(control_input.mode, None)
|
||||
if not control_config:
|
||||
msg = f"Unknown control mode: {control_input.mode}"
|
||||
raise ValueError(msg)
|
||||
from refiners.foundationals.latent_diffusion import SD1ControlnetAdapter
|
||||
|
||||
controlnet = SD1ControlnetAdapter( # type: ignore
|
||||
name=control_input.mode,
|
||||
target=sd.unet,
|
||||
weights_location=control_config.weights_location,
|
||||
)
|
||||
controlnet.set_scale(control_input.strength)
|
||||
return controlnet, control_image_t, control_image_disp
|
||||
|
Loading…
Reference in New Issue
Block a user