feature: automatic use of inpainting

feature disabled since controlnet inpainting doesn't work great. Was disabled by setting `inpaint_method="finetune",`
pull/333/head
Bryce 1 year ago committed by Bryce Drennan
parent d32e1060cd
commit df25936d6f

@ -468,8 +468,11 @@ A: The AI models are cached in `~/.cache/` (or `HUGGINGFACE_HUB_CACHE`). To dele
## ChangeLog
- feature: [disabled] inpainting controlnet can be used instead of finetuned inpainting model
- The inpainting controlnet doesn't work as well as the finetuned model
- feature: multi-controlnet support. pass in multiple `--control-mode`, `--control-image`, and `--control-image-raw` arguments.
- feature: "better" memory management. If GPU is full, least-recently-used model is moved to RAM.
- feature: "better" memory management. If GPU is full, least-recently-used model is moved to RAM.
- feature: python interface allows configuration of controlnet strength
- fix: hide the "triton" error messages
- feature: show full stack trace on error in cli

@ -2,7 +2,7 @@ import logging
import os
import re
from imaginairy.schema import SafetyMode
from imaginairy.schema import ControlNetInput, SafetyMode
logger = logging.getLogger(__name__)
@ -195,7 +195,8 @@ def _generate_single_image(
progress_img_interval_min_s=0.1,
half_mode=None,
add_caption=False,
suppress_inpaint=False,
# controlnet, finetune, naive, auto
inpaint_method="finetune",
return_latent=False,
):
import torch.nn
@ -243,15 +244,26 @@ def _generate_single_image(
_, img_type = prompt.mask_image.strip("*").split(".")
prompt.mask_image = _most_recent_result.images[img_type]
control_modes = []
if prompt.control_inputs:
control_inputs = prompt.control_inputs or []
control_inputs = control_inputs.copy()
for_inpainting = bool(prompt.mask_image or prompt.mask_prompt or prompt.outpaint)
if control_inputs:
control_modes = [c.mode for c in prompt.control_inputs]
if inpaint_method == "auto":
if prompt.model in {"SD-1.5", "SD-2.0"}:
inpaint_method = "finetune"
else:
inpaint_method = "controlnet"
if for_inpainting and inpaint_method == "controlnet":
control_modes.append("inpaint")
model = get_diffusion_model(
weights_location=prompt.model,
config_path=prompt.model_config_path,
control_weights_locations=control_modes,
half_mode=half_mode,
for_inpainting=(prompt.mask_image or prompt.mask_prompt or prompt.outpaint)
and not suppress_inpaint,
for_inpainting=for_inpainting and inpaint_method == "finetune",
)
is_controlnet_model = hasattr(model, "control_key")
@ -303,6 +315,7 @@ def _generate_single_image(
result_images = {}
seed_everything(prompt.seed)
noise = randn_seeded(seed=prompt.seed, size=shape).to(get_device())
control_strengths = []
if prompt.init_image:
starting_image = prompt.init_image
@ -330,6 +343,14 @@ def _generate_single_image(
max_height=prompt.height,
max_width=prompt.width,
)
init_image_t = pillow_img_to_torch_image(init_image)
init_image_t = init_image_t.to(get_device())
init_latent = model.get_first_stage_encoding(
model.encode_first_stage(init_image_t)
)
shape = init_latent.shape
log_latent(init_latent, "init_latent")
if mask_image is not None:
mask_image = pillow_fit_image_within(
@ -349,15 +370,12 @@ def _generate_single_image(
mask_latent = pillow_mask_to_latent_mask(
mask_image, downsampling_factor=downsampling_factor
).to(get_device())
if inpaint_method == "controlnet":
result_images["control-inpaint"] = mask_image
control_inputs.append(
ControlNetInput(mode="inpaint", image=mask_image)
)
init_image_t = pillow_img_to_torch_image(init_image)
init_image_t = init_image_t.to(get_device())
init_latent = model.get_first_stage_encoding(
model.encode_first_stage(init_image_t)
)
shape = init_latent.shape
log_latent(init_latent, "init_latent")
seed_everything(prompt.seed)
noise = randn_seeded(seed=prompt.seed, size=init_latent.shape).to(
get_device()
@ -398,7 +416,7 @@ def _generate_single_image(
elif is_controlnet_model:
from imaginairy.img_processors.control_modes import CONTROL_MODES
for control_input in prompt.control_inputs:
for control_input in control_inputs:
if control_input.image_raw is not None:
control_image = control_input.image_raw
elif control_input.image is not None:
@ -414,9 +432,13 @@ def _generate_single_image(
control_image_input_t = control_image_input_t.to(get_device())
if control_input.image_raw is None:
control_image_t = CONTROL_MODES[control_input.mode](
control_image_input_t
)
control_prep_function = CONTROL_MODES[control_input.mode]
if control_input.mode == "inpaint":
control_image_t = control_prep_function(
control_image_input_t, init_image_t
)
else:
control_image_t = control_prep_function(control_image_input_t)
else:
control_image_t = (control_image_input_t + 1) / 2
@ -430,10 +452,11 @@ def _generate_single_image(
if control_image_t.shape[1] != 3:
raise RuntimeError("Control image must have 3 channels")
if control_image_t.min() < 0 or control_image_t.max() > 1:
raise RuntimeError(
f"Control image must be in [0, 1] but we received {control_image_t.min()} and {control_image_t.max()}"
)
if control_input.mode != "inpaint":
if control_image_t.min() < 0 or control_image_t.max() > 1:
raise RuntimeError(
f"Control image must be in [0, 1] but we received {control_image_t.min()} and {control_image_t.max()}"
)
if control_image_t.max() == control_image_t.min():
raise RuntimeError(
@ -441,6 +464,7 @@ def _generate_single_image(
)
c_cat.append(control_image_t)
control_strengths.append(control_input.strength)
elif hasattr(model, "masked_image_key"):
# inpainting model
@ -481,6 +505,10 @@ def _generate_single_image(
"c_crossattn": [neutral_conditioning],
}
if control_strengths and is_controlnet_model:
positive_conditioning["control_strengths"] = torch.Tensor(control_strengths)
neutral_conditioning["control_strengths"] = torch.Tensor(control_strengths)
if (
prompt.allow_compose_phase
and not is_controlnet_model

@ -187,6 +187,26 @@ def shuffle_map_torch(tensor, h=None, w=None, f=256):
return shuffled_tensor.to(device)
def inpaint_prep(mask_image_t, target_image_t):
"""
Combines the masked image and target image into a single tensor.
The output tensor has any masked areas set to -1 and other pixel values set between 0 and 1.
mask_image_t is a 3-channel torch tensor of shape (B, C, H, W) with pixel values in range [-1, 1], where -1 indicates masked areas
target_image_t is a 3-channel torch tensor of shape (B, C, H, W) with pixel values in range [-1, 1]
"""
import torch
# Normalize target_image_t from [-1,1] to [0,1]
target_image_t = (target_image_t + 1.0) / 2.0
# Use mask_image_t to replace masked areas in target_image_t with -1
output_image_t = torch.where(mask_image_t == -1, mask_image_t, target_image_t)
return output_image_t
def noop(img):
return (img + 1.0) / 2.0
@ -201,6 +221,6 @@ CONTROL_MODES = {
# "scribble": None,
"shuffle": shuffle_map_torch,
"edit": noop,
"inpaint": noop,
"inpaint": inpaint_prep,
"details": noop,
}

@ -394,14 +394,16 @@ class ControlLDM(LatentDiffusion):
merged_control = None
cond_txt = torch.cat(cond["c_crossattn"], 1)
for control_model, c_concat in zip(self.control_models, cond["c_concat"]):
for control_model, c_concat, control_strength in zip(
self.control_models, cond["c_concat"], cond["control_strengths"]
):
cond_hint = torch.cat([c_concat], 1)
control = control_model(
x=x_noisy, hint=cond_hint, timesteps=t, context=cond_txt
)
control = [c * scale for c, scale in zip(control, self.control_scales)]
control_scales = [control_strength] * 13
control = [c * scale for c, scale in zip(control, control_scales)]
if self.global_average_pooling:
control = [torch.mean(c, dim=(2, 3), keepdim=True) for c in control]
if merged_control is None:

@ -73,10 +73,11 @@ class LazyLoadingImage:
class ControlNetInput:
def __init__(self, *, mode, image=None, image_raw=None):
def __init__(self, *, mode, image=None, image_raw=None, strength=1):
self.mode = mode
self.image = image
self.image_raw = image_raw
self.strength = strength
def validate(self, default_image=None):
if isinstance(self.image, str):

Binary file not shown.

After

Width:  |  Height:  |  Size: 563 KiB

@ -20,8 +20,10 @@ def test_control_images(filename_base_for_outputs, control_func, control_name):
seed_everything(42)
img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/bench2.png")
img_t = pillow_img_to_torch_image(img)
control_t = control_func(img_t.clone())
if control_name == "inpaint":
control_t = control_func(img_t.clone(), img_t.clone())
else:
control_t = control_func(img_t.clone())
control_img = control_img_to_pillow_img(control_t)
img_path = f"{filename_base_for_outputs}.png"

Loading…
Cancel
Save