feature: multi-controlnet support at the command line

add controlnet option for edit demo
pull/333/head
Bryce 1 year ago committed by Bryce Drennan
parent bcaa000d35
commit d32e1060cd

@ -468,7 +468,7 @@ A: The AI models are cached in `~/.cache/` (or `HUGGINGFACE_HUB_CACHE`). To dele
## ChangeLog
- feature: multi-controlnet support
- feature: multi-controlnet support. pass in multiple `--control-mode`, `--control-image`, and `--control-image-raw` arguments.
- feature: "better" memory management. If GPU is full, least-recently-used model is moved to RAM.
- fix: hide the "triton" error messages
- feature: show full stack trace on error in cli

@ -242,13 +242,13 @@ def _generate_single_image(
if isinstance(prompt.mask_image, str) and prompt.mask_image.startswith("*prev"):
_, img_type = prompt.mask_image.strip("*").split(".")
prompt.mask_image = _most_recent_result.images[img_type]
control_modes = []
if prompt.control_inputs:
control_modes = [c.mode for c in prompt.control_inputs]
model = get_diffusion_model(
weights_location=prompt.model,
config_path=prompt.model_config_path,
control_weights_locations=[prompt.control_mode]
if prompt.control_mode
else None,
control_weights_locations=control_modes,
half_mode=half_mode,
for_inpainting=(prompt.mask_image or prompt.mask_prompt or prompt.outpaint)
and not suppress_inpaint,
@ -398,46 +398,49 @@ def _generate_single_image(
elif is_controlnet_model:
from imaginairy.img_processors.control_modes import CONTROL_MODES
if prompt.control_image_raw is not None:
control_image = prompt.control_image_raw
elif prompt.control_image is not None:
control_image = prompt.control_image
control_image = control_image.convert("RGB")
log_img(control_image, "control_image_input")
control_image_input = pillow_fit_image_within(
control_image,
max_height=prompt.height,
max_width=prompt.width,
)
control_image_input_t = pillow_img_to_torch_image(control_image_input)
control_image_input_t = control_image_input_t.to(get_device())
if prompt.control_image_raw is None:
control_image_t = CONTROL_MODES[prompt.control_mode](
control_image_input_t
for control_input in prompt.control_inputs:
if control_input.image_raw is not None:
control_image = control_input.image_raw
elif control_input.image is not None:
control_image = control_input.image
control_image = control_image.convert("RGB")
log_img(control_image, "control_image_input")
control_image_input = pillow_fit_image_within(
control_image,
max_height=prompt.height,
max_width=prompt.width,
)
else:
control_image_t = (control_image_input_t + 1) / 2
control_image_input_t = pillow_img_to_torch_image(control_image_input)
control_image_input_t = control_image_input_t.to(get_device())
control_image_disp = control_image_t * 2 - 1
result_images["control"] = control_image_disp
log_img(control_image_disp, "control_image")
if control_input.image_raw is None:
control_image_t = CONTROL_MODES[control_input.mode](
control_image_input_t
)
else:
control_image_t = (control_image_input_t + 1) / 2
if len(control_image_t.shape) == 3:
raise RuntimeError("Control image must be 4D")
control_image_disp = control_image_t * 2 - 1
result_images[f"control-{control_input.mode}"] = control_image_disp
log_img(control_image_disp, "control_image")
if control_image_t.shape[1] != 3:
raise RuntimeError("Control image must have 3 channels")
if len(control_image_t.shape) == 3:
raise RuntimeError("Control image must be 4D")
if control_image_t.min() < 0 or control_image_t.max() > 1:
raise RuntimeError(
f"Control image must be in [0, 1] but we received {control_image_t.min()} and {control_image_t.max()}"
)
if control_image_t.shape[1] != 3:
raise RuntimeError("Control image must have 3 channels")
if control_image_t.min() < 0 or control_image_t.max() > 1:
raise RuntimeError(
f"Control image must be in [0, 1] but we received {control_image_t.min()} and {control_image_t.max()}"
)
if control_image_t.max() == control_image_t.min():
raise RuntimeError("No control signal found in control image.")
if control_image_t.max() == control_image_t.min():
raise RuntimeError(
f"No control signal found in control image {control_input.mode}."
)
c_cat.append(control_image_t)
c_cat.append(control_image_t)
elif hasattr(model, "masked_image_key"):
# inpainting model

@ -85,7 +85,23 @@ class ColorShell(HelpColorsMixin, ModShell):
class ImagineColorsCommand(HelpColorsCommand):
_option_order = []
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.help_headers_color = "yellow"
self.help_options_color = "green"
def parse_args(self, ctx, args):
# run the parser for ourselves to preserve the passed order
parser = self.make_parser(ctx)
opts, _, param_order = parser.parse_args(args=list(args))
type(self)._option_order = []
for param in param_order:
# Type check
option = opts[param.name]
if isinstance(option, list):
type(self)._option_order.append((param, option.pop(0)))
# return "normal" parse results
return super().parse_args(ctx, args)

@ -35,5 +35,5 @@ def edit_demo_cmd(image_paths, outdir, height, width):
configure_logging()
for image_path in image_paths:
create_surprise_me_images(
image_path, outdir=outdir, make_gif=True, width=width, height=height
image_path, outdir=outdir, make_gif=True, width=width, height=height, seed=1
)

@ -20,7 +20,7 @@ from imaginairy.cli.shared import _imagine_cmd, add_options, common_options
"extracted from the control image. "
"Defaults to the `--init-image`"
),
multiple=False,
multiple=True,
)
@click.option(
"--control-image-raw",
@ -30,7 +30,7 @@ from imaginairy.cli.shared import _imagine_cmd, add_options, common_options
" expects the already extracted signal. For example the raw control image would be a depth map or"
"pose information."
),
multiple=False,
multiple=True,
)
@click.option(
"--control-mode",
@ -51,6 +51,7 @@ from imaginairy.cli.shared import _imagine_cmd, add_options, common_options
]
),
help="how the control image is used as signal",
multiple=True,
)
@click.pass_context
def imagine_cmd(
@ -103,6 +104,42 @@ def imagine_cmd(
Can be invoked via either `aimg imagine` or just `imagine`.
"""
from imaginairy.schema import ControlNetInput, LazyLoadingImage
# hacky method of getting order of control images (mixing raw and normal images)
control_images = [
(o, path)
for o, path in ImagineColorsCommand._option_order # noqa
if o.name in ("control_image", "control_image_raw")
]
control_inputs = []
if control_mode:
for i, cm in enumerate(control_mode):
try:
option = control_images[i]
except IndexError:
option = None
if option is None:
control_image = None
control_image_raw = None
elif option[0].name == "control_image":
control_image = option[1]
control_image_raw = None
if control_image and control_image.startswith("http"):
control_image = LazyLoadingImage(url=control_image)
else:
control_image = None
control_image_raw = option[1]
if control_image_raw and control_image_raw.startswith("http"):
control_image_raw = LazyLoadingImage(url=control_image_raw)
control_inputs.append(
ControlNetInput(
image=control_image,
image_raw=control_image_raw,
mode=cm,
)
)
return _imagine_cmd(
ctx,
prompt_texts,
@ -144,7 +181,5 @@ def imagine_cmd(
arg_schedules,
make_compilation_animation,
caption_text,
control_image,
control_image_raw,
control_mode,
control_inputs=control_inputs,
)

@ -49,9 +49,7 @@ def _imagine_cmd(
arg_schedules=None,
make_compilation_animation=False,
caption_text="",
control_image=None,
control_image_raw=None,
control_mode="",
control_inputs=None,
):
"""Have the AI generate images. alias:imagine."""
@ -102,12 +100,6 @@ def _imagine_cmd(
from imaginairy import ImaginePrompt, LazyLoadingImage, imagine_image_files
if control_image and control_image.startswith("http"):
control_image = LazyLoadingImage(url=control_image)
if control_image_raw and control_image_raw.startswith("http"):
control_image_raw = LazyLoadingImage(url=control_image_raw)
new_init_images = []
for _init_image in init_images:
if _init_image and _init_image.startswith("http"):
@ -141,6 +133,7 @@ def _imagine_cmd(
_tile_mode = "y"
else:
_tile_mode = ""
for _init_image in init_images:
prompt = ImaginePrompt(
next(prompt_iterator),
@ -148,9 +141,7 @@ def _imagine_cmd(
prompt_strength=prompt_strength,
init_image=_init_image,
init_image_strength=init_image_strength,
control_image=control_image,
control_image_raw=control_image_raw,
control_mode=control_mode,
control_inputs=control_inputs,
seed=seed,
sampler_type=sampler_type,
steps=steps,

@ -2,18 +2,18 @@ from PIL import Image
from imaginairy import ImaginePrompt, imagine
from imaginairy.enhancers.describe_image_blip import generate_caption
from imaginairy.schema import ControlNetInput
def colorize_img(img):
caption = generate_caption(img)
caption = caption.replace("black and white", "color")
control_input = ControlNetInput(mode="hed", image=img)
prompt = ImaginePrompt(
prompt=caption,
init_image=img,
init_image_strength=0.01,
control_image=img,
control_mode="hed",
control_inputs=[control_input],
negative_prompt="black and white",
# width=img.width,
# height=img.height,
@ -25,8 +25,7 @@ def colorize_img(img):
prompt=caption,
init_image=colorized_img,
init_image_strength=0.1,
control_image=img,
control_mode="hed",
control_inputs=[control_input],
negative_prompt="black and white",
width=min(img.width, 1024),
height=min(img.height, 1024),

@ -37,7 +37,10 @@ def load_tensors(tensorfile, map_location=None):
return torch.load(tensorfile, map_location=map_location)
if tensorfile.endswith(".safetensors"):
return load_file(tensorfile, device=map_location)
raise ValueError(f"Unknown tensorfile type: {tensorfile}")
return load_file(tensorfile, device=map_location)
# raise ValueError(f"Unknown tensorfile type: {tensorfile}")
def load_state_dict(weights_location, half_mode=False, device=None):

@ -72,6 +72,31 @@ class LazyLoadingImage:
return self._lazy_filepath or self._lazy_url
class ControlNetInput:
def __init__(self, *, mode, image=None, image_raw=None):
self.mode = mode
self.image = image
self.image_raw = image_raw
def validate(self, default_image=None):
if isinstance(self.image, str):
if not self.image.startswith("*prev."):
self.image = LazyLoadingImage(filepath=self.image)
if isinstance(self.image_raw, str):
if not self.image_raw.startswith("*prev."):
self.image_raw = LazyLoadingImage(filepath=self.image_raw)
if self.image is None and self.image_raw is None and default_image is not None:
self.image = default_image
if self.image is None and self.image_raw is None:
raise ValueError("You must specify either image or image_raw")
if self.image is not None and self.image_raw is not None:
raise ValueError("You cannot specify both image and image_raw")
class WeightedPrompt:
def __init__(self, text, weight=1):
self.text = text
@ -95,9 +120,7 @@ class ImaginePrompt:
prompt_strength=7.5,
init_image=None, # Pillow Image, LazyLoadingImage, or filepath str
init_image_strength=None,
control_image=None,
control_image_raw=None,
control_mode=None,
control_inputs=None,
mask_prompt=None,
mask_image=None,
mask_mode=MaskMode.REPLACE,
@ -125,9 +148,7 @@ class ImaginePrompt:
self.prompt_strength = prompt_strength
self.init_image = init_image
self.init_image_strength = init_image_strength
self.control_image = control_image
self.control_image_raw = control_image_raw
self.control_mode = control_mode
self.control_inputs = control_inputs
self._orig_seed = seed
self.seed = seed
self.steps = steps
@ -168,16 +189,6 @@ class ImaginePrompt:
self.tile_mode = self.tile_mode.lower()
assert self.tile_mode in ("", "x", "y", "xy")
if isinstance(self.control_image, str):
if not self.control_image.startswith("*prev."):
self.control_image = LazyLoadingImage(filepath=self.control_image)
if isinstance(self.control_image_raw, str):
if not self.control_image_raw.startswith("*prev."):
self.control_image_raw = LazyLoadingImage(
filepath=self.control_image_raw
)
if isinstance(self.init_image, str):
if not self.init_image.startswith("*prev."):
self.init_image = LazyLoadingImage(filepath=self.init_image)
@ -186,23 +197,13 @@ class ImaginePrompt:
if not self.mask_image.startswith("*prev."):
self.mask_image = LazyLoadingImage(filepath=self.mask_image)
if self.control_image is not None and self.control_image_raw is not None:
raise ValueError(
"You can only set one of `control_image` and `control_image_raw`"
)
if self.control_image is not None and self.init_image is None:
self.init_image = self.control_image
if (
self.control_mode
and self.control_image is None
and self.init_image is not None
):
self.control_image = self.init_image
if self.control_inputs:
for control_input in self.control_inputs:
control_input.validate(default_image=self.init_image)
if self.control_mode and not (self.control_image or self.control_image_raw):
raise ValueError("You must set `control_image` when using `control_mode`")
if self.init_image is None:
if self.control_inputs[0].image:
self.init_image = self.control_inputs[0].image
if self.mask_image is not None and self.mask_prompt is not None:
raise ValueError("You can only set one of `mask_image` and `mask_prompt`")
@ -211,7 +212,7 @@ class ImaginePrompt:
self.model = config.DEFAULT_MODEL
if self.init_image_strength is None:
if self.control_mode is not None:
if self.control_inputs:
self.init_image_strength = 0.0
elif self.outpaint or self.mask_image or self.mask_prompt:
self.init_image_strength = 0.0

@ -10,6 +10,7 @@ from imaginairy import ImaginePrompt, LazyLoadingImage, imagine_image_files
from imaginairy.animations import make_gif_animation
from imaginairy.enhancers.facecrop import detect_faces
from imaginairy.img_utils import add_caption_to_image, pillow_fit_image_within
from imaginairy.schema import ControlNetInput
preserve_head_kwargs = {
"mask_prompt": "head|face",
@ -17,7 +18,7 @@ preserve_head_kwargs = {
}
generic_prompts = [
("add confetti", 7.5, {}),
("add confetti", 6, {}),
# ("add sparkles", 14, {}),
("make it christmas", 15, preserve_head_kwargs),
("make it halloween", 15, {}),
@ -125,7 +126,9 @@ person_prompt_configs = [
]
def surprise_me_prompts(img, person=None, width=None, height=None, steps=30):
def surprise_me_prompts(
img, person=None, width=None, height=None, steps=30, seed=None, use_controlnet=False
):
if isinstance(img, str):
if img.startswith("http"):
img = LazyLoadingImage(url=img)
@ -137,18 +140,36 @@ def surprise_me_prompts(img, person=None, width=None, height=None, steps=30):
prompts = []
for prompt_text, strength, kwargs in generic_prompts:
prompts.append(
ImaginePrompt(
prompt_text,
init_image=img,
prompt_strength=strength,
model="edit",
steps=steps,
width=width,
height=height,
**kwargs,
if use_controlnet:
control_input = ControlNetInput(
mode="edit",
)
prompts.append(
ImaginePrompt(
prompt_text,
init_image=img,
init_image_strength=0.05,
prompt_strength=strength,
control_inputs=[control_input],
steps=steps,
width=width,
height=height,
**kwargs,
)
)
else:
prompts.append(
ImaginePrompt(
prompt_text,
init_image=img,
prompt_strength=strength,
model="edit",
steps=steps,
width=width,
height=height,
**kwargs,
)
)
)
if person:
for prompt_subconfigs in person_prompt_configs:
@ -156,24 +177,44 @@ def surprise_me_prompts(img, person=None, width=None, height=None, steps=30):
prompt_subconfigs = [prompt_subconfigs]
for prompt_subconfig in prompt_subconfigs:
prompt_text, strength, kwargs = prompt_subconfig
prompts.append(
ImaginePrompt(
prompt_text,
init_image=img,
prompt_strength=strength,
model="edit",
steps=steps,
width=width,
height=height,
**kwargs, # noqa
if use_controlnet:
control_input = ControlNetInput(
mode="edit",
)
prompts.append(
ImaginePrompt(
prompt_text,
init_image=img,
init_image_strength=0.05,
prompt_strength=strength,
control_inputs=[control_input],
steps=steps,
width=width,
height=height,
seed=seed,
**kwargs, # noqa
)
)
else:
prompts.append(
ImaginePrompt(
prompt_text,
init_image=img,
prompt_strength=strength,
model="edit",
steps=steps,
width=width,
height=height,
seed=seed,
**kwargs, # noqa
)
)
)
return prompts
def create_surprise_me_images(
img, outdir, person=None, make_gif=True, width=None, height=None
img, outdir, person=None, make_gif=True, width=None, height=None, seed=None
):
if isinstance(img, str):
if img.startswith("http"):
@ -181,7 +222,9 @@ def create_surprise_me_images(
else:
img = LazyLoadingImage(filepath=img)
prompts = surprise_me_prompts(img, person=person, width=width, height=height)
prompts = surprise_me_prompts(
img, person=person, width=width, height=height, seed=seed
)
generated_filenames = imagine_image_files(
prompts,

@ -6,7 +6,7 @@ from imaginairy import LazyLoadingImage
from imaginairy.api import imagine, imagine_image_files
from imaginairy.img_processors.control_modes import CONTROL_MODES
from imaginairy.img_utils import pillow_fit_image_within
from imaginairy.schema import ImaginePrompt
from imaginairy.schema import ControlNetInput, ImaginePrompt
from imaginairy.utils import get_device
from . import TESTS_FOLDER
@ -193,12 +193,12 @@ def test_img_to_img_fruit_2_gold(
result = next(imagine(prompt))
threshold_lookup = {
"k_dpm_2_a": 31000,
"k_dpm_2_a": 32000,
"k_euler_a": 18000,
"k_dpm_adaptive": 13000,
"k_dpmpp_2s": 16000,
}
threshold = threshold_lookup.get(sampler_type, 14000)
threshold = threshold_lookup.get(sampler_type, 16000)
pillow_fit_image_within(img).save(f"{filename_base_for_orig_outputs}__orig.jpg")
img_path = f"{filename_base_for_outputs}.png"
@ -324,14 +324,18 @@ control_modes = list(CONTROL_MODES.keys())
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
def test_controlnet(filename_base_for_outputs, control_mode):
prompt_text = "a photo of a woman sitting on a bench"
control_input = ControlNetInput(
mode=control_mode,
image=LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/bench2.png"),
)
prompt = ImaginePrompt(
prompt_text,
control_image=LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/bench2.png"),
width=512,
height=512,
steps=15,
seed=0,
control_mode=control_mode,
control_inputs=[control_input],
fix_faces=True,
)
prompt.steps = 1

Loading…
Cancel
Save