feature: large refactor

- add type hints
- size parameter
- ControlNetInput => ControlInput
- simplify imagineresult
pull/411/head^2
Bryce 6 months ago committed by Bryce Drennan
parent db85f0898a
commit 2bd6cb264b

@ -787,7 +787,7 @@ Use with `--model SD-2.1` or `--model SD-2.0-v`
**6.1.0** **6.1.0**
- feature: use different default steps and image sizes depending on sampler and model selected - feature: use different default steps and image sizes depending on sampler and model selected
- fix: #110 use proper version in image metadata - fix: #110 use proper version in image metadata
- refactor: samplers all have their own class that inherits from ImageSampler - refactor: solvers all have their own class that inherits from ImageSolver
- feature: 🎉🎉🎉 Stable Diffusion 2.0 - feature: 🎉🎉🎉 Stable Diffusion 2.0
- `--model SD-2.0` to use (it makes worse images than 1.5 though...) - `--model SD-2.0` to use (it makes worse images than 1.5 though...)
- Tested on macOS and Linux - Tested on macOS and Linux

@ -11,10 +11,11 @@
- allow selection of output video format - allow selection of output video format
- chain multiple operations together imggen => videogen - chain multiple operations together imggen => videogen
- make sure terminal output on windows doesn't suck - make sure terminal output on windows doesn't suck
- add karras schedule to refiners - add interface for loading diffusers weights
- add method to show cache size - add method to show cache size
- add method to clear model cache - add method to clear model cache
- add method to clear cached items not recently used (does diffusers have one?) - add method to clear cached items not recently used (does diffusers have one?)
- https://huggingface.co/playgroundai/playground-v2-1024px-aesthetic
### Old Todo ### Old Todo

@ -1,35 +1,37 @@
import logging import logging
import os import os
import re import re
from typing import TYPE_CHECKING, Callable
from imaginairy.schema import ControlNetInput, SafetyMode if TYPE_CHECKING:
from imaginairy.schema import ImaginePrompt
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# leave undocumented. I'd ask that no one publicize this flag. Just want a # leave undocumented. I'd ask that no one publicize this flag. Just want a
# slight barrier to entry. Please don't use this is any way that's gonna cause # slight barrier to entry. Please don't use this is any way that's gonna cause
# the media or politicians to freak out about AI... # the media or politicians to freak out about AI...
IMAGINAIRY_SAFETY_MODE = os.getenv("IMAGINAIRY_SAFETY_MODE", SafetyMode.STRICT) IMAGINAIRY_SAFETY_MODE = os.getenv("IMAGINAIRY_SAFETY_MODE", "strict")
if IMAGINAIRY_SAFETY_MODE in {"disabled", "classify"}: if IMAGINAIRY_SAFETY_MODE in {"disabled", "classify"}:
IMAGINAIRY_SAFETY_MODE = SafetyMode.RELAXED IMAGINAIRY_SAFETY_MODE = "relaxed"
elif IMAGINAIRY_SAFETY_MODE == "filter": elif IMAGINAIRY_SAFETY_MODE == "filter":
IMAGINAIRY_SAFETY_MODE = SafetyMode.STRICT IMAGINAIRY_SAFETY_MODE = "strict"
# we put this in the global scope so it can be used in the interactive shell # we put this in the global scope so it can be used in the interactive shell
_most_recent_result = None _most_recent_result = None
def imagine_image_files( def imagine_image_files(
prompts, prompts: "list[ImaginePrompt] | ImaginePrompt",
outdir, outdir: str,
precision="autocast", precision: str = "autocast",
record_step_images=False, record_step_images: bool = False,
output_file_extension="jpg", output_file_extension: str = "jpg",
print_caption=False, print_caption: bool = False,
make_gif=False, make_gif: bool = False,
make_compare_gif=False, make_compare_gif: bool = False,
return_filename_type="generated", return_filename_type: str = "generated",
videogen=False, videogen: bool = False,
): ):
from PIL import ImageDraw from PIL import ImageDraw
@ -46,6 +48,9 @@ def imagine_image_files(
if output_file_extension not in {"jpg", "png"}: if output_file_extension not in {"jpg", "png"}:
raise ValueError("Must output a png or jpg") raise ValueError("Must output a png or jpg")
if not isinstance(prompts, list):
prompts = [prompts]
def _record_step(img, description, image_count, step_count, prompt): def _record_step(img, description, image_count, step_count, prompt):
steps_path = os.path.join(outdir, "steps", f"{base_count:08}_S{prompt.seed}") steps_path = os.path.join(outdir, "steps", f"{base_count:08}_S{prompt.seed}")
os.makedirs(steps_path, exist_ok=True) os.makedirs(steps_path, exist_ok=True)
@ -74,7 +79,7 @@ def imagine_image_files(
if prompt.init_image: if prompt.init_image:
img_str = f"_img2img-{prompt.init_image_strength}" img_str = f"_img2img-{prompt.init_image_strength}"
basefilename = ( basefilename = (
f"{base_count:06}_{prompt.seed}_{prompt.sampler_type.replace('_', '')}{prompt.steps}_" f"{base_count:06}_{prompt.seed}_{prompt.solver_type.replace('_', '')}{prompt.steps}_"
f"PS{prompt.prompt_strength}{img_str}_{prompt_normalized(prompt.prompt_text)}" f"PS{prompt.prompt_strength}{img_str}_{prompt_normalized(prompt.prompt_text)}"
) )
@ -139,15 +144,15 @@ def imagine_image_files(
def imagine( def imagine(
prompts, prompts: "list[ImaginePrompt] | str | ImaginePrompt",
precision="autocast", precision: str = "autocast",
debug_img_callback=None, debug_img_callback: Callable | None = None,
progress_img_callback=None, progress_img_callback: Callable | None = None,
progress_img_interval_steps=3, progress_img_interval_steps: int = 3,
progress_img_interval_min_s=0.1, progress_img_interval_min_s=0.1,
half_mode=None, half_mode=None,
add_caption=False, add_caption: bool = False,
unsafe_retry_count=1, unsafe_retry_count: int = 1,
): ):
import torch.nn import torch.nn
@ -209,7 +214,7 @@ def imagine(
def _generate_single_image_compvis( def _generate_single_image_compvis(
prompt, prompt: "ImaginePrompt",
debug_img_callback=None, debug_img_callback=None,
progress_img_callback=None, progress_img_callback=None,
progress_img_interval_steps=3, progress_img_interval_steps=3,
@ -248,9 +253,9 @@ def _generate_single_image_compvis(
from imaginairy.modules.midas.api import torch_image_to_depth_map from imaginairy.modules.midas.api import torch_image_to_depth_map
from imaginairy.outpaint import outpaint_arg_str_parse, prepare_image_for_outpaint from imaginairy.outpaint import outpaint_arg_str_parse, prepare_image_for_outpaint
from imaginairy.safety import create_safety_score from imaginairy.safety import create_safety_score
from imaginairy.samplers import SAMPLER_LOOKUP from imaginairy.samplers import SOLVER_LOOKUP
from imaginairy.samplers.editing import CFGEditingDenoiser from imaginairy.samplers.editing import CFGEditingDenoiser
from imaginairy.schema import ImaginePrompt, ImagineResult from imaginairy.schema import ControlInput, ImagineResult, MaskMode
from imaginairy.utils import get_device, randn_seeded from imaginairy.utils import get_device, randn_seeded
latent_channels = 4 latent_channels = 4
@ -326,8 +331,8 @@ def _generate_single_image_compvis(
prompt.height // downsampling_factor, prompt.height // downsampling_factor,
prompt.width // downsampling_factor, prompt.width // downsampling_factor,
] ]
SamplerCls = SAMPLER_LOOKUP[prompt.sampler_type.lower()] SolverCls = SOLVER_LOOKUP[prompt.solver_type.lower()]
sampler = SamplerCls(model) solver = SolverCls(model)
mask_latent = mask_image = mask_image_orig = mask_grayscale = None mask_latent = mask_image = mask_image_orig = mask_grayscale = None
t_enc = init_latent = control_image = None t_enc = init_latent = control_image = None
starting_image = None starting_image = None
@ -385,7 +390,7 @@ def _generate_single_image_compvis(
log_img(mask_image, "init mask") log_img(mask_image, "init mask")
if prompt.mask_mode == ImaginePrompt.MaskMode.REPLACE: if prompt.mask_mode == MaskMode.REPLACE:
mask_image = ImageOps.invert(mask_image) mask_image = ImageOps.invert(mask_image)
mask_image_orig = mask_image mask_image_orig = mask_image
@ -396,7 +401,7 @@ def _generate_single_image_compvis(
if inpaint_method == "controlnet": if inpaint_method == "controlnet":
result_images["control-inpaint"] = mask_image result_images["control-inpaint"] = mask_image
control_inputs.append( control_inputs.append(
ControlNetInput(mode="inpaint", image=mask_image) ControlInput(mode="inpaint", image=mask_image)
) )
seed_everything(prompt.seed) seed_everything(prompt.seed)
@ -543,7 +548,7 @@ def _generate_single_image_compvis(
prompt=prompt, prompt=prompt,
target_height=init_image.height, target_height=init_image.height,
target_width=init_image.width, target_width=init_image.width,
cutoff=get_model_default_image_size(prompt.model), cutoff=get_model_default_image_size(prompt.model_architecture),
) )
else: else:
comp_image = _generate_composition_image( comp_image = _generate_composition_image(
@ -563,7 +568,7 @@ def _generate_single_image_compvis(
model.encode_first_stage(comp_image_t) model.encode_first_stage(comp_image_t)
) )
with lc.timing("sampling"): with lc.timing("sampling"):
samples = sampler.sample( samples = solver.sample(
num_steps=prompt.steps, num_steps=prompt.steps,
positive_conditioning=positive_conditioning, positive_conditioning=positive_conditioning,
neutral_conditioning=neutral_conditioning, neutral_conditioning=neutral_conditioning,
@ -711,8 +716,10 @@ def _generate_composition_image(
composition_prompt = prompt.full_copy( composition_prompt = prompt.full_copy(
deep=True, deep=True,
update={ update={
"width": int(prompt.width * shrink_scale), "size": (
"height": int(prompt.height * shrink_scale), int(prompt.width * shrink_scale),
int(prompt.height * shrink_scale),
),
"steps": None, "steps": None,
"upscale": False, "upscale": False,
"fix_faces": False, "fix_faces": False,

@ -1,15 +1,16 @@
import logging import logging
from typing import List, Optional from typing import List, Optional
from imaginairy import WeightedPrompt from imaginairy import ImaginePrompt, WeightedPrompt
from imaginairy.config import CONTROLNET_CONFIG_SHORTCUTS from imaginairy.config import CONTROL_CONFIG_SHORTCUTS
from imaginairy.model_manager import load_controlnet_adapter from imaginairy.model_manager import load_controlnet_adapter
from imaginairy.schema import MaskMode
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _generate_single_image( def _generate_single_image(
prompt, prompt: ImaginePrompt,
debug_img_callback=None, debug_img_callback=None,
progress_img_callback=None, progress_img_callback=None,
progress_img_interval_steps=3, progress_img_interval_steps=3,
@ -55,7 +56,7 @@ def _generate_single_image(
) )
from imaginairy.outpaint import outpaint_arg_str_parse, prepare_image_for_outpaint from imaginairy.outpaint import outpaint_arg_str_parse, prepare_image_for_outpaint
from imaginairy.safety import create_safety_score from imaginairy.safety import create_safety_score
from imaginairy.samplers import SamplerName from imaginairy.samplers import SolverName
from imaginairy.schema import ImaginePrompt, ImagineResult from imaginairy.schema import ImaginePrompt, ImagineResult
from imaginairy.utils import get_device, randn_seeded from imaginairy.utils import get_device, randn_seeded
@ -76,8 +77,8 @@ def _generate_single_image(
control_modes = [c.mode for c in prompt.control_inputs] control_modes = [c.mode for c in prompt.control_inputs]
sd = get_diffusion_model_refiners( sd = get_diffusion_model_refiners(
weights_location=prompt.model, weights_location=prompt.model_weights,
config_path=prompt.model_config_path, model_architecture=prompt.model_architecture,
control_weights_locations=tuple(control_modes), control_weights_locations=tuple(control_modes),
dtype=dtype, dtype=dtype,
for_inpainting=for_inpainting and inpaint_method == "finetune", for_inpainting=for_inpainting and inpaint_method == "finetune",
@ -90,7 +91,7 @@ def _generate_single_image(
mask_image = None mask_image = None
mask_image_orig = None mask_image_orig = None
prompt = prompt.make_concrete_copy() prompt: ImaginePrompt = prompt.make_concrete_copy()
def latent_logger(latents): def latent_logger(latents):
progress_latents.append(latents) progress_latents.append(latents)
@ -171,7 +172,7 @@ def _generate_single_image(
log_img(mask_image, "init mask") log_img(mask_image, "init mask")
if prompt.mask_mode == ImaginePrompt.MaskMode.REPLACE: if prompt.mask_mode == MaskMode.REPLACE:
mask_image = ImageOps.invert(mask_image) mask_image = ImageOps.invert(mask_image)
mask_image_orig = mask_image mask_image_orig = mask_image
@ -182,7 +183,7 @@ def _generate_single_image(
# if inpaint_method == "controlnet": # if inpaint_method == "controlnet":
# result_images["control-inpaint"] = mask_image # result_images["control-inpaint"] = mask_image
# control_inputs.append( # control_inputs.append(
# ControlNetInput(mode="inpaint", image=mask_image) # ControlInput(mode="inpaint", image=mask_image)
# ) # )
seed_everything(prompt.seed) seed_everything(prompt.seed)
@ -194,7 +195,6 @@ def _generate_single_image(
controlnets = [] controlnets = []
if control_modes: if control_modes:
control_strengths = []
from imaginairy.img_processors.control_modes import CONTROL_MODES from imaginairy.img_processors.control_modes import CONTROL_MODES
for control_input in control_inputs: for control_input in control_inputs:
@ -231,10 +231,10 @@ def _generate_single_image(
log_img(control_image_disp, "control_image") log_img(control_image_disp, "control_image")
if len(control_image_t.shape) == 3: if len(control_image_t.shape) == 3:
raise RuntimeError("Control image must be 4D") raise ValueError("Control image must be 4D")
if control_image_t.shape[1] != 3: if control_image_t.shape[1] != 3:
raise RuntimeError("Control image must have 3 channels") raise ValueError("Control image must have 3 channels")
if ( if (
control_input.mode != "inpaint" control_input.mode != "inpaint"
@ -242,21 +242,20 @@ def _generate_single_image(
or control_image_t.max() > 1 or control_image_t.max() > 1
): ):
msg = f"Control image must be in [0, 1] but we received {control_image_t.min()} and {control_image_t.max()}" msg = f"Control image must be in [0, 1] but we received {control_image_t.min()} and {control_image_t.max()}"
raise RuntimeError(msg) raise ValueError(msg)
if control_image_t.max() == control_image_t.min(): if control_image_t.max() == control_image_t.min():
msg = f"No control signal found in control image {control_input.mode}." msg = f"No control signal found in control image {control_input.mode}."
raise RuntimeError(msg) raise ValueError(msg)
control_strengths.append(control_input.strength)
control_weights_path = CONTROLNET_CONFIG_SHORTCUTS.get( control_config = CONTROL_CONFIG_SHORTCUTS.get(control_input.mode, None)
control_input.mode, None if not control_config:
).weights_url msg = f"Unknown control mode: {control_input.mode}"
raise ValueError(msg)
controlnet = load_controlnet_adapter( controlnet = load_controlnet_adapter(
name=control_input.mode, name=control_input.mode,
control_weights_location=control_weights_path, control_weights_location=control_config.weights_location,
target_unet=sd.unet, target_unet=sd.unet,
scale=control_input.strength, scale=control_input.strength,
) )
@ -268,7 +267,7 @@ def _generate_single_image(
prompt=prompt, prompt=prompt,
target_height=init_image.height, target_height=init_image.height,
target_width=init_image.width, target_width=init_image.width,
cutoff=get_model_default_image_size(prompt.model), cutoff=get_model_default_image_size(prompt.model_architecture),
dtype=dtype, dtype=dtype,
) )
else: else:
@ -276,7 +275,7 @@ def _generate_single_image(
prompt=prompt, prompt=prompt,
target_height=prompt.height, target_height=prompt.height,
target_width=prompt.width, target_width=prompt.width,
cutoff=get_model_default_image_size(prompt.model), cutoff=get_model_default_image_size(prompt.model_architecture),
dtype=dtype, dtype=dtype,
) )
if comp_image is not None: if comp_image is not None:
@ -296,12 +295,12 @@ def _generate_single_image(
control_image_t.to(device=sd.device, dtype=sd.dtype) control_image_t.to(device=sd.device, dtype=sd.dtype)
) )
controlnet.inject() controlnet.inject()
if prompt.sampler_type.lower() == SamplerName.K_DPMPP_2M: if prompt.solver_type.lower() == SolverName.DPMPP:
sd.scheduler = DPMSolver(num_inference_steps=prompt.steps) sd.scheduler = DPMSolver(num_inference_steps=prompt.steps)
elif prompt.sampler_type.lower() == SamplerName.DDIM: elif prompt.solver_type.lower() == SolverName.DDIM:
sd.scheduler = DDIM(num_inference_steps=prompt.steps) sd.scheduler = DDIM(num_inference_steps=prompt.steps)
else: else:
msg = f"Unknown sampler type: {prompt.sampler_type}" msg = f"Unknown solver type: {prompt.solver_type}"
raise ValueError(msg) raise ValueError(msg)
sd.scheduler.to(device=sd.device, dtype=sd.dtype) sd.scheduler.to(device=sd.device, dtype=sd.dtype)
sd.set_num_inference_steps(prompt.steps) sd.set_num_inference_steps(prompt.steps)
@ -414,18 +413,24 @@ def _generate_single_image(
caption_text = prompt.caption_text.format(prompt=prompt.prompt_text) caption_text = prompt.caption_text.format(prompt=prompt.prompt_text)
add_caption_to_image(gen_img, caption_text) add_caption_to_image(gen_img, caption_text)
# todo: do something smarter
result_images.update(
{
"upscaled": upscaled_img,
"modified_original": rebuilt_orig_img,
"mask_binary": mask_image_orig,
"mask_grayscale": mask_grayscale,
}
)
result = ImagineResult( result = ImagineResult(
img=gen_img, img=gen_img,
prompt=prompt, prompt=prompt,
upscaled_img=upscaled_img,
is_nsfw=safety_score.is_nsfw, is_nsfw=safety_score.is_nsfw,
safety_score=safety_score, safety_score=safety_score,
modified_original=rebuilt_orig_img,
mask_binary=mask_image_orig,
mask_grayscale=mask_grayscale,
result_images=result_images, result_images=result_images,
timings={}, timings={}, # todo
progress_latents=[], progress_latents=[], # todo
) )
_most_recent_result = result _most_recent_result = result
@ -441,6 +446,9 @@ def _generate_single_image(
def _prompts_to_embeddings(prompts, text_encoder): def _prompts_to_embeddings(prompts, text_encoder):
import torch import torch
if not prompts:
prompts = [WeightedPrompt(text="")]
total_weight = sum(wp.weight for wp in prompts) total_weight = sum(wp.weight for wp in prompts)
if str(text_encoder.device) == "cpu": if str(text_encoder.device) == "cpu":
text_encoder = text_encoder.to(dtype=torch.float32) text_encoder = text_encoder.to(dtype=torch.float32)

@ -31,7 +31,7 @@ remove_option(edit_options, "allow_compose_phase")
@click.option( @click.option(
"--model-weights-path", "--model-weights-path",
"--model", "--model",
help=f"Model to use. Should be one of {', '.join(config.MODEL_SHORT_NAMES)}, or a path to custom weights.", help=f"Model to use. Should be one of {', '.join(config.IMAGE_WEIGHTS_SHORT_NAMES)}, or a path to custom weights.",
show_default=True, show_default=True,
default="SD-1.5", default="SD-1.5",
) )
@ -53,15 +53,13 @@ def edit_cmd(
outdir, outdir,
output_file_extension, output_file_extension,
repeats, repeats,
height,
width,
size, size,
steps, steps,
seed, seed,
upscale, upscale,
fix_faces, fix_faces,
fix_faces_fidelity, fix_faces_fidelity,
sampler_type, solver,
log_level, log_level,
quiet, quiet,
show_work, show_work,
@ -76,7 +74,7 @@ def edit_cmd(
caption, caption,
precision, precision,
model_weights_path, model_weights_path,
model_config_path, model_architecture,
prompt_library_path, prompt_library_path,
version, version,
make_gif, make_gif,
@ -95,11 +93,11 @@ def edit_cmd(
Same as calling `aimg imagine --model edit --init-image my-dog.jpg --init-image-strength 1` except this command Same as calling `aimg imagine --model edit --init-image my-dog.jpg --init-image-strength 1` except this command
can batch edit images. can batch edit images.
""" """
from imaginairy.schema import ControlNetInput from imaginairy.schema import ControlInput
allow_compose_phase = False allow_compose_phase = False
control_inputs = [ control_inputs = [
ControlNetInput( ControlInput(
image=None, image=None,
image_raw=None, image_raw=None,
mode="edit", mode="edit",
@ -116,15 +114,13 @@ def edit_cmd(
outdir, outdir,
output_file_extension, output_file_extension,
repeats, repeats,
height,
width,
size, size,
steps, steps,
seed, seed,
upscale, upscale,
fix_faces, fix_faces,
fix_faces_fidelity, fix_faces_fidelity,
sampler_type, solver,
log_level, log_level,
quiet, quiet,
show_work, show_work,
@ -140,7 +136,7 @@ def edit_cmd(
caption, caption,
precision, precision,
model_weights_path, model_weights_path,
model_config_path, model_architecture,
prompt_library_path, prompt_library_path,
version, version,
make_gif, make_gif,

@ -77,15 +77,13 @@ def imagine_cmd(
outdir, outdir,
output_file_extension, output_file_extension,
repeats, repeats,
height,
width,
size, size,
steps, steps,
seed, seed,
upscale, upscale,
fix_faces, fix_faces,
fix_faces_fidelity, fix_faces_fidelity,
sampler_type, solver,
log_level, log_level,
quiet, quiet,
show_work, show_work,
@ -101,7 +99,7 @@ def imagine_cmd(
caption, caption,
precision, precision,
model_weights_path, model_weights_path,
model_config_path, model_architecture,
prompt_library_path, prompt_library_path,
version, version,
make_gif, make_gif,
@ -120,7 +118,7 @@ def imagine_cmd(
Can be invoked via either `aimg imagine` or just `imagine`. Can be invoked via either `aimg imagine` or just `imagine`.
""" """
from imaginairy.schema import ControlNetInput, LazyLoadingImage from imaginairy.schema import ControlInput, LazyLoadingImage
# hacky method of getting order of control images (mixing raw and normal images) # hacky method of getting order of control images (mixing raw and normal images)
control_images = [ control_images = [
@ -128,13 +126,18 @@ def imagine_cmd(
for o, path in ImagineColorsCommand._option_order for o, path in ImagineColorsCommand._option_order
if o.name in ("control_image", "control_image_raw") if o.name in ("control_image", "control_image_raw")
] ]
control_strengths = [
strength
for o, strength in ImagineColorsCommand._option_order
if o.name == "control_strength"
]
control_inputs = [] control_inputs = []
if control_mode: if control_mode:
for i, cm in enumerate(control_mode): for i, cm in enumerate(control_mode):
try: option = index_default(control_images, i, None)
option = control_images[i] control_strength = index_default(control_strengths, i, 1.0)
except IndexError:
option = None
if option is None: if option is None:
control_image = None control_image = None
control_image_raw = None control_image_raw = None
@ -149,10 +152,10 @@ def imagine_cmd(
if control_image_raw and control_image_raw.startswith("http"): if control_image_raw and control_image_raw.startswith("http"):
control_image_raw = LazyLoadingImage(url=control_image_raw) control_image_raw = LazyLoadingImage(url=control_image_raw)
control_inputs.append( control_inputs.append(
ControlNetInput( ControlInput(
image=control_image, image=control_image,
image_raw=control_image_raw, image_raw=control_image_raw,
strength=float(control_strength[i]), strength=float(control_strength),
mode=cm, mode=cm,
) )
) )
@ -167,15 +170,13 @@ def imagine_cmd(
outdir, outdir,
output_file_extension, output_file_extension,
repeats, repeats,
height,
width,
size, size,
steps, steps,
seed, seed,
upscale, upscale,
fix_faces, fix_faces,
fix_faces_fidelity, fix_faces_fidelity,
sampler_type, solver,
log_level, log_level,
quiet, quiet,
show_work, show_work,
@ -191,7 +192,7 @@ def imagine_cmd(
caption, caption,
precision, precision,
model_weights_path, model_weights_path,
model_config_path, model_architecture,
prompt_library_path, prompt_library_path,
version, version,
make_gif, make_gif,
@ -204,5 +205,12 @@ def imagine_cmd(
) )
def index_default(items, index, default):
try:
return items[index]
except IndexError:
return default
if __name__ == "__main__": if __name__ == "__main__":
imagine_cmd() imagine_cmd()

@ -45,8 +45,6 @@ aimg.add_command(describe_cmd, name="describe")
aimg.add_command(edit_cmd, name="edit") aimg.add_command(edit_cmd, name="edit")
aimg.add_command(edit_demo_cmd, name="edit-demo") aimg.add_command(edit_demo_cmd, name="edit-demo")
aimg.add_command(imagine_cmd, name="imagine") aimg.add_command(imagine_cmd, name="imagine")
# aimg.add_command(prep_images_cmd, name="prep-images")
# aimg.add_command(prune_ckpt_cmd, name="prune-ckpt")
aimg.add_command(upscale_cmd, name="upscale") aimg.add_command(upscale_cmd, name="upscale")
aimg.add_command(run_server_cmd, name="server") aimg.add_command(run_server_cmd, name="server")
aimg.add_command(videogen_cmd, name="videogen") aimg.add_command(videogen_cmd, name="videogen")
@ -85,14 +83,14 @@ def model_list_cmd():
from imaginairy import config from imaginairy import config
print(f"{'ALIAS': <10} {'NAME': <18} {'DESCRIPTION'}") print(f"{'ALIAS': <10} {'NAME': <18} {'DESCRIPTION'}")
for model_config in config.MODEL_CONFIGS: for model_config in config.MODEL_WEIGHT_CONFIGS:
print( print(
f"{model_config.alias: <10} {model_config.short_name: <18} {model_config.description}" f"{model_config.alias: <10} {model_config.short_name: <18} {model_config.description}"
) )
print("\nCONTROL MODES:") print("\nCONTROL MODES:")
print(f"{'ALIAS': <10} {'NAME': <18} {'CONTROL TYPE'}") print(f"{'ALIAS': <10} {'NAME': <18} {'CONTROL TYPE'}")
for control_mode in config.CONTROLNET_CONFIGS: for control_mode in config.CONTROL_CONFIGS:
print( print(
f"{control_mode.alias: <10} {control_mode.short_name: <18} {control_mode.control_type}" f"{control_mode.alias: <10} {control_mode.short_name: <18} {control_mode.control_type}"
) )

@ -34,15 +34,13 @@ def _imagine_cmd(
outdir, outdir,
output_file_extension, output_file_extension,
repeats, repeats,
height,
width,
size, size,
steps, steps,
seed, seed,
upscale, upscale,
fix_faces, fix_faces,
fix_faces_fidelity, fix_faces_fidelity,
sampler_type, solver,
log_level, log_level,
quiet, quiet,
show_work, show_work,
@ -58,7 +56,7 @@ def _imagine_cmd(
caption, caption,
precision, precision,
model_weights_path, model_weights_path,
model_config_path, model_architecture,
prompt_library_path, prompt_library_path,
version=False, version=False,
make_gif=False, make_gif=False,
@ -96,15 +94,6 @@ def _imagine_cmd(
configure_logging(log_level) configure_logging(log_level)
if (height is not None or width is not None) and size is not None:
msg = "You cannot specify both --size and --height/--width. Please choose one."
raise ValueError(msg)
if size is not None:
from imaginairy.utils.named_resolutions import get_named_resolution
width, height = get_named_resolution(size)
init_images = [init_image] if isinstance(init_image, str) else init_image init_images = [init_image] if isinstance(init_image, str) else init_image
from imaginairy.utils import glob_expand_paths from imaginairy.utils import glob_expand_paths
@ -171,10 +160,9 @@ def _imagine_cmd(
init_image_strength=init_image_strength, init_image_strength=init_image_strength,
control_inputs=control_inputs, control_inputs=control_inputs,
seed=seed, seed=seed,
sampler_type=sampler_type, solver_type=solver,
steps=steps, steps=steps,
height=height, size=size,
width=width,
mask_image=mask_image, mask_image=mask_image,
mask_prompt=mask_prompt, mask_prompt=mask_prompt,
mask_mode=mask_mode, mask_mode=mask_mode,
@ -185,8 +173,8 @@ def _imagine_cmd(
fix_faces_fidelity=fix_faces_fidelity, fix_faces_fidelity=fix_faces_fidelity,
tile_mode=_tile_mode, tile_mode=_tile_mode,
allow_compose_phase=allow_compose_phase, allow_compose_phase=allow_compose_phase,
model=model_weights_path, model_weights=model_weights_path,
model_config_path=model_config_path, model_architecture=model_architecture,
caption_text=caption_text, caption_text=caption_text,
) )
from imaginairy.prompt_schedules import ( from imaginairy.prompt_schedules import (
@ -318,28 +306,12 @@ common_options = [
type=int, type=int,
help="How many times to repeat the renders. If you provide two prompts and --repeat=3 then six images will be generated.", help="How many times to repeat the renders. If you provide two prompts and --repeat=3 then six images will be generated.",
), ),
click.option(
"-h",
"--height",
default=None,
show_default=True,
type=int,
help="Image height. Should be multiple of 8.",
),
click.option(
"-w",
"--width",
default=None,
show_default=True,
type=int,
help="Image width. Should be multiple of 8.",
),
click.option( click.option(
"--size", "--size",
default=None, default=None,
show_default=True, show_default=True,
type=str, type=str,
help="Image size as a string. Can be a named size or WIDTHxHEIGHT format. Should be multiple of 8. Examples: 512x512, 4k, UHD, 8k, ", help="Image size as a string. Can be a named size, WIDTHxHEIGHT, or single integer. Should be multiple of 8. Examples: 512x512, 4k, UHD, 8k, 512, 1080p",
), ),
click.option( click.option(
"--steps", "--steps",
@ -363,18 +335,18 @@ common_options = [
help="How faithful to the original should face enhancement be. 1 = best fidelity, 0 = best looking face.", help="How faithful to the original should face enhancement be. 1 = best fidelity, 0 = best looking face.",
), ),
click.option( click.option(
"--sampler-type", "--solver",
"--sampler", "--sampler",
default=config.DEFAULT_SAMPLER, default=config.DEFAULT_SOLVER,
show_default=True, show_default=True,
type=click.Choice(config.SAMPLER_TYPE_OPTIONS), type=click.Choice(config.SOLVER_TYPE_NAMES, case_sensitive=False),
help="What sampling strategy to use.", help="Solver algorithm to generate the image with. (AKA 'Sampler' or 'Scheduler' in other libraries.",
), ),
click.option( click.option(
"--log-level", "--log-level",
default="INFO", default="INFO",
show_default=True, show_default=True,
type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR"]), type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR"], case_sensitive=False),
help="What level of logs to show.", help="What level of logs to show.",
), ),
click.option( click.option(
@ -429,7 +401,7 @@ common_options = [
"--mask-mode", "--mask-mode",
default="replace", default="replace",
show_default=True, show_default=True,
type=click.Choice(["keep", "replace"]), type=click.Choice(["keep", "replace"], case_sensitive=False),
help="Should we replace the masked area or keep it?", help="Should we replace the masked area or keep it?",
), ),
click.option( click.option(
@ -458,20 +430,20 @@ common_options = [
click.option( click.option(
"--precision", "--precision",
help="Evaluate at this precision.", help="Evaluate at this precision.",
type=click.Choice(["full", "autocast"]), type=click.Choice(["full", "autocast"], case_sensitive=False),
default="autocast", default="autocast",
show_default=True, show_default=True,
), ),
click.option( click.option(
"--model-weights-path", "--model-weights-path",
"--model", "--model",
help=f"Model to use. Should be one of {', '.join(config.MODEL_SHORT_NAMES)}, or a path to custom weights.", help=f"Model to use. Should be one of {', '.join(config.IMAGE_WEIGHTS_SHORT_NAMES)}, or a path to custom weights.",
show_default=True, show_default=True,
default=config.DEFAULT_MODEL, default=config.DEFAULT_MODEL_WEIGHTS,
), ),
click.option( click.option(
"--model-config-path", "--model-architecture",
help="Model config file to use. If a model name is specified, the appropriate config will be used.", help="Model architecture. When specifying custom weights the model architecture must be specified. (sd15, sdxl, etc).",
show_default=True, show_default=True,
default=None, default=None,
), ),

@ -44,9 +44,9 @@ logger = logging.getLogger(__name__)
"--model-weights-path", "--model-weights-path",
"--model", "--model",
"model", "model",
help=f"Model to use. Should be one of {', '.join(config.MODEL_SHORT_NAMES)}, or a path to custom weights.", help=f"Model to use. Should be one of {', '.join(config.IMAGE_WEIGHTS_SHORT_NAMES)}, or a path to custom weights.",
show_default=True, show_default=True,
default=config.DEFAULT_MODEL, default=config.DEFAULT_MODEL_WEIGHTS,
) )
@click.option( @click.option(
"--person", "--person",

@ -4,7 +4,7 @@ from PIL import Image, ImageEnhance, ImageStat
from imaginairy import ImaginePrompt, imagine from imaginairy import ImaginePrompt, imagine
from imaginairy.enhancers.describe_image_blip import generate_caption from imaginairy.enhancers.describe_image_blip import generate_caption
from imaginairy.schema import ControlNetInput from imaginairy.schema import ControlInput
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -23,7 +23,7 @@ def colorize_img(img, max_width=1024, max_height=1024, caption=None):
caption = caption.replace(" old ", " ") caption = caption.replace(" old ", " ")
logger.info(caption) logger.info(caption)
control_inputs = [ control_inputs = [
ControlNetInput(mode="colorize", image=img, strength=2), ControlInput(mode="colorize", image=img, strength=2),
] ]
prompt_add = ". color photo, sharp-focus, highly detailed, intricate, Canon 5D" prompt_add = ". color photo, sharp-focus, highly detailed, intricate, Canon 5D"
prompt = ImaginePrompt( prompt = ImaginePrompt(

@ -1,7 +1,9 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, List
DEFAULT_MODEL = "SD-1.5" DEFAULT_MODEL_WEIGHTS = "sd15"
DEFAULT_SAMPLER = "ddim" DEFAULT_MODEL_ARCHITECTURE = "sd15"
DEFAULT_SOLVER = "ddim"
DEFAULT_NEGATIVE_PROMPT = ( DEFAULT_NEGATIVE_PROMPT = (
"Ugly, duplication, duplicates, mutilation, deformed, mutilated, mutation, twisted body, disfigured, bad anatomy, " "Ugly, duplication, duplicates, mutilation, deformed, mutilated, mutation, twisted body, disfigured, bad anatomy, "
@ -12,229 +14,306 @@ DEFAULT_NEGATIVE_PROMPT = (
"grainy, blurred, blurry, writing, calligraphy, signature, text, watermark, bad art," "grainy, blurred, blurry, writing, calligraphy, signature, text, watermark, bad art,"
) )
SPLITMEM_ENABLED = False midas_url = "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt"
@dataclass @dataclass
class ModelConfig: class ModelArchitecture:
description: str name: str
short_name: str aliases: List[str]
config_path: str output_modality: str
weights_url: str defaults: dict[str, Any]
default_image_size: int config_path: str | None = None
forced_attn_precision: str = "default"
default_negative_prompt: str = DEFAULT_NEGATIVE_PROMPT
alias: str = None
midas_url = "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt" @dataclass
class ModelWeightsConfig:
name: str
aliases: List[str]
architecture: ModelArchitecture
defaults: dict[str, Any]
weights_location: str
MODEL_CONFIGS = [
ModelConfig( MODEL_ARCHITECTURES = [
description="Stable Diffusion 1.5", ModelArchitecture(
short_name="SD-1.5", name="Stable Diffusion 1.5",
aliases=["sd15", "sd-15", "sd1.5", "sd-1.5"],
output_modality="image",
defaults={"size": "512"},
config_path="configs/stable-diffusion-v1.yaml", config_path="configs/stable-diffusion-v1.yaml",
weights_url="https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/889b629140e71758e1e0006e355c331a5744b4bf/v1-5-pruned-emaonly.ckpt",
default_image_size=512,
alias="sd15",
), ),
ModelConfig( ModelArchitecture(
description="Stable Diffusion 1.5 - Inpainting", name="Stable Diffusion 1.5 - Inpainting",
short_name="SD-1.5-inpaint", aliases=[
"sd15inpaint",
"sd15-inpaint",
"sd-15-inpaint",
"sd1.5inpaint",
"sd1.5-inpaint",
"sd-1.5-inpaint",
],
output_modality="image",
defaults={"size": "512"},
config_path="configs/stable-diffusion-v1-inpaint.yaml", config_path="configs/stable-diffusion-v1-inpaint.yaml",
weights_url="https://huggingface.co/julienacquaviva/inpainting/resolve/2155ff7fe38b55f4c0d99c2f1ab9b561f8311ca7/sd-v1-5-inpainting.ckpt", ),
default_image_size=512, ModelArchitecture(
alias="sd15in", name="Stable Diffusion XL",
), aliases=["sdxl", "sd-xl"],
# ModelConfig( output_modality="image",
# description="Instruct Pix2Pix - Photo Editing", defaults={"size": "512"},
# short_name="instruct-pix2pix", ),
# config_path="configs/instruct-pix2pix.yaml", ModelArchitecture(
# weights_url="https://huggingface.co/imaginairy/instruct-pix2pix/resolve/ea0009b3d0d4888f410a40bd06d69516d0b5a577/instruct-pix2pix-00-22000-pruned.ckpt", name="Stable Video Diffusion",
# default_image_size=512, aliases=["svd", "stablevideo"],
# default_negative_prompt="", output_modality="video",
# alias="edit", defaults={"size": "1024x576"},
# ), config_path="configs/svd.yaml",
ModelConfig( ),
description="OpenJourney V1", ModelArchitecture(
short_name="openjourney-v1", name="Stable Video Diffusion - Image Decoder",
config_path="configs/stable-diffusion-v1.yaml", aliases=["svd-image-decoder", "svd-imdec"],
weights_url="https://huggingface.co/prompthero/openjourney/resolve/7428477dad893424c92f6ea1cc29d45f6d1448c1/mdjrny-v4.safetensors", output_modality="video",
default_image_size=512, defaults={"size": "1024x576"},
default_negative_prompt="", config_path="configs/svd_image_decoder.yaml",
alias="oj1", ),
), ModelArchitecture(
ModelConfig( name="Stable Video Diffusion - XT",
description="OpenJourney V2", aliases=["svd-xt", "svd25f", "svd-25f", "stablevideoxt", "svdxt"],
short_name="openjourney-v2", output_modality="video",
config_path="configs/stable-diffusion-v1.yaml", defaults={"size": "1024x576"},
weights_url="https://huggingface.co/prompthero/openjourney-v2/resolve/47257274a40e93dab7fbc0cd2cfd5f5704cfeb60/openjourney-v2.ckpt", config_path="configs/svd_xt.yaml",
default_image_size=512, ),
default_negative_prompt="", ModelArchitecture(
alias="oj2", name="Stable Video Diffusion - XT - Image Decoder",
), aliases=[
ModelConfig( "svd-xt-image-decoder",
description="OpenJourney V4", "svd-xt-imdec",
short_name="openjourney-v4", "svd-25f-imdec",
config_path="configs/stable-diffusion-v1.yaml", "svdxt-imdec",
weights_url="https://huggingface.co/prompthero/openjourney/resolve/e291118e93d5423dc88ac1ed93c02362b17d698f/mdjrny-v4.safetensors", "svdxtimdec",
default_image_size=512, "svd25fimdec",
default_negative_prompt="", "svdxtimdec",
alias="oj4", ],
output_modality="video",
defaults={"size": "1024x576"},
config_path="configs/svd_xt_image_decoder.yaml",
), ),
] ]
MODEL_ARCHITECTURE_LOOKUP = {}
for m in MODEL_ARCHITECTURES:
for a in m.aliases:
MODEL_ARCHITECTURE_LOOKUP[a] = m
video_models = [
{ MODEL_WEIGHT_CONFIGS = [
"short_name": "svd", ModelWeightsConfig(
"description": "Stable Video Diffusion", name="Stable Diffusion 1.5",
"default_frames": 14, aliases=MODEL_ARCHITECTURE_LOOKUP["sd15"].aliases,
"default_steps": 25, architecture=MODEL_ARCHITECTURE_LOOKUP["sd15"],
"config_path": "configs/svd.yaml", defaults={"negative_prompt": DEFAULT_NEGATIVE_PROMPT},
"weights_url": "https://huggingface.co/imaginairy/stable-video-diffusion/resolve/f9dce2757a0713da6262f35438050357c2be7ee6/svd.fp16.safetensors", weights_location="https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/889b629140e71758e1e0006e355c331a5744b4bf/v1-5-pruned-emaonly.ckpt",
}, ),
{ ModelWeightsConfig(
"short_name": "svd_image_decoder", name="Stable Diffusion 1.5 - Inpainting",
"description": "Stable Video Diffusion - Image Decoder", aliases=MODEL_ARCHITECTURE_LOOKUP["sd15inpaint"].aliases,
"default_frames": 14, architecture=MODEL_ARCHITECTURE_LOOKUP["sd15inpaint"],
"default_steps": 25, defaults={"negative_prompt": DEFAULT_NEGATIVE_PROMPT},
"config_path": "configs/svd_image_decoder.yaml", weights_location="https://huggingface.co/julienacquaviva/inpainting/resolve/2155ff7fe38b55f4c0d99c2f1ab9b561f8311ca7/sd-v1-5-inpainting.ckpt",
"weights_url": "https://huggingface.co/imaginairy/stable-video-diffusion/resolve/f9dce2757a0713da6262f35438050357c2be7ee6/svd_image_decoder.fp16.safetensors", ),
}, ModelWeightsConfig(
{ name="OpenJourney V1",
"short_name": "svd_xt", aliases=["openjourney-v1", "oj1", "ojv1", "openjourney1"],
"description": "Stable Video Diffusion - XT", architecture=MODEL_ARCHITECTURE_LOOKUP["sd15"],
"default_frames": 25, defaults={"negative_prompt": "poor quality"},
"default_steps": 30, weights_location="https://huggingface.co/prompthero/openjourney/resolve/7428477dad893424c92f6ea1cc29d45f6d1448c1/mdjrny-v4.safetensors",
"config_path": "configs/svd_xt.yaml", ),
"weights_url": "https://huggingface.co/imaginairy/stable-video-diffusion/resolve/f9dce2757a0713da6262f35438050357c2be7ee6/svd_xt.fp16.safetensors", ModelWeightsConfig(
}, name="OpenJourney V2",
{ aliases=["openjourney-v2", "oj2", "ojv2", "openjourney2"],
"short_name": "svd_xt_image_decoder", architecture=MODEL_ARCHITECTURE_LOOKUP["sd15"],
"description": "Stable Video Diffusion - XT - Image Decoder", weights_location="https://huggingface.co/prompthero/openjourney-v2/resolve/47257274a40e93dab7fbc0cd2cfd5f5704cfeb60/openjourney-v2.ckpt",
"default_frames": 25, defaults={"negative_prompt": "poor quality"},
"default_steps": 30, ),
"config_path": "configs/svd_xt_image_decoder.yaml", ModelWeightsConfig(
"weights_url": "https://huggingface.co/imaginairy/stable-video-diffusion/resolve/f9dce2757a0713da6262f35438050357c2be7ee6/svd_xt_image_decoder.fp16.safetensors", name="OpenJourney V4",
}, aliases=["openjourney-v4", "oj4", "ojv4", "openjourney4", "openjourney", "oj"],
architecture=MODEL_ARCHITECTURE_LOOKUP["sd15"],
weights_location="https://huggingface.co/prompthero/openjourney/resolve/e291118e93d5423dc88ac1ed93c02362b17d698f/mdjrny-v4.safetensors",
defaults={"negative_prompt": "poor quality"},
),
# Video Weights
ModelWeightsConfig(
name="Stable Video Diffusion",
aliases=MODEL_ARCHITECTURE_LOOKUP["svd"].aliases,
architecture=MODEL_ARCHITECTURE_LOOKUP["svd"],
weights_location="https://huggingface.co/imaginairy/stable-video-diffusion/resolve/f9dce2757a0713da6262f35438050357c2be7ee6/svd.fp16.safetensors",
defaults={"frames": 14, "steps": 25},
),
ModelWeightsConfig(
name="Stable Video Diffusion - Image Decoder",
aliases=MODEL_ARCHITECTURE_LOOKUP["svd-image-decoder"].aliases,
architecture=MODEL_ARCHITECTURE_LOOKUP["svd-image-decoder"],
weights_location="https://huggingface.co/imaginairy/stable-video-diffusion/resolve/f9dce2757a0713da6262f35438050357c2be7ee6/svd_image_decoder.fp16.safetensors",
defaults={"frames": 14, "steps": 25},
),
ModelWeightsConfig(
name="Stable Video Diffusion - XT",
aliases=MODEL_ARCHITECTURE_LOOKUP["svdxt"].aliases,
architecture=MODEL_ARCHITECTURE_LOOKUP["svdxt"],
weights_location="https://huggingface.co/imaginairy/stable-video-diffusion/resolve/f9dce2757a0713da6262f35438050357c2be7ee6/svd_xt.fp16.safetensors",
defaults={"frames": 25, "steps": 30},
),
ModelWeightsConfig(
name="Stable Video Diffusion - XT - Image Decoder",
aliases=MODEL_ARCHITECTURE_LOOKUP["svd-xt-image-decoder"].aliases,
architecture=MODEL_ARCHITECTURE_LOOKUP["svd-xt-image-decoder"],
weights_location="https://huggingface.co/imaginairy/stable-video-diffusion/resolve/f9dce2757a0713da6262f35438050357c2be7ee6/svd_xt_image_decoder.fp16.safetensors",
defaults={"frames": 25, "steps": 30},
),
] ]
video_models = {m["short_name"]: m for m in video_models}
MODEL_CONFIG_SHORTCUTS = {m.short_name: m for m in MODEL_CONFIGS} MODEL_WEIGHT_CONFIG_LOOKUP = {}
for m in MODEL_CONFIGS: for m in MODEL_WEIGHT_CONFIGS:
if m.alias: for a in m.aliases:
MODEL_CONFIG_SHORTCUTS[m.alias] = m MODEL_WEIGHT_CONFIG_LOOKUP[a] = m
MODEL_CONFIG_SHORTCUTS["openjourney"] = MODEL_CONFIG_SHORTCUTS["openjourney-v2"]
MODEL_CONFIG_SHORTCUTS["oj"] = MODEL_CONFIG_SHORTCUTS["openjourney-v2"]
MODEL_SHORT_NAMES = sorted(MODEL_CONFIG_SHORTCUTS.keys()) IMAGE_WEIGHTS_SHORT_NAMES = [
k
for k, mw in MODEL_WEIGHT_CONFIG_LOOKUP.items()
if mw.architecture.output_modality == "image"
]
IMAGE_WEIGHTS_SHORT_NAMES.sort()
@dataclass @dataclass
class ControlNetConfig: class ControlConfig:
short_name: str name: str
aliases: List[str]
control_type: str control_type: str
config_path: str config_path: str
weights_url: str weights_location: str
alias: str = None
CONTROLNET_CONFIGS = [ CONTROL_CONFIGS = [
ControlNetConfig( ControlConfig(
short_name="canny15", name="Canny Edge Control",
aliases=["canny", "canny15"],
control_type="canny", control_type="canny",
config_path="configs/control-net-v15.yaml", config_path="configs/control-net-v15.yaml",
weights_url="https://huggingface.co/lllyasviel/control_v11p_sd15_canny/resolve/115a470d547982438f70198e353a921996e2e819/diffusion_pytorch_model.fp16.safetensors", weights_location="https://huggingface.co/lllyasviel/control_v11p_sd15_canny/resolve/115a470d547982438f70198e353a921996e2e819/diffusion_pytorch_model.fp16.safetensors",
alias="canny",
), ),
ControlNetConfig( ControlConfig(
short_name="depth15", name="Depth Control",
aliases=["depth", "depth15"],
control_type="depth", control_type="depth",
config_path="configs/control-net-v15.yaml", config_path="configs/control-net-v15.yaml",
weights_url="https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth/resolve/539f99181d33db39cf1af2e517cd8056785f0a87/diffusion_pytorch_model.fp16.safetensors", weights_location="https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth/resolve/539f99181d33db39cf1af2e517cd8056785f0a87/diffusion_pytorch_model.fp16.safetensors",
alias="depth",
), ),
ControlNetConfig( ControlConfig(
short_name="normal15", name="Normal Map Control",
aliases=["normal", "normal15"],
control_type="normal", control_type="normal",
config_path="configs/control-net-v15.yaml", config_path="configs/control-net-v15.yaml",
weights_url="https://huggingface.co/lllyasviel/control_v11p_sd15_normalbae/resolve/cb7296e6587a219068e9d65864e38729cd862aa8/diffusion_pytorch_model.fp16.safetensors", weights_location="https://huggingface.co/lllyasviel/control_v11p_sd15_normalbae/resolve/cb7296e6587a219068e9d65864e38729cd862aa8/diffusion_pytorch_model.fp16.safetensors",
alias="normal",
), ),
ControlNetConfig( ControlConfig(
short_name="hed15", name="Soft Edge Control (HED)",
aliases=["hed", "hed15"],
control_type="hed", control_type="hed",
config_path="configs/control-net-v15.yaml", config_path="configs/control-net-v15.yaml",
weights_url="https://huggingface.co/lllyasviel/control_v11p_sd15_softedge/resolve/b5bcad0c48e9b12f091968cf5eadbb89402d6bc9/diffusion_pytorch_model.fp16.safetensors", weights_location="https://huggingface.co/lllyasviel/control_v11p_sd15_softedge/resolve/b5bcad0c48e9b12f091968cf5eadbb89402d6bc9/diffusion_pytorch_model.fp16.safetensors",
alias="hed",
), ),
ControlNetConfig( ControlConfig(
short_name="openpose15", name="Pose Control",
control_type="openpose", control_type="openpose",
config_path="configs/control-net-v15.yaml", config_path="configs/control-net-v15.yaml",
weights_url="https://huggingface.co/lllyasviel/control_v11p_sd15_openpose/resolve/9ae9f970358db89e211b87c915f9535c6686d5ba/diffusion_pytorch_model.fp16.safetensors", weights_location="https://huggingface.co/lllyasviel/control_v11p_sd15_openpose/resolve/9ae9f970358db89e211b87c915f9535c6686d5ba/diffusion_pytorch_model.fp16.safetensors",
alias="openpose", aliases=["openpose", "pose", "pose15", "openpose15"],
), ),
ControlNetConfig( ControlConfig(
short_name="shuffle15", name="Shuffle Control",
control_type="shuffle", control_type="shuffle",
config_path="configs/control-net-v15-pool.yaml", config_path="configs/control-net-v15-pool.yaml",
weights_url="https://huggingface.co/lllyasviel/control_v11e_sd15_shuffle/resolve/8cf275970f984acf5cc0fdfa537db8be098936a3/diffusion_pytorch_model.fp16.safetensors", weights_location="https://huggingface.co/lllyasviel/control_v11e_sd15_shuffle/resolve/8cf275970f984acf5cc0fdfa537db8be098936a3/diffusion_pytorch_model.fp16.safetensors",
alias="shuffle", aliases=["shuffle", "shuffle15"],
), ),
# "instruct pix2pix" # "instruct pix2pix"
ControlNetConfig( ControlConfig(
short_name="edit15", name="Edit Prompt Control",
aliases=["edit", "edit15"],
control_type="edit", control_type="edit",
config_path="configs/control-net-v15.yaml", config_path="configs/control-net-v15.yaml",
weights_url="https://huggingface.co/lllyasviel/control_v11e_sd15_ip2p/resolve/1fed6ebb905c61929a60514830eb05b039969d6d/diffusion_pytorch_model.fp16.safetensors", weights_location="https://huggingface.co/lllyasviel/control_v11e_sd15_ip2p/resolve/1fed6ebb905c61929a60514830eb05b039969d6d/diffusion_pytorch_model.fp16.safetensors",
alias="edit",
), ),
ControlNetConfig( ControlConfig(
short_name="inpaint15", name="Inpaint Control",
aliases=["inpaint", "inpaint15"],
control_type="inpaint", control_type="inpaint",
config_path="configs/control-net-v15.yaml", config_path="configs/control-net-v15.yaml",
weights_url="https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint/resolve/c96e03a807e64135568ba8aecb66b3a306ec73bd/diffusion_pytorch_model.fp16.safetensors", weights_location="https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint/resolve/c96e03a807e64135568ba8aecb66b3a306ec73bd/diffusion_pytorch_model.fp16.safetensors",
alias="inpaint",
), ),
ControlNetConfig( ControlConfig(
short_name="details15", name="Details Control (Upscale Tile)",
aliases=["details", "details15"],
control_type="details", control_type="details",
config_path="configs/control-net-v15.yaml", config_path="configs/control-net-v15.yaml",
weights_url="https://huggingface.co/lllyasviel/control_v11f1e_sd15_tile/resolve/3f877705c37010b7221c3d10743307d6b5b6efac/diffusion_pytorch_model.bin", weights_location="https://huggingface.co/lllyasviel/control_v11f1e_sd15_tile/resolve/3f877705c37010b7221c3d10743307d6b5b6efac/diffusion_pytorch_model.bin",
alias="details",
), ),
ControlNetConfig( ControlConfig(
short_name="colorize15", name="Brightness Control (Colorize)",
aliases=["colorize", "colorize15"],
control_type="colorize", control_type="colorize",
config_path="configs/control-net-v15.yaml", config_path="configs/control-net-v15.yaml",
weights_url="https://huggingface.co/ioclab/control_v1p_sd15_brightness/resolve/8509361eb1ba89c03839040ed8c75e5f11bbd9c5/diffusion_pytorch_model.safetensors", weights_location="https://huggingface.co/ioclab/control_v1p_sd15_brightness/resolve/8509361eb1ba89c03839040ed8c75e5f11bbd9c5/diffusion_pytorch_model.safetensors",
alias="colorize",
), ),
] ]
CONTROLNET_CONFIG_SHORTCUTS = {} CONTROL_CONFIG_SHORTCUTS: dict[str, ControlConfig] = {}
for m in CONTROLNET_CONFIGS: for m in CONTROL_CONFIGS:
if m.alias: for a in m.aliases:
CONTROLNET_CONFIG_SHORTCUTS[m.alias] = m CONTROL_CONFIG_SHORTCUTS[a] = m
for m in CONTROLNET_CONFIGS:
CONTROLNET_CONFIG_SHORTCUTS[m.short_name] = m @dataclass
class SolverConfig:
SAMPLER_TYPE_OPTIONS = [ name: str
# "plms", short_name: str
"ddim", aliases: List[str]
"k_dpmpp_2m" papers: List[str]
# "k_dpm_fast", implementations: List[str]
# "k_dpm_adaptive",
# "k_lms",
# "k_dpm_2", SOLVER_CONFIGS = [
# "k_dpm_2_a", SolverConfig(
# "k_dpmpp_2m", name="DDIM",
# "k_dpmpp_2s_a", short_name="DDIM",
# "k_euler", aliases=["ddim"],
# "k_euler_a", papers=["https://arxiv.org/abs/2010.02502"],
# "k_heun", implementations=[
"https://github.com/ermongroup/ddim",
"https://github.com/Stability-AI/stablediffusion/blob/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/ldm/models/diffusion/ddim.py#L10",
"https://github.com/huggingface/diffusers/blob/76c645d3a641c879384afcb43496f0b7db8cc5cb/src/diffusers/schedulers/scheduling_ddim.py#L131",
],
),
SolverConfig(
name="DPM-Solver++",
short_name="DPMPP",
aliases=["dpmpp", "dpm++", "dpmsolver"],
papers=["https://arxiv.org/abs/2211.01095"],
implementations=[
"https://github.com/LuChengTHU/dpm-solver/blob/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/dpm_solver_pytorch.py#L337",
"https://github.com/apple/ml-stable-diffusion/blob/7449ce46a4b23c94413b714704202e4ea4c55080/swift/StableDiffusion/pipeline/DPMSolverMultistepScheduler.swift#L27",
"https://github.com/crowsonkb/k-diffusion/blob/045515774882014cc14c1ba2668ab5bad9cbf7c0/k_diffusion/sampling.py#L509",
],
),
] ]
SOLVER_TYPE_NAMES = [s.aliases[0] for s in SOLVER_CONFIGS]
SOLVER_LOOKUP = {}
for s in SOLVER_CONFIGS:
for a in s.aliases:
SOLVER_LOOKUP[a.lower()] = s

@ -26,7 +26,7 @@ class StableStudioStyle(BaseModel):
image: Optional[HttpUrl] = None image: Optional[HttpUrl] = None
class StableStudioSampler(BaseModel): class StableStudioSolver(BaseModel):
id: str id: str
name: Optional[str] = None name: Optional[str] = None
@ -55,7 +55,7 @@ class StableStudioInput(BaseModel, extra=Extra.forbid):
style: Optional[str] = None style: Optional[str] = None
width: Optional[int] = None width: Optional[int] = None
height: Optional[int] = None height: Optional[int] = None
sampler: Optional[StableStudioSampler] = None solver: Optional[StableStudioSolver] = None
cfg_scale: Optional[float] = Field(None, alias="cfgScale") cfg_scale: Optional[float] = Field(None, alias="cfgScale")
steps: Optional[int] = None steps: Optional[int] = None
seed: Optional[int] = None seed: Optional[int] = None
@ -88,14 +88,14 @@ class StableStudioInput(BaseModel, extra=Extra.forbid):
mask_image = self.mask_image.blob if self.mask_image else None mask_image = self.mask_image.blob if self.mask_image else None
sampler_type = self.sampler.id if self.sampler else None solver_type = self.solver.id if self.solver else None
return ImaginePrompt( return ImaginePrompt(
prompt=positive_prompt, prompt=positive_prompt,
prompt_strength=self.cfg_scale, prompt_strength=self.cfg_scale,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
model=self.model, model=self.model,
sampler_type=sampler_type, solver_type=solver_type,
seed=self.seed, seed=self.seed,
steps=self.steps, steps=self.steps,
height=self.height, height=self.height,

@ -8,7 +8,7 @@ from imaginairy.http_app.stablestudio.models import (
StableStudioBatchResponse, StableStudioBatchResponse,
StableStudioImage, StableStudioImage,
StableStudioModel, StableStudioModel,
StableStudioSampler, StableStudioSolver,
) )
from imaginairy.http_app.utils import generate_image_b64 from imaginairy.http_app.utils import generate_image_b64
@ -37,11 +37,14 @@ async def generate(studio_request: StableStudioBatchRequest):
@router.get("/samplers") @router.get("/samplers")
async def list_samplers(): async def list_samplers():
from imaginairy.config import SAMPLER_TYPE_OPTIONS from imaginairy.config import SOLVER_CONFIGS
sampler_objs = [] sampler_objs = []
for sampler_type in SAMPLER_TYPE_OPTIONS:
sampler_obj = StableStudioSampler(id=sampler_type, name=sampler_type) for solver_config in SOLVER_CONFIGS:
sampler_obj = StableStudioSolver(
id=solver_config.aliases[0], name=solver_config.aliases[0]
)
sampler_objs.append(sampler_obj) sampler_objs.append(sampler_obj)
return sampler_objs return sampler_objs
@ -49,10 +52,10 @@ async def list_samplers():
@router.get("/models") @router.get("/models")
async def list_models(): async def list_models():
from imaginairy.config import MODEL_CONFIGS from imaginairy.config import MODEL_WEIGHT_CONFIGS
model_objs = [] model_objs = []
for model_config in MODEL_CONFIGS: for model_config in MODEL_WEIGHT_CONFIGS:
if "inpaint" in model_config.description.lower(): if "inpaint" in model_config.description.lower():
continue continue
model_obj = StableStudioModel( model_obj = StableStudioModel(

@ -17,7 +17,7 @@ from refiners.foundationals.latent_diffusion import SD1ControlnetAdapter, SD1UNe
from safetensors.torch import load_file from safetensors.torch import load_file
from imaginairy import config as iconfig from imaginairy import config as iconfig
from imaginairy.config import MODEL_SHORT_NAMES from imaginairy.config import IMAGE_WEIGHTS_SHORT_NAMES, ModelArchitecture
from imaginairy.modules import attention from imaginairy.modules import attention
from imaginairy.paths import PKG_ROOT from imaginairy.paths import PKG_ROOT
from imaginairy.utils import get_device, instantiate_from_config from imaginairy.utils import get_device, instantiate_from_config
@ -66,7 +66,7 @@ def load_state_dict(weights_location, half_mode=False, device=None):
except FileNotFoundError as e: except FileNotFoundError as e:
if e.errno == 2: if e.errno == 2:
logger.error( logger.error(
f'Error: "{ckpt_path}" not a valid path to model weights.\nPreconfigured models you can use: {MODEL_SHORT_NAMES}.' f'Error: "{ckpt_path}" not a valid path to model weights.\nPreconfigured models you can use: {IMAGE_WEIGHTS_SHORT_NAMES}.'
) )
sys.exit(1) sys.exit(1)
raise raise
@ -149,7 +149,7 @@ def add_controlnet(base_state_dict, controlnet_state_dict):
def get_diffusion_model( def get_diffusion_model(
weights_location=iconfig.DEFAULT_MODEL, weights_location=iconfig.DEFAULT_MODEL_WEIGHTS,
config_path="configs/stable-diffusion-v1.yaml", config_path="configs/stable-diffusion-v1.yaml",
control_weights_locations=None, control_weights_locations=None,
half_mode=None, half_mode=None,
@ -174,7 +174,7 @@ def get_diffusion_model(
f"Failed to load inpainting model. Attempting to fall-back to standard model. {e!s}" f"Failed to load inpainting model. Attempting to fall-back to standard model. {e!s}"
) )
return _get_diffusion_model( return _get_diffusion_model(
iconfig.DEFAULT_MODEL, iconfig.DEFAULT_MODEL_WEIGHTS,
config_path, config_path,
half_mode, half_mode,
for_inpainting=False, for_inpainting=False,
@ -184,8 +184,8 @@ def get_diffusion_model(
def _get_diffusion_model( def _get_diffusion_model(
weights_location=iconfig.DEFAULT_MODEL, weights_location=iconfig.DEFAULT_MODEL_WEIGHTS,
config_path="configs/stable-diffusion-v1.yaml", model_architecture="configs/stable-diffusion-v1.yaml",
half_mode=None, half_mode=None,
for_inpainting=False, for_inpainting=False,
control_weights_locations=None, control_weights_locations=None,
@ -197,24 +197,20 @@ def _get_diffusion_model(
""" """
global MOST_RECENTLY_LOADED_MODEL global MOST_RECENTLY_LOADED_MODEL
( model_weights_config = resolve_model_weights_config(
model_config, model_weights=weights_location,
weights_location, default_model_architecture=model_architecture,
config_path,
control_weights_locations,
) = resolve_model_paths(
weights_path=weights_location,
config_path=config_path,
control_weights_paths=control_weights_locations,
for_inpainting=for_inpainting, for_inpainting=for_inpainting,
) )
# some models need the attention calculated in float32 # some models need the attention calculated in float32
if model_config is not None: if model_weights_config is not None:
attention.ATTENTION_PRECISION_OVERRIDE = model_config.forced_attn_precision attention.ATTENTION_PRECISION_OVERRIDE = (
model_weights_config.forced_attn_precision
)
else: else:
attention.ATTENTION_PRECISION_OVERRIDE = "default" attention.ATTENTION_PRECISION_OVERRIDE = "default"
diffusion_model = _load_diffusion_model( diffusion_model = _load_diffusion_model(
config_path=config_path, config_path=model_weights_config.architecture.config_path,
weights_location=weights_location, weights_location=weights_location,
half_mode=half_mode, half_mode=half_mode,
) )
@ -229,8 +225,8 @@ def _get_diffusion_model(
def get_diffusion_model_refiners( def get_diffusion_model_refiners(
weights_location=iconfig.DEFAULT_MODEL, weights_location=iconfig.DEFAULT_MODEL_WEIGHTS,
config_path="configs/stable-diffusion-v1.yaml", model_architecture=None,
control_weights_locations=None, control_weights_locations=None,
dtype=None, dtype=None,
for_inpainting=False, for_inpainting=False,
@ -243,8 +239,8 @@ def get_diffusion_model_refiners(
try: try:
return _get_diffusion_model_refiners( return _get_diffusion_model_refiners(
weights_location, weights_location,
config_path, model_architecture=model_architecture,
for_inpainting, for_inpainting=for_inpainting,
dtype=dtype, dtype=dtype,
control_weights_locations=control_weights_locations, control_weights_locations=control_weights_locations,
) )
@ -254,8 +250,8 @@ def get_diffusion_model_refiners(
f"Failed to load inpainting model. Attempting to fall-back to standard model. {e!s}" f"Failed to load inpainting model. Attempting to fall-back to standard model. {e!s}"
) )
return _get_diffusion_model_refiners( return _get_diffusion_model_refiners(
iconfig.DEFAULT_MODEL, iconfig.DEFAULT_MODEL_WEIGHTS,
config_path, model_architecture=model_architecture,
dtype=dtype, dtype=dtype,
for_inpainting=False, for_inpainting=False,
control_weights_locations=control_weights_locations, control_weights_locations=control_weights_locations,
@ -264,8 +260,8 @@ def get_diffusion_model_refiners(
def _get_diffusion_model_refiners( def _get_diffusion_model_refiners(
weights_location=iconfig.DEFAULT_MODEL, weights_location=iconfig.DEFAULT_MODEL_WEIGHTS,
config_path="configs/stable-diffusion-v1.yaml", model_architecture=None,
for_inpainting=False, for_inpainting=False,
control_weights_locations=None, control_weights_locations=None,
device=None, device=None,
@ -279,7 +275,7 @@ def _get_diffusion_model_refiners(
sd = _get_diffusion_model_refiners_only( sd = _get_diffusion_model_refiners_only(
weights_location=weights_location, weights_location=weights_location,
config_path=config_path, model_architecture=model_architecture,
for_inpainting=for_inpainting, for_inpainting=for_inpainting,
device=device, device=device,
dtype=dtype, dtype=dtype,
@ -290,10 +286,9 @@ def _get_diffusion_model_refiners(
@lru_cache(maxsize=1) @lru_cache(maxsize=1)
def _get_diffusion_model_refiners_only( def _get_diffusion_model_refiners_only(
weights_location=iconfig.DEFAULT_MODEL, weights_location=iconfig.DEFAULT_MODEL_WEIGHTS,
config_path="configs/stable-diffusion-v1.yaml", model_architecture=None,
for_inpainting=False, for_inpainting=False,
control_weights_locations=None,
device=None, device=None,
dtype=torch.float16, dtype=torch.float16,
): ):
@ -312,28 +307,17 @@ def _get_diffusion_model_refiners_only(
device = device or get_device() device = device or get_device()
( model_weights_config = resolve_model_weights_config(
model_config, model_weights=weights_location,
weights_location, default_model_architecture=model_architecture,
config_path,
control_weights_locations,
) = resolve_model_paths(
weights_path=weights_location,
config_path=config_path,
control_weights_paths=control_weights_locations,
for_inpainting=for_inpainting, for_inpainting=for_inpainting,
) )
# some models need the attention calculated in float32
if model_config is not None:
attention.ATTENTION_PRECISION_OVERRIDE = model_config.forced_attn_precision
else:
attention.ATTENTION_PRECISION_OVERRIDE = "default"
( (
vae_weights, vae_weights,
unet_weights, unet_weights,
text_encoder_weights, text_encoder_weights,
) = load_stable_diffusion_compvis_weights(weights_location) ) = load_stable_diffusion_compvis_weights(model_weights_config.weights_location)
if for_inpainting: if for_inpainting:
unet = SD1UNet(in_channels=9) unet = SD1UNet(in_channels=9)
@ -422,58 +406,69 @@ def load_controlnet(control_weights_location, half_mode):
return controlnet return controlnet
def resolve_model_paths( def resolve_model_weights_config(
weights_path=iconfig.DEFAULT_MODEL, model_weights: str,
config_path=None, default_model_architecture: str | None = None,
control_weights_paths=None, for_inpainting: bool = False,
for_inpainting=False, ) -> iconfig.ModelWeightsConfig:
):
"""Resolve weight and config path if they happen to be shortcuts.""" """Resolve weight and config path if they happen to be shortcuts."""
model_metadata_w = iconfig.MODEL_CONFIG_SHORTCUTS.get(weights_path, None)
model_metadata_c = iconfig.MODEL_CONFIG_SHORTCUTS.get(config_path, None) if for_inpainting:
model_weights_config = iconfig.MODEL_WEIGHT_CONFIG_LOOKUP.get(
control_weights_paths = control_weights_paths or [] f"{model_weights.lower()}-inpaint", None
control_net_metadatas = [ )
iconfig.CONTROLNET_CONFIG_SHORTCUTS.get(control_weights_path, None) if model_weights_config:
for control_weights_path in control_weights_paths return model_weights_config
]
model_weights_config = iconfig.MODEL_WEIGHT_CONFIG_LOOKUP.get(
if not control_net_metadatas and for_inpainting: model_weights.lower(), None
model_metadata_w = iconfig.MODEL_CONFIG_SHORTCUTS.get( )
f"{weights_path}-inpaint", model_metadata_w if model_weights_config:
return model_weights_config
if not default_model_architecture:
msg = "You must specify the model architecture when loading custom weights."
raise ValueError(msg)
default_model_architecture = default_model_architecture.lower()
model_architecture_config = None
if for_inpainting:
model_architecture_config = iconfig.MODEL_ARCHITECTURE_LOOKUP.get(
f"{default_model_architecture}-inpaint", None
) )
model_metadata_c = iconfig.MODEL_CONFIG_SHORTCUTS.get(
f"{config_path}-inpaint", model_metadata_c if not model_architecture_config:
model_architecture_config = iconfig.MODEL_ARCHITECTURE_LOOKUP.get(
default_model_architecture, None
)
if model_architecture_config is None:
msg = f"Invalid model architecture: {default_model_architecture}"
raise ValueError(msg)
model_weights_config = iconfig.ModelWeightsConfig(
name="Custom Loaded",
aliases=[],
architecture=model_architecture_config,
weights_location=model_weights,
defaults={},
)
return model_weights_config
def get_model_default_image_size(model_architecture: str | ModelArchitecture):
if isinstance(model_architecture, str):
model_architecture = iconfig.MODEL_WEIGHT_CONFIG_LOOKUP.get(
model_architecture, None
) )
default_size = None
if model_architecture:
default_size = model_architecture.defaults.get("size")
if model_metadata_w: if default_size is None:
if config_path is None: default_size = 512
config_path = model_metadata_w.config_path return default_size
weights_path = model_metadata_w.weights_url
if model_metadata_c:
config_path = model_metadata_c.config_path
if config_path is None:
config_path = iconfig.MODEL_CONFIG_SHORTCUTS[iconfig.DEFAULT_MODEL].config_path
if control_net_metadatas:
if "stable-diffusion-v1" not in config_path:
msg = "Control net is only supported for stable diffusion v1. Please use a different model."
raise ValueError(msg)
control_weights_paths = [cnm.weights_url for cnm in control_net_metadatas]
config_path = control_net_metadatas[0].config_path
model_metadata = model_metadata_w or model_metadata_c
logger.debug(f"Loading model weights from: {weights_path}")
logger.debug(f"Loading model config from: {config_path}")
return model_metadata, weights_path, config_path, control_weights_paths
def get_model_default_image_size(weights_location):
model_config = iconfig.MODEL_CONFIG_SHORTCUTS.get(weights_location, None)
if model_config:
return model_config.default_image_size
return 512
def get_current_diffusion_model(): def get_current_diffusion_model():

@ -1,21 +1,20 @@
from imaginairy.samplers import kdiff from imaginairy.samplers.base import SolverName # noqa
from imaginairy.samplers.base import SamplerName # noqa from imaginairy.samplers.ddim import DDIMSolver
from imaginairy.samplers.ddim import DDIMSampler
SAMPLERS = [ SOLVERS = [
# PLMSSampler, # PLMSSolver,
DDIMSampler, DDIMSolver,
# kdiff.DPMFastSampler, # kdiff.DPMFastSampler,
# kdiff.DPMAdaptiveSampler, # kdiff.DPMAdaptiveSampler,
# kdiff.LMSSampler, # kdiff.LMSSampler,
# kdiff.DPM2Sampler, # kdiff.DPM2Sampler,
# kdiff.DPM2AncestralSampler, # kdiff.DPM2AncestralSampler,
kdiff.DPMPP2MSampler, # kdiff.DPMPP2MSampler,
# kdiff.DPMPP2SAncestralSampler, # kdiff.DPMPP2SAncestralSampler,
# kdiff.EulerSampler, # kdiff.EulerSampler,
# kdiff.EulerAncestralSampler, # kdiff.EulerAncestralSampler,
# kdiff.HeunSampler, # kdiff.HeunSampler,
] ]
SAMPLER_LOOKUP = {sampler.short_name: sampler for sampler in SAMPLERS} SOLVER_LOOKUP = {s.short_name: s for s in SOLVERS}
SAMPLER_TYPE_OPTIONS = [sampler.short_name for sampler in SAMPLERS] SOLVER_TYPE_OPTIONS = [s.short_name for s in SOLVERS]

@ -16,9 +16,10 @@ from imaginairy.utils import get_device
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class SamplerName: class SolverName:
PLMS = "plms" PLMS = "plms"
DDIM = "ddim" DDIM = "ddim"
DPMPP = "dpmpp"
K_DPM_FAST = "k_dpm_fast" K_DPM_FAST = "k_dpm_fast"
K_DPM_ADAPTIVE = "k_dpm_adaptive" K_DPM_ADAPTIVE = "k_dpm_adaptive"
K_LMS = "k_lms" K_LMS = "k_lms"
@ -31,7 +32,7 @@ class SamplerName:
K_HEUN = "k_heun" K_HEUN = "k_heun"
class ImageSampler(ABC): class ImageSolver(ABC):
short_name: str short_name: str
name: str name: str
default_steps: int default_steps: int

@ -8,9 +8,9 @@ from tqdm import tqdm
from imaginairy.log_utils import increment_step, log_latent from imaginairy.log_utils import increment_step, log_latent
from imaginairy.modules.diffusion.util import extract_into_tensor, noise_like from imaginairy.modules.diffusion.util import extract_into_tensor, noise_like
from imaginairy.samplers.base import ( from imaginairy.samplers.base import (
ImageSampler, ImageSolver,
NoiseSchedule, NoiseSchedule,
SamplerName, SolverName,
get_noise_prediction, get_noise_prediction,
mask_blend, mask_blend,
) )
@ -19,14 +19,14 @@ from imaginairy.utils import get_device
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DDIMSampler(ImageSampler): class DDIMSolver(ImageSolver):
""" """
Denoising Diffusion Implicit Models. Denoising Diffusion Implicit Models.
https://arxiv.org/abs/2010.02502 https://arxiv.org/abs/2010.02502
""" """
short_name = SamplerName.DDIM short_name = SolverName.DDIM
name = "Denoising Diffusion Implicit Models" name = "Denoising Diffusion Implicit Models"
default_steps = 50 default_steps = 50

@ -6,8 +6,8 @@ from torch import nn
from imaginairy.log_utils import increment_step, log_latent from imaginairy.log_utils import increment_step, log_latent
from imaginairy.samplers.base import ( from imaginairy.samplers.base import (
ImageSampler, ImageSolver,
SamplerName, SolverName,
get_noise_prediction, get_noise_prediction,
mask_blend, mask_blend,
) )
@ -57,7 +57,7 @@ def sample_dpm_fast(model, x, sigmas, extra_args=None, disable=False, callback=N
) )
class KDiffusionSampler(ImageSampler, ABC): class KDiffusionSolver(ImageSolver, ABC):
sampler_func: callable sampler_func: callable
def __init__(self, model): def __init__(self, model):
@ -98,9 +98,9 @@ class KDiffusionSampler(ImageSampler, ABC):
# see https://github.com/crowsonkb/k-diffusion/issues/43#issuecomment-1305195666 # see https://github.com/crowsonkb/k-diffusion/issues/43#issuecomment-1305195666
if self.short_name in ( if self.short_name in (
SamplerName.K_DPM_2, SolverName.K_DPM_2,
SamplerName.K_DPMPP_2M, SolverName.K_DPMPP_2M,
SamplerName.K_DPM_2_ANCESTRAL, SolverName.K_DPM_2_ANCESTRAL,
): ):
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
@ -152,73 +152,73 @@ class KDiffusionSampler(ImageSampler, ABC):
# #
# class DPMFastSampler(KDiffusionSampler): # class DPMFastSampler(KDiffusionSolver):
# short_name = SamplerName.K_DPM_FAST # short_name = SolverName.K_DPM_FAST
# name = "Diffusion probabilistic models - fast" # name = "Diffusion probabilistic models - fast"
# default_steps = 15 # default_steps = 15
# sampler_func = staticmethod(sample_dpm_fast) # sampler_func = staticmethod(sample_dpm_fast)
# #
# #
# class DPMAdaptiveSampler(KDiffusionSampler): # class DPMAdaptiveSampler(KDiffusionSolver):
# short_name = SamplerName.K_DPM_ADAPTIVE # short_name = SolverName.K_DPM_ADAPTIVE
# name = "Diffusion probabilistic models - adaptive" # name = "Diffusion probabilistic models - adaptive"
# default_steps = 40 # default_steps = 40
# sampler_func = staticmethod(sample_dpm_adaptive) # sampler_func = staticmethod(sample_dpm_adaptive)
# #
# #
# class DPM2Sampler(KDiffusionSampler): # class DPM2Sampler(KDiffusionSolver):
# short_name = SamplerName.K_DPM_2 # short_name = SolverName.K_DPM_2
# name = "Diffusion probabilistic models - 2" # name = "Diffusion probabilistic models - 2"
# default_steps = 40 # default_steps = 40
# sampler_func = staticmethod(k_sampling.sample_dpm_2) # sampler_func = staticmethod(k_sampling.sample_dpm_2)
# #
# #
# class DPM2AncestralSampler(KDiffusionSampler): # class DPM2AncestralSampler(KDiffusionSolver):
# short_name = SamplerName.K_DPM_2_ANCESTRAL # short_name = SolverName.K_DPM_2_ANCESTRAL
# name = "Diffusion probabilistic models - 2 ancestral" # name = "Diffusion probabilistic models - 2 ancestral"
# default_steps = 40 # default_steps = 40
# sampler_func = staticmethod(k_sampling.sample_dpm_2_ancestral) # sampler_func = staticmethod(k_sampling.sample_dpm_2_ancestral)
# #
class DPMPP2MSampler(KDiffusionSampler): class DPMPP2MSampler(KDiffusionSolver):
short_name = SamplerName.K_DPMPP_2M short_name = SolverName.K_DPMPP_2M
name = "Diffusion probabilistic models - 2m" name = "Diffusion probabilistic models - 2m"
default_steps = 15 default_steps = 15
sampler_func = staticmethod(k_sampling.sample_dpmpp_2m) sampler_func = staticmethod(k_sampling.sample_dpmpp_2m)
# #
# class DPMPP2SAncestralSampler(KDiffusionSampler): # class DPMPP2SAncestralSampler(KDiffusionSolver):
# short_name = SamplerName.K_DPMPP_2S_ANCESTRAL # short_name = SolverName.K_DPMPP_2S_ANCESTRAL
# name = "Ancestral sampling with DPM-Solver++(2S) second-order steps." # name = "Ancestral sampling with DPM-Solver++(2S) second-order steps."
# default_steps = 15 # default_steps = 15
# sampler_func = staticmethod(k_sampling.sample_dpmpp_2s_ancestral) # sampler_func = staticmethod(k_sampling.sample_dpmpp_2s_ancestral)
# #
# #
# class EulerSampler(KDiffusionSampler): # class EulerSampler(KDiffusionSolver):
# short_name = SamplerName.K_EULER # short_name = SolverName.K_EULER
# name = "Algorithm 2 (Euler steps) from Karras et al. (2022)" # name = "Algorithm 2 (Euler steps) from Karras et al. (2022)"
# default_steps = 40 # default_steps = 40
# sampler_func = staticmethod(k_sampling.sample_euler) # sampler_func = staticmethod(k_sampling.sample_euler)
# #
# #
# class EulerAncestralSampler(KDiffusionSampler): # class EulerAncestralSampler(KDiffusionSolver):
# short_name = SamplerName.K_EULER_ANCESTRAL # short_name = SolverName.K_EULER_ANCESTRAL
# name = "Euler ancestral" # name = "Euler ancestral"
# default_steps = 40 # default_steps = 40
# sampler_func = staticmethod(k_sampling.sample_euler_ancestral) # sampler_func = staticmethod(k_sampling.sample_euler_ancestral)
# #
# #
# class HeunSampler(KDiffusionSampler): # class HeunSampler(KDiffusionSolver):
# short_name = SamplerName.K_HEUN # short_name = SolverName.K_HEUN
# name = "Algorithm 2 (Heun steps) from Karras et al. (2022)." # name = "Algorithm 2 (Heun steps) from Karras et al. (2022)."
# default_steps = 40 # default_steps = 40
# sampler_func = staticmethod(k_sampling.sample_heun) # sampler_func = staticmethod(k_sampling.sample_heun)
# #
# #
# class LMSSampler(KDiffusionSampler): # class LMSSampler(KDiffusionSolver):
# short_name = SamplerName.K_LMS # short_name = SolverName.K_LMS
# name = "LMS" # name = "LMS"
# default_steps = 40 # default_steps = 40
# sampler_func = staticmethod(k_sampling.sample_lms) # sampler_func = staticmethod(k_sampling.sample_lms)

@ -8,9 +8,9 @@ from tqdm import tqdm
from imaginairy.log_utils import increment_step, log_latent from imaginairy.log_utils import increment_step, log_latent
from imaginairy.modules.diffusion.util import extract_into_tensor, noise_like from imaginairy.modules.diffusion.util import extract_into_tensor, noise_like
from imaginairy.samplers.base import ( from imaginairy.samplers.base import (
ImageSampler, ImageSolver,
NoiseSchedule, NoiseSchedule,
SamplerName, SolverName,
get_noise_prediction, get_noise_prediction,
mask_blend, mask_blend,
) )
@ -19,7 +19,7 @@ from imaginairy.utils import get_device
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PLMSSampler(ImageSampler): class PLMSSolver(ImageSolver):
""" """
probabilistic least-mean-squares. probabilistic least-mean-squares.
@ -29,7 +29,7 @@ class PLMSSampler(ImageSampler):
https://github.com/luping-liu/PNDM https://github.com/luping-liu/PNDM
""" """
short_name = SamplerName.PLMS short_name = SolverName.PLMS
name = "probabilistic least-mean-squares sampler" name = "probabilistic least-mean-squares sampler"
default_steps = 40 default_steps = 40

@ -7,43 +7,32 @@ import logging
import os.path import os.path
import random import random
from datetime import datetime, timezone from datetime import datetime, timezone
from enum import Enum
from io import BytesIO from io import BytesIO
from typing import TYPE_CHECKING, Any, List, Literal, Optional from typing import TYPE_CHECKING, Any, List
from pydantic import ( from pydantic import (
BaseModel, BaseModel,
ConfigDict,
Field, Field,
GetCoreSchemaHandler, GetCoreSchemaHandler,
field_validator, field_validator,
model_validator, model_validator,
) )
from pydantic_core import core_schema from pydantic_core import core_schema
from typing_extensions import Self
from imaginairy import config from imaginairy import config
if TYPE_CHECKING: if TYPE_CHECKING:
from pathlib import Path
from PIL import Image from PIL import Image
else:
Image = Any
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def save_image_as_base64(image: "Image.Image") -> str:
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_bytes = buffered.getvalue()
return base64.b64encode(img_bytes).decode()
def load_image_from_base64(image_str: str) -> "Image.Image":
from PIL import Image
img_bytes = base64.b64decode(image_str)
return Image.open(io.BytesIO(img_bytes))
class InvalidUrlError(ValueError): class InvalidUrlError(ValueError):
pass pass
@ -52,7 +41,12 @@ class LazyLoadingImage:
"""Image file encoded as base64 string.""" """Image file encoded as base64 string."""
def __init__( def __init__(
self, *, filepath=None, url=None, img: Image = None, b64: Optional[str] = None self,
*,
filepath=None,
url=None,
img: "Image.Image" = None,
b64: str | None = None,
): ):
if not filepath and not url and not img and not b64: if not filepath and not url and not img and not b64:
msg = "You must specify a url or filepath or img or base64 string" msg = "You must specify a url or filepath or img or base64 string"
@ -208,10 +202,10 @@ class LazyLoadingImage:
return f"<LazyLoadingImage RENDER EXCEPTION*{e}*>" return f"<LazyLoadingImage RENDER EXCEPTION*{e}*>"
class ControlNetInput(BaseModel): class ControlInput(BaseModel):
mode: str mode: str
image: Optional[LazyLoadingImage] = None image: LazyLoadingImage | None = None
image_raw: Optional[LazyLoadingImage] = None image_raw: LazyLoadingImage | None = None
strength: float = Field(1, ge=0, le=1000) strength: float = Field(1, ge=0, le=1000)
# @field_validator("image", "image_raw", mode="before") # @field_validator("image", "image_raw", mode="before")
@ -233,8 +227,8 @@ class ControlNetInput(BaseModel):
@field_validator("mode") @field_validator("mode")
def mode_validate(cls, v): def mode_validate(cls, v):
if v not in config.CONTROLNET_CONFIG_SHORTCUTS: if v not in config.CONTROL_CONFIG_SHORTCUTS:
valid_modes = list(config.CONTROLNET_CONFIG_SHORTCUTS.keys()) valid_modes = list(config.CONTROL_CONFIG_SHORTCUTS.keys())
valid_modes = ", ".join(valid_modes) valid_modes = ", ".join(valid_modes)
msg = f"Invalid controlnet mode: '{v}'. Valid modes are: {valid_modes}" msg = f"Invalid controlnet mode: '{v}'. Valid modes are: {valid_modes}"
raise ValueError(msg) raise ValueError(msg)
@ -249,43 +243,51 @@ class WeightedPrompt(BaseModel):
return f"{self.weight}*({self.text})" return f"{self.weight}*({self.text})"
class MaskMode(str, Enum):
REPLACE = "replace"
KEEP = "keep"
MaskInput = MaskMode | str
PromptInput = str | WeightedPrompt | list[WeightedPrompt] | list[str] | None
class ImaginePrompt(BaseModel, protected_namespaces=()): class ImaginePrompt(BaseModel, protected_namespaces=()):
prompt: Optional[List[WeightedPrompt]] = Field(default=None, validate_default=True) model_config = ConfigDict(extra="forbid", validate_assignment=True)
negative_prompt: Optional[List[WeightedPrompt]] = Field(
default=None, validate_default=True prompt: List[WeightedPrompt] = Field(default=None, validate_default=True)
) negative_prompt: List[WeightedPrompt] = Field(default=None, validate_default=True)
prompt_strength: Optional[float] = Field( prompt_strength: float = Field(default=7.5, le=50, ge=-50, validate_default=True)
default=7.5, le=10_000, ge=-10_000, validate_default=True init_image: LazyLoadingImage | None = Field(
)
init_image: Optional[LazyLoadingImage] = Field(
None, description="base64 encoded image", validate_default=True None, description="base64 encoded image", validate_default=True
) )
init_image_strength: Optional[float] = Field( init_image_strength: float | None = Field(
ge=0, le=1, default=None, validate_default=True ge=0, le=1, default=None, validate_default=True
) )
control_inputs: List[ControlNetInput] = Field( control_inputs: List[ControlInput] = Field(
default_factory=list, validate_default=True default_factory=list, validate_default=True
) )
mask_prompt: Optional[str] = Field( mask_prompt: str | None = Field(
default=None, default=None,
description="text description of the things to be masked", description="text description of the things to be masked",
validate_default=True, validate_default=True,
) )
mask_image: Optional[LazyLoadingImage] = Field(default=None, validate_default=True) mask_image: LazyLoadingImage | None = Field(default=None, validate_default=True)
mask_mode: Optional[Literal["keep", "replace"]] = "replace" mask_mode: MaskMode = MaskMode.REPLACE
mask_modify_original: bool = True mask_modify_original: bool = True
outpaint: Optional[str] = "" outpaint: str | None = ""
model: str = Field(default=config.DEFAULT_MODEL, validate_default=True) model_architecture: str | None = None
model_config_path: Optional[str] = None model_weights: str = Field(
sampler_type: str = Field(default=config.DEFAULT_SAMPLER, validate_default=True) default=config.DEFAULT_MODEL_WEIGHTS, validate_default=True
seed: Optional[int] = Field(default=None, validate_default=True) )
steps: Optional[int] = Field(default=None, validate_default=True) solver_type: str = Field(default=config.DEFAULT_SOLVER, validate_default=True)
height: Optional[int] = Field(None, ge=1, le=100_000, validate_default=True) seed: int | None = Field(default=None, validate_default=True)
width: Optional[int] = Field(None, ge=1, le=100_000, validate_default=True) steps: int | None = Field(default=None, validate_default=True)
size: tuple[int, int] | None = Field(default=None, validate_default=True)
upscale: bool = False upscale: bool = False
fix_faces: bool = False fix_faces: bool = False
fix_faces_fidelity: Optional[float] = Field(0.2, ge=0, le=1, validate_default=True) fix_faces_fidelity: float | None = Field(0.2, ge=0, le=1, validate_default=True)
conditioning: Optional[str] = None conditioning: str | None = None
tile_mode: str = "" tile_mode: str = ""
allow_compose_phase: bool = True allow_compose_phase: bool = True
is_intermediate: bool = False is_intermediate: bool = False
@ -294,22 +296,87 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
"", description="text to be overlaid on the image", validate_default=True "", description="text to be overlaid on the image", validate_default=True
) )
class MaskMode: def __init__(
REPLACE = "replace" self,
KEEP = "keep" prompt: PromptInput = "",
*,
def __init__(self, prompt=None, **kwargs): negative_prompt: PromptInput = None,
# allows `prompt` to be positional prompt_strength: float | None = 7.5,
super().__init__(prompt=prompt, **kwargs) init_image: LazyLoadingImage | None = None,
init_image_strength: float | None = None,
control_inputs: List[ControlInput] | None = None,
mask_prompt: str | None = None,
mask_image: LazyLoadingImage | None = None,
mask_mode: MaskInput = MaskMode.REPLACE,
mask_modify_original: bool = True,
outpaint: str | None = "",
model_architecture: str | None = None,
model_weights: str = config.DEFAULT_MODEL_WEIGHTS,
solver_type: str = config.DEFAULT_SOLVER,
seed: int | None = None,
steps: int | None = None,
size: int | str | tuple[int, int] | None = None,
upscale: bool = False,
fix_faces: bool = False,
fix_faces_fidelity: float | None = 0.2,
conditioning: str | None = None,
tile_mode: str = "",
allow_compose_phase: bool = True,
is_intermediate: bool = False,
collect_progress_latents: bool = False,
caption_text: str = "",
):
super().__init__(
prompt=prompt,
negative_prompt=negative_prompt,
prompt_strength=prompt_strength,
init_image=init_image,
init_image_strength=init_image_strength,
control_inputs=control_inputs,
mask_prompt=mask_prompt,
mask_image=mask_image,
mask_mode=mask_mode,
mask_modify_original=mask_modify_original,
outpaint=outpaint,
model_architecture=model_architecture,
model_weights=model_weights,
solver_type=solver_type,
seed=seed,
steps=steps,
size=size,
upscale=upscale,
fix_faces=fix_faces,
fix_faces_fidelity=fix_faces_fidelity,
conditioning=conditioning,
tile_mode=tile_mode,
allow_compose_phase=allow_compose_phase,
is_intermediate=is_intermediate,
collect_progress_latents=collect_progress_latents,
caption_text=caption_text,
)
@field_validator("prompt", "negative_prompt", mode="before") @field_validator("prompt", "negative_prompt", mode="before")
@classmethod def make_into_weighted_prompts(
def make_into_weighted_prompts(cls, v): cls,
if isinstance(v, str): value: PromptInput,
v = [WeightedPrompt(text=v)] ) -> list[WeightedPrompt]:
elif isinstance(v, WeightedPrompt): match value:
v = [v] case None:
return v return []
case str():
if value:
return [WeightedPrompt(text=value)]
else:
return []
case WeightedPrompt():
return [value]
case list():
if all(isinstance(item, str) for item in value):
return [WeightedPrompt(text=p) for p in value]
elif all(isinstance(item, WeightedPrompt) for item in value):
return value
raise ValueError("Invalid prompt input")
@field_validator("prompt", "negative_prompt", mode="after") @field_validator("prompt", "negative_prompt", mode="after")
@classmethod @classmethod
@ -328,16 +395,20 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
@model_validator(mode="after") @model_validator(mode="after")
def validate_negative_prompt(self): def validate_negative_prompt(self):
if self.negative_prompt is None: if (
model_config = config.MODEL_CONFIG_SHORTCUTS.get(self.model, None) self.negative_prompt == [WeightedPrompt(text="")]
if model_config: or self.negative_prompt == []
self.negative_prompt = [ ):
WeightedPrompt(text=model_config.default_negative_prompt) model_weight_config = config.MODEL_WEIGHT_CONFIG_LOOKUP.get(
] self.model_weights, None
else: )
self.negative_prompt = [ default_negative_prompt = config.DEFAULT_NEGATIVE_PROMPT
WeightedPrompt(text=config.DEFAULT_NEGATIVE_PROMPT) if model_weight_config:
] default_negative_prompt = model_weight_config.defaults.get(
"negative_prompt", default_negative_prompt
)
self.negative_prompt = [WeightedPrompt(text=default_negative_prompt)]
return self return self
@field_validator("prompt_strength") @field_validator("prompt_strength")
@ -426,10 +497,10 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
raise ValueError(msg) raise ValueError(msg)
return v return v
@field_validator("model", mode="before") @field_validator("model_weights", mode="before")
def set_default_diffusion_model(cls, v): def set_default_diffusion_model(cls, v):
if v is None: if v is None:
return config.DEFAULT_MODEL return config.DEFAULT_MODEL_WEIGHTS
return v return v
@ -444,33 +515,32 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
return v return v
@field_validator("sampler_type", mode="after") @field_validator("solver_type", mode="after")
def validate_sampler_type(cls, v, info: core_schema.FieldValidationInfo): def validate_solver_type(cls, v, info: core_schema.FieldValidationInfo):
from imaginairy.samplers import SamplerName from imaginairy.samplers import SolverName
if v is None: if v is None:
v = config.DEFAULT_SAMPLER v = config.DEFAULT_SOLVER
v = v.lower() v = v.lower()
if info.data.get("model") == "SD-2.0-v" and v == SamplerName.PLMS: if info.data.get("model") == "SD-2.0-v" and v == SolverName.PLMS:
raise ValueError("PLMS sampler is not supported for SD-2.0-v model.") raise ValueError("PLMS solvers is not supported for SD-2.0-v model.")
if info.data.get("model") == "edit" and v in ( if info.data.get("model") == "edit" and v in (
SamplerName.PLMS, SolverName.PLMS,
SamplerName.DDIM, SolverName.DDIM,
): ):
msg = "PLMS and DDIM samplers are not supported for pix2pix edit model." msg = "PLMS and DDIM solvers are not supported for pix2pix edit model."
raise ValueError(msg) raise ValueError(msg)
return v return v
@field_validator("steps") @field_validator("steps")
def validate_steps(cls, v, info: core_schema.FieldValidationInfo): def validate_steps(cls, v, info: core_schema.FieldValidationInfo):
from imaginairy.samplers import SAMPLER_LOOKUP steps_lookup = {"ddim": 50, "dpmpp": 20}
if v is None: if v is None:
SamplerCls = SAMPLER_LOOKUP[info.data["sampler_type"]] v = steps_lookup[info.data["solver_type"]]
v = SamplerCls.default_steps
return int(v) return int(v)
@ -486,14 +556,26 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
return self return self
@field_validator("height", "width") @field_validator("size", mode="before")
def validate_image_size(cls, v, info: core_schema.FieldValidationInfo): def validate_image_size(cls, v, info: core_schema.FieldValidationInfo):
from imaginairy.model_manager import get_model_default_image_size from imaginairy.model_manager import get_model_default_image_size
from imaginairy.utils.named_resolutions import normalize_image_size
if v is None: if v is None:
v = get_model_default_image_size(info.data["model"]) v = get_model_default_image_size(info.data["model_architecture"])
return v width, height = normalize_image_size(v)
min_size = 8
max_size = 100_000
if not min_size <= width <= max_size:
msg = f"Width must be between {min_size} and {max_size}. Got: {width}"
raise ValueError(msg)
if not min_size <= height <= max_size:
msg = f"Height must be between {min_size} and {max_size}. Got: {height}"
raise ValueError(msg)
return width, height
@field_validator("caption_text", mode="before") @field_validator("caption_text", mode="before")
def validate_caption_text(cls, v): def validate_caption_text(cls, v):
@ -507,7 +589,7 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
return self.prompt return self.prompt
@property @property
def prompt_text(self): def prompt_text(self) -> str:
if not self.prompt: if not self.prompt:
return "" return ""
if len(self.prompt) == 1: if len(self.prompt) == 1:
@ -515,18 +597,31 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
return "|".join(str(p) for p in self.prompt) return "|".join(str(p) for p in self.prompt)
@property @property
def negative_prompt_text(self): def negative_prompt_text(self) -> str:
if not self.negative_prompt: if not self.negative_prompt:
return "" return ""
if len(self.negative_prompt) == 1: if len(self.negative_prompt) == 1:
return self.negative_prompt[0].text return self.negative_prompt[0].text
return "|".join(str(p) for p in self.negative_prompt) return "|".join(str(p) for p in self.negative_prompt)
@property
def width(self) -> int:
return self.size[0]
@property
def height(self) -> int:
return self.size[1]
def prompt_description(self): def prompt_description(self):
return ( return (
f'"{self.prompt_text}" {self.width}x{self.height}px ' f'"{self.prompt_text}" {self.width}x{self.height}px '
f'negative-prompt:"{self.negative_prompt_text}" ' f'negative-prompt:"{self.negative_prompt_text}" '
f"seed:{self.seed} prompt-strength:{self.prompt_strength} steps:{self.steps} sampler-type:{self.sampler_type} init-image-strength:{self.init_image_strength} model:{self.model}" f"seed:{self.seed} "
f"prompt-strength:{self.prompt_strength} "
f"steps:{self.steps} solver-type:{self.solver_type} "
f"init-image-strength:{self.init_image_strength} "
f"arch:{self.model_architecture} "
f"weights: {self.model_weights}"
) )
def logging_dict(self): def logging_dict(self):
@ -547,7 +642,7 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
new_prompt = new_prompt.model_validate(dict(new_prompt)) new_prompt = new_prompt.model_validate(dict(new_prompt))
return new_prompt return new_prompt
def make_concrete_copy(self): def make_concrete_copy(self) -> Self:
seed = self.seed if self.seed is not None else random.randint(1, 1_000_000_000) seed = self.seed if self.seed is not None else random.randint(1, 1_000_000_000)
return self.full_copy( return self.full_copy(
deep=False, deep=False,
@ -574,10 +669,6 @@ class ImagineResult:
prompt: ImaginePrompt, prompt: ImaginePrompt,
is_nsfw, is_nsfw,
safety_score, safety_score,
upscaled_img=None,
modified_original=None,
mask_binary=None,
mask_grayscale=None,
result_images=None, result_images=None,
timings=None, timings=None,
progress_latents=None, progress_latents=None,
@ -594,20 +685,10 @@ class ImagineResult:
self.images = {"generated": img} self.images = {"generated": img}
if upscaled_img:
self.images["upscaled"] = upscaled_img
if modified_original:
self.images["modified_original"] = modified_original
if mask_binary:
self.images["mask_binary"] = mask_binary
if mask_grayscale:
self.images["mask_grayscale"] = mask_grayscale
if result_images: if result_images:
for img_type, r_img in result_images.items(): for img_type, r_img in result_images.items():
if r_img is None:
continue
if isinstance(r_img, torch.Tensor): if isinstance(r_img, torch.Tensor):
if r_img.shape[1] == 4: if r_img.shape[1] == 4:
r_img = model_latent_to_pillow_img(r_img) r_img = model_latent_to_pillow_img(r_img)
@ -620,7 +701,6 @@ class ImagineResult:
# for backward compat # for backward compat
self.img = img self.img = img
self.upscaled_img = upscaled_img
self.is_nsfw = is_nsfw self.is_nsfw = is_nsfw
self.safety_score = safety_score self.safety_score = safety_score
@ -628,7 +708,7 @@ class ImagineResult:
self.torch_backend = get_device() self.torch_backend = get_device()
self.hardware_name = get_hardware_description(get_device()) self.hardware_name = get_hardware_description(get_device())
def md5(self): def md5(self) -> str:
return hashlib.md5(self.img.tobytes()).hexdigest() return hashlib.md5(self.img.tobytes()).hexdigest()
def metadata_dict(self): def metadata_dict(self):
@ -636,19 +716,19 @@ class ImagineResult:
"prompt": self.prompt.logging_dict(), "prompt": self.prompt.logging_dict(),
} }
def timings_str(self): def timings_str(self) -> str:
if not self.timings: if not self.timings:
return "" return ""
return " ".join(f"{k}:{v:.2f}s" for k, v in self.timings.items()) return " ".join(f"{k}:{v:.2f}s" for k, v in self.timings.items())
def _exif(self): def _exif(self) -> "Image.Exif":
from PIL import Image from PIL import Image
exif = Image.Exif() exif = Image.Exif()
exif[ExifCodes.ImageDescription] = self.prompt.prompt_description() exif[ExifCodes.ImageDescription] = self.prompt.prompt_description()
exif[ExifCodes.UserComment] = json.dumps(self.metadata_dict()) exif[ExifCodes.UserComment] = json.dumps(self.metadata_dict())
# help future web scrapes not ingest AI generated art # help future web scrapes not ingest AI generated art
sd_version = self.prompt.model sd_version = self.prompt.model_weights
if len(sd_version) > 20: if len(sd_version) > 20:
sd_version = "custom weights" sd_version = "custom weights"
exif[ExifCodes.Software] = f"Imaginairy / Stable Diffusion {sd_version}" exif[ExifCodes.Software] = f"Imaginairy / Stable Diffusion {sd_version}"
@ -656,7 +736,7 @@ class ImagineResult:
exif[ExifCodes.HostComputer] = f"{self.torch_backend}:{self.hardware_name}" exif[ExifCodes.HostComputer] = f"{self.torch_backend}:{self.hardware_name}"
return exif return exif
def save(self, save_path, image_type="generated"): def save(self, save_path: "Path | str", image_type: str = "generated") -> None:
img = self.images.get(image_type, None) img = self.images.get(image_type, None)
if img is None: if img is None:
msg = f"Image of type {image_type} not stored. Options are: {self.images.keys()}" msg = f"Image of type {image_type} not stored. Options are: {self.images.keys()}"
@ -665,6 +745,6 @@ class ImagineResult:
img.convert("RGB").save(save_path, exif=self._exif()) img.convert("RGB").save(save_path, exif=self._exif())
class SafetyMode: class SafetyMode(str, Enum):
STRICT = "strict" STRICT = "strict"
RELAXED = "relaxed" RELAXED = "relaxed"

@ -10,7 +10,7 @@ from imaginairy import ImaginePrompt, LazyLoadingImage, imagine_image_files
from imaginairy.animations import make_gif_animation from imaginairy.animations import make_gif_animation
from imaginairy.enhancers.facecrop import detect_faces from imaginairy.enhancers.facecrop import detect_faces
from imaginairy.img_utils import add_caption_to_image, pillow_fit_image_within from imaginairy.img_utils import add_caption_to_image, pillow_fit_image_within
from imaginairy.schema import ControlNetInput from imaginairy.schema import ControlInput
preserve_head_kwargs = { preserve_head_kwargs = {
"mask_prompt": "head|face", "mask_prompt": "head|face",
@ -142,7 +142,7 @@ def surprise_me_prompts(
for prompt_text, strength, kwargs in generic_prompts: for prompt_text, strength, kwargs in generic_prompts:
if use_controlnet: if use_controlnet:
strength = 5 strength = 5
control_input = ControlNetInput(mode="edit", strength=2) control_input = ControlInput(mode="edit", strength=2)
prompts.append( prompts.append(
ImaginePrompt( ImaginePrompt(
prompt_text, prompt_text,
@ -163,7 +163,7 @@ def surprise_me_prompts(
prompt_text, prompt_text,
init_image=img, init_image=img,
prompt_strength=strength, prompt_strength=strength,
model="edit", model_weights="edit",
steps=steps, steps=steps,
width=width, width=width,
height=height, height=height,
@ -178,7 +178,7 @@ def surprise_me_prompts(
for prompt_subconfig in prompt_subconfigs: for prompt_subconfig in prompt_subconfigs:
prompt_text, strength, kwargs = prompt_subconfig prompt_text, strength, kwargs = prompt_subconfig
if use_controlnet: if use_controlnet:
control_input = ControlNetInput( control_input = ControlInput(
mode="edit", mode="edit",
) )
prompts.append( prompts.append(
@ -201,7 +201,7 @@ def surprise_me_prompts(
prompt_text, prompt_text,
init_image=img, init_image=img,
prompt_strength=strength, prompt_strength=strength,
model="edit", model_weights="edit",
steps=steps, steps=steps,
width=width, width=width,
height=height, height=height,

@ -0,0 +1,533 @@
import datetime
import logging
import os
import signal
import time
from functools import partial
import numpy as np
import pytorch_lightning as pl
import torch
import torchvision
from omegaconf import OmegaConf
from PIL import Image
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import Callback, LearningRateMonitor
try:
from pytorch_lightning.strategies import DDPStrategy
except ImportError:
# let's not break all of imaginairy just because a training import doesn't exist in an older version of PL
# Use >= 1.6.0 to make this work
DDPStrategy = None
import contextlib
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.utilities import rank_zero_info
from pytorch_lightning.utilities.distributed import rank_zero_only
from torch.utils.data import DataLoader, Dataset
from imaginairy import config
from imaginairy.model_manager import get_diffusion_model
from imaginairy.training_tools.single_concept import SingleConceptDataset
from imaginairy.utils import get_device, instantiate_from_config
mod_logger = logging.getLogger(__name__)
referenced_by_string = [LearningRateMonitor]
class WrappedDataset(Dataset):
"""Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset."""
def __init__(self, dataset):
self.data = dataset
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
def worker_init_fn(_):
worker_info = torch.utils.data.get_worker_info()
dataset = worker_info.dataset
worker_id = worker_info.id
if isinstance(dataset, SingleConceptDataset):
# split_size = dataset.num_records // worker_info.num_workers
# reset num_records to the true number to retain reliable length information
# dataset.sample_ids = dataset.valid_ids[
# worker_id * split_size : (worker_id + 1) * split_size
# ]
current_id = np.random.choice(len(np.random.get_state()[1]), 1)
return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
return np.random.seed(np.random.get_state()[1][0] + worker_id)
class DataModuleFromConfig(pl.LightningDataModule):
def __init__(
self,
batch_size,
train=None,
validation=None,
test=None,
predict=None,
wrap=False,
num_workers=None,
shuffle_test_loader=False,
use_worker_init_fn=False,
shuffle_val_dataloader=False,
num_val_workers=0,
):
super().__init__()
self.batch_size = batch_size
self.dataset_configs = {}
self.num_workers = num_workers if num_workers is not None else batch_size * 2
if num_val_workers is None:
self.num_val_workers = self.num_workers
else:
self.num_val_workers = num_val_workers
self.use_worker_init_fn = use_worker_init_fn
if train is not None:
self.dataset_configs["train"] = train
self.train_dataloader = self._train_dataloader
if validation is not None:
self.dataset_configs["validation"] = validation
self.val_dataloader = partial(
self._val_dataloader, shuffle=shuffle_val_dataloader
)
if test is not None:
self.dataset_configs["test"] = test
self.test_dataloader = partial(
self._test_dataloader, shuffle=shuffle_test_loader
)
if predict is not None:
self.dataset_configs["predict"] = predict
self.predict_dataloader = self._predict_dataloader
self.wrap = wrap
self.datasets = None
def prepare_data(self):
for data_cfg in self.dataset_configs.values():
instantiate_from_config(data_cfg)
def setup(self, stage=None):
self.datasets = {
k: instantiate_from_config(c) for k, c in self.dataset_configs.items()
}
if self.wrap:
self.datasets = {k: WrappedDataset(v) for k, v in self.datasets.items()}
def _train_dataloader(self):
is_iterable_dataset = isinstance(self.datasets["train"], SingleConceptDataset)
if is_iterable_dataset or self.use_worker_init_fn:
pass
else:
pass
return DataLoader(
self.datasets["train"],
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
worker_init_fn=worker_init_fn,
)
def _val_dataloader(self, shuffle=False):
if (
isinstance(self.datasets["validation"], SingleConceptDataset)
or self.use_worker_init_fn
):
init_fn = worker_init_fn
else:
init_fn = None
return DataLoader(
self.datasets["validation"],
batch_size=self.batch_size,
num_workers=self.num_val_workers,
worker_init_fn=init_fn,
shuffle=shuffle,
)
def _test_dataloader(self, shuffle=False):
is_iterable_dataset = isinstance(self.datasets["train"], SingleConceptDataset)
if is_iterable_dataset or self.use_worker_init_fn:
init_fn = worker_init_fn
else:
init_fn = None
is_iterable_dataset = False
# do not shuffle dataloader for iterable dataset
shuffle = shuffle and (not is_iterable_dataset)
return DataLoader(
self.datasets["test"],
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=init_fn,
shuffle=shuffle,
)
def _predict_dataloader(self, shuffle=False):
if (
isinstance(self.datasets["predict"], SingleConceptDataset)
or self.use_worker_init_fn
):
init_fn = worker_init_fn
else:
init_fn = None
return DataLoader(
self.datasets["predict"],
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=init_fn,
)
class SetupCallback(Callback):
def __init__(
self,
resume,
now,
logdir,
ckptdir,
cfgdir,
):
super().__init__()
self.resume = resume
self.now = now
self.logdir = logdir
self.ckptdir = ckptdir
self.cfgdir = cfgdir
def on_keyboard_interrupt(self, trainer, pl_module):
if trainer.global_rank == 0:
mod_logger.info("Stopping execution and saving final checkpoint.")
ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
trainer.save_checkpoint(ckpt_path)
def on_fit_start(self, trainer, pl_module):
if trainer.global_rank == 0:
# Create logdirs and save configs
os.makedirs(self.logdir, exist_ok=True)
os.makedirs(self.ckptdir, exist_ok=True)
os.makedirs(self.cfgdir, exist_ok=True)
else:
# ModelCheckpoint callback created log directory --- remove it
if not self.resume and os.path.exists(self.logdir):
dst, name = os.path.split(self.logdir)
dst = os.path.join(dst, "child_runs", name)
os.makedirs(os.path.split(dst)[0], exist_ok=True)
with contextlib.suppress(FileNotFoundError):
os.rename(self.logdir, dst)
class ImageLogger(Callback):
def __init__(
self,
batch_frequency,
max_images,
clamp=True,
increase_log_steps=True,
rescale=True,
disabled=False,
log_on_batch_idx=False,
log_first_step=False,
log_images_kwargs=None,
log_all_val=False,
concept_label=None,
):
super().__init__()
self.rescale = rescale
self.batch_freq = batch_frequency
self.max_images = max_images
self.logger_log_images = {}
self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)]
if not increase_log_steps:
self.log_steps = [self.batch_freq]
self.clamp = clamp
self.disabled = disabled
self.log_on_batch_idx = log_on_batch_idx
self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
self.log_first_step = log_first_step
self.log_all_val = log_all_val
self.concept_label = concept_label
@rank_zero_only
def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx):
root = os.path.join(save_dir, "logs", "images", split)
for k in images:
grid = torchvision.utils.make_grid(images[k], nrow=4)
if self.rescale:
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
grid = grid.numpy()
grid = (grid * 255).astype(np.uint8)
filename = (
f"{k}_gs-{global_step:06}_e-{current_epoch:06}_b-{batch_idx:06}.png"
)
path = os.path.join(root, filename)
os.makedirs(os.path.split(path)[0], exist_ok=True)
Image.fromarray(grid).save(path)
def log_img(self, pl_module, batch, batch_idx, split="train"):
# always generate the concept label
batch["txt"][0] = self.concept_label
check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
if self.log_all_val and split == "val":
should_log = True
else:
should_log = self.check_frequency(check_idx)
if (
should_log
and (batch_idx % self.batch_freq == 0)
and hasattr(pl_module, "log_images")
and callable(pl_module.log_images)
and self.max_images > 0
):
logger = type(pl_module.logger)
is_train = pl_module.training
if is_train:
pl_module.eval()
with torch.no_grad():
images = pl_module.log_images(
batch, split=split, **self.log_images_kwargs
)
for k in images:
N = min(images[k].shape[0], self.max_images)
images[k] = images[k][:N]
if isinstance(images[k], torch.Tensor):
images[k] = images[k].detach().cpu()
if self.clamp:
images[k] = torch.clamp(images[k], -1.0, 1.0)
self.log_local(
pl_module.logger.save_dir,
split,
images,
pl_module.global_step,
pl_module.current_epoch,
batch_idx,
)
logger_log_images = self.logger_log_images.get(
logger, lambda *args, **kwargs: None
)
logger_log_images(pl_module, images, pl_module.global_step, split)
if is_train:
pl_module.train()
def check_frequency(self, check_idx):
if (check_idx % self.batch_freq) == 0 and (
check_idx > 0 or self.log_first_step
):
return True
return False
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
self.log_img(pl_module, batch, batch_idx, split="train")
def on_validation_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
):
if not self.disabled and pl_module.global_step > 0:
self.log_img(pl_module, batch, batch_idx, split="val")
if (
hasattr(pl_module, "calibrate_grad_norm")
and (pl_module.calibrate_grad_norm and batch_idx % 25 == 0)
and batch_idx > 0
):
self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
class CUDACallback(Callback):
# see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py
def on_train_epoch_start(self, trainer, pl_module):
# Reset the memory use counter
if "cuda" in get_device():
torch.cuda.reset_peak_memory_stats(trainer.strategy.root_device.index)
torch.cuda.synchronize(trainer.strategy.root_device.index)
self.start_time = time.time()
def on_train_epoch_end(self, trainer, pl_module):
if "cuda" in get_device():
torch.cuda.synchronize(trainer.strategy.root_device.index)
max_memory = (
torch.cuda.max_memory_allocated(trainer.strategy.root_device.index)
/ 2**20
)
epoch_time = time.time() - self.start_time
try:
max_memory = trainer.training_type_plugin.reduce(max_memory)
epoch_time = trainer.training_type_plugin.reduce(epoch_time)
rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB")
except AttributeError:
pass
def train_diffusion_model(
concept_label,
concept_images_dir,
class_label,
class_images_dir,
weights_location=config.DEFAULT_MODEL_WEIGHTS,
logdir="logs",
learning_rate=1e-6,
accumulate_grad_batches=32,
resume=None,
):
"""
Train a diffusion model on a single concept.
accumulate_grad_batches used to simulate a bigger batch size - https://arxiv.org/pdf/1711.00489.pdf
"""
if DDPStrategy is None:
msg = "Please install pytorch-lightning>=1.6.0 to train a model"
raise ImportError(msg)
batch_size = 1
seed = 23
num_workers = 1
num_val_workers = 0
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") # noqa: DTZ005
logdir = os.path.join(logdir, now)
ckpt_output_dir = os.path.join(logdir, "checkpoints")
cfg_output_dir = os.path.join(logdir, "configs")
seed_everything(seed)
model = get_diffusion_model(
weights_location=weights_location, half_mode=False, for_training=True
)._model
model.learning_rate = learning_rate * accumulate_grad_batches * batch_size
# add callback which sets up log directory
default_callbacks_cfg = {
"setup_callback": {
"target": "imaginairy.train.SetupCallback",
"params": {
"resume": False,
"now": now,
"logdir": logdir,
"ckptdir": ckpt_output_dir,
"cfgdir": cfg_output_dir,
},
},
"image_logger": {
"target": "imaginairy.train.ImageLogger",
"params": {
"batch_frequency": 10,
"max_images": 1,
"clamp": True,
"increase_log_steps": False,
"log_first_step": True,
"log_all_val": True,
"concept_label": concept_label,
"log_images_kwargs": {
"use_ema_scope": True,
"inpaint": False,
"plot_progressive_rows": False,
"plot_diffusion_rows": False,
"N": 1,
"unconditional_guidance_scale:": 7.5,
"unconditional_guidance_label": [""],
"ddim_steps": 20,
},
},
},
"learning_rate_logger": {
"target": "imaginairy.train.LearningRateMonitor",
"params": {
"logging_interval": "step",
# "log_momentum": True
},
},
"cuda_callback": {"target": "imaginairy.train.CUDACallback"},
}
default_modelckpt_cfg = {
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
"params": {
"dirpath": ckpt_output_dir,
"filename": "{epoch:06}",
"verbose": True,
"save_last": True,
"every_n_train_steps": 50,
"save_top_k": -1,
"monitor": None,
},
}
modelckpt_cfg = OmegaConf.create(default_modelckpt_cfg)
default_callbacks_cfg.update({"checkpoint_callback": modelckpt_cfg})
callbacks_cfg = OmegaConf.create(default_callbacks_cfg)
dataset_config = {
"concept_label": concept_label,
"concept_images_dir": concept_images_dir,
"class_label": class_label,
"class_images_dir": class_images_dir,
"image_transforms": [
{
"target": "torchvision.transforms.Resize",
"params": {"size": 512, "interpolation": 3},
},
{"target": "torchvision.transforms.RandomCrop", "params": {"size": 512}},
],
}
data_module_config = {
"batch_size": batch_size,
"num_workers": num_workers,
"num_val_workers": num_val_workers,
"train": {
"target": "imaginairy.training_tools.single_concept.SingleConceptDataset",
"params": dataset_config,
},
}
trainer = Trainer(
benchmark=True,
num_sanity_val_steps=0,
accumulate_grad_batches=accumulate_grad_batches,
strategy=DDPStrategy(),
callbacks=[instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg],
gpus=1,
default_root_dir=".",
)
trainer.logdir = logdir
data = DataModuleFromConfig(**data_module_config)
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
# calling these ourselves should not be necessary but it is.
# lightning still takes care of proper multiprocessing though
data.prepare_data()
data.setup()
def melk(*args, **kwargs):
if trainer.global_rank == 0:
mod_logger.info("Summoning checkpoint.")
ckpt_path = os.path.join(ckpt_output_dir, "last.ckpt")
trainer.save_checkpoint(ckpt_path)
signal.signal(signal.SIGUSR1, melk)
try:
try:
trainer.fit(model, data)
except Exception:
melk()
raise
finally:
mod_logger.info(trainer.profiler.summary())

@ -170,7 +170,21 @@ def replace_value_at_path(data, path, new_value):
parent = get_path(data, path[:-1]) parent = get_path(data, path[:-1])
last_key = path[-1] last_key = path[-1]
if new_value == NODE_DELETE: if new_value == NODE_DELETE:
del parent[last_key] if isinstance(parent, tuple):
grandparent = get_path(data, path[:-2])
grandparent_key = path[-2]
new_parent = list(parent)
del new_parent[last_key]
grandparent[grandparent_key] = tuple(new_parent)
else:
del parent[last_key]
else: else:
parent[last_key] = new_value if isinstance(parent, tuple):
grandparent = get_path(data, path[:-2])
grandparent_key = path[-2]
new_parent = list(parent)
new_parent[last_key] = new_value
grandparent[grandparent_key] = tuple(new_parent)
else:
parent[last_key] = new_value
return data return data

@ -43,27 +43,37 @@ _NAMED_RESOLUTIONS = {
"SVD": (1024, 576), # stable video diffusion "SVD": (1024, 576), # stable video diffusion
} }
_NAMED_RESOLUTIONS = {k.upper(): v for k, v in _NAMED_RESOLUTIONS.items()}
def get_named_resolution(resolution: str):
resolution = resolution.upper()
size = _NAMED_RESOLUTIONS.get(resolution) def normalize_image_size(resolution: str | int | tuple[int, int]) -> tuple[int, int]:
match resolution:
if size is None: case (int(), int()):
# is it WIDTHxHEIGHT format? size = resolution
try: case int():
width, height = resolution.split("X") size = resolution, resolution
size = (int(width), int(height)) case str():
except ValueError: resolution = resolution.strip().upper()
pass resolution = resolution.replace(" ", "").replace("X", ",").replace("*", ",")
size = _NAMED_RESOLUTIONS.get(resolution.upper())
if size is None: if size is None:
# is it just a single number? # is it WIDTH,HEIGHT format?
with contextlib.suppress(ValueError): try:
size = (int(resolution), int(resolution)) width, height = resolution.split(",")
size = int(width), int(height)
if size is None: except ValueError:
msg = f"Unknown resolution: {resolution}" pass
if size is None:
# is it just a single number?
with contextlib.suppress(ValueError):
size = (int(resolution), int(resolution))
if size is None:
msg = f"Invalid resolution: '{resolution}'"
raise ValueError(msg)
case _:
msg = f"Invalid resolution: {resolution!r}"
raise ValueError(msg)
if size[0] <= 0 or size[1] <= 0:
msg = f"Invalid resolution: {resolution!r}"
raise ValueError(msg) raise ValueError(msg)
return size return size

@ -87,7 +87,7 @@ def generate_video(
device="cpu", device="cpu",
num_frames=num_frames, num_frames=num_frames,
num_steps=num_steps, num_steps=num_steps,
weights_url=video_model_config["weights_url"], weights_url=video_model_config["weights_location"],
) )
torch.manual_seed(seed) torch.manual_seed(seed)

@ -3,7 +3,7 @@ import safetensors
from imaginairy.model_manager import ( from imaginairy.model_manager import (
get_cached_url_path, get_cached_url_path,
open_weights, open_weights,
resolve_model_paths, resolve_model_weights_config,
) )
from imaginairy.weight_management import utils from imaginairy.weight_management import utils
from imaginairy.weight_management.pattern_collapse import find_state_dict_key_patterns from imaginairy.weight_management.pattern_collapse import find_state_dict_key_patterns
@ -11,15 +11,12 @@ from imaginairy.weight_management.utils import save_model_info
def save_compvis_patterns(): def save_compvis_patterns():
( model_weights_config = resolve_model_weights_config(
model_metadata, model_weights="openjourney-v1",
weights_url, )
config_path, weights_path = get_cached_url_path(
control_weights_paths, model_weights_config.weights_location, category="weights"
) = resolve_model_paths(
weights_path="openjourney-v1",
) )
weights_path = get_cached_url_path(weights_url, category="weights")
with safetensors.safe_open(weights_path, "pytorch") as f: with safetensors.safe_open(weights_path, "pytorch") as f:
weights_keys = f.keys() weights_keys = f.keys()
@ -98,7 +95,7 @@ def save_weight_info(
model_name, component_name, format_name, weights_url=None, weights_keys=None model_name, component_name, format_name, weights_url=None, weights_keys=None
): ):
if weights_keys is None and weights_url is None: if weights_keys is None and weights_url is None:
msg = "Either weights_keys or weights_url must be provided" msg = "Either weights_keys or weights_location must be provided"
raise ValueError(msg) raise ValueError(msg)
if weights_keys is None: if weights_keys is None:

@ -12,7 +12,6 @@ from urllib3 import HTTPConnectionPool
from imaginairy import ImaginePrompt, api, imagine from imaginairy import ImaginePrompt, api, imagine
from imaginairy.log_utils import configure_logging, suppress_annoying_logs_and_warnings from imaginairy.log_utils import configure_logging, suppress_annoying_logs_and_warnings
from imaginairy.samplers import SAMPLER_TYPE_OPTIONS
from imaginairy.utils import ( from imaginairy.utils import (
fix_torch_group_norm, fix_torch_group_norm,
fix_torch_nn_layer_norm, fix_torch_nn_layer_norm,
@ -26,13 +25,13 @@ if "pytest" in str(sys.argv):
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SAMPLERS_FOR_TESTING = SAMPLER_TYPE_OPTIONS # SOLVERS_FOR_TESTING = SOLVER_TYPE_OPTIONS
if get_device() == "mps:0": # if get_device() == "mps:0":
SAMPLERS_FOR_TESTING = ["plms", "k_euler_a"] # SOLVERS_FOR_TESTING = ["plms", "k_euler_a"]
elif get_device() == "cpu": # elif get_device() == "cpu":
SAMPLERS_FOR_TESTING = [] # SOLVERS_FOR_TESTING = []
SAMPLERS_FOR_TESTING = ["ddim", "k_dpmpp_2m"] SOLVERS_FOR_TESTING = ["ddim", "dpmpp"]
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
@ -90,8 +89,8 @@ def filename_base_for_orig_outputs(request):
return filename_base return filename_base
@pytest.fixture(params=SAMPLERS_FOR_TESTING) @pytest.fixture(params=SOLVERS_FOR_TESTING)
def sampler_type(request): def solver_type(request):
return request.param return request.param
@ -118,11 +117,10 @@ def default_model_loaded():
""" """
prompt = ImaginePrompt( prompt = ImaginePrompt(
"dogs lying on a hot pink couch", "dogs lying on a hot pink couch",
width=64, size=64,
height=64,
steps=2, steps=2,
seed=1, seed=1,
sampler_type="ddim", solver_type="ddim",
) )
next(imagine(prompt)) next(imagine(prompt))

Binary file not shown.

After

Width:  |  Height:  |  Size: 569 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 308 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 248 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 392 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 281 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 276 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 278 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 239 KiB

@ -6,22 +6,22 @@ from imaginairy import LazyLoadingImage
from imaginairy.api import imagine, imagine_image_files from imaginairy.api import imagine, imagine_image_files
from imaginairy.img_processors.control_modes import CONTROL_MODES from imaginairy.img_processors.control_modes import CONTROL_MODES
from imaginairy.img_utils import pillow_fit_image_within from imaginairy.img_utils import pillow_fit_image_within
from imaginairy.schema import ControlNetInput, ImaginePrompt from imaginairy.schema import ControlInput, ImaginePrompt, MaskMode
from imaginairy.utils import get_device from imaginairy.utils import get_device
from . import TESTS_FOLDER from . import TESTS_FOLDER
from .utils import assert_image_similar_to_expectation from .utils import assert_image_similar_to_expectation
def test_imagine(sampler_type, filename_base_for_outputs): def test_imagine(solver_type, filename_base_for_outputs):
prompt_text = "a scenic old-growth forest with diffuse light poking through the canopy. high resolution nature photography" prompt_text = "a scenic old-growth forest with diffuse light poking through the canopy. high resolution nature photography"
prompt = ImaginePrompt( prompt = ImaginePrompt(
prompt_text, width=512, height=512, steps=20, seed=1, sampler_type=sampler_type prompt_text, size=512, steps=20, seed=1, solver_type=solver_type
) )
result = next(imagine(prompt)) result = next(imagine(prompt))
threshold_lookup = {"k_dpm_2_a": 26000} threshold_lookup = {"k_dpm_2_a": 26000}
threshold = threshold_lookup.get(sampler_type, 10000) threshold = threshold_lookup.get(solver_type, 10000)
img_path = f"{filename_base_for_outputs}.png" img_path = f"{filename_base_for_outputs}.png"
assert_image_similar_to_expectation( assert_image_similar_to_expectation(
@ -49,25 +49,25 @@ def test_model_versions(filename_base_for_orig_outputs, model_version):
ImaginePrompt( ImaginePrompt(
prompt_text, prompt_text,
seed=1, seed=1,
model=model_version, model_weights=model_version,
) )
) )
threshold = 35000 threshold = 35000
results = list(imagine(prompts))
for i, result in enumerate(imagine(prompts)): for i, result in enumerate(results):
img_path = f"{filename_base_for_orig_outputs}_{result.prompt.prompt_text}_{result.prompt.model}.png" img_path = f"{filename_base_for_orig_outputs}_{result.prompt.prompt_text}_{result.prompt.model_weights}.png"
result.img.save(img_path) result.img.save(img_path)
for i, result in enumerate(imagine(prompts)): for i, result in enumerate(results):
img_path = f"{filename_base_for_orig_outputs}_{result.prompt.prompt_text}_{result.prompt.model}.png" img_path = f"{filename_base_for_orig_outputs}_{result.prompt.prompt_text}_{result.prompt.model_weights}.png"
assert_image_similar_to_expectation( assert_image_similar_to_expectation(
result.img, img_path=img_path, threshold=threshold result.img, img_path=img_path, threshold=threshold
) )
def test_img2img_beach_to_sunset( def test_img2img_beach_to_sunset(
sampler_type, filename_base_for_outputs, filename_base_for_orig_outputs solver_type, filename_base_for_outputs, filename_base_for_orig_outputs
): ):
img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/beach_at_sainte_adresse.jpg") img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/beach_at_sainte_adresse.jpg")
prompt = ImaginePrompt( prompt = ImaginePrompt(
@ -77,11 +77,10 @@ def test_img2img_beach_to_sunset(
prompt_strength=15, prompt_strength=15,
mask_prompt="(sky|clouds) AND !(buildings|trees)", mask_prompt="(sky|clouds) AND !(buildings|trees)",
mask_mode="replace", mask_mode="replace",
width=512, size=512,
height=512,
steps=40 * 2, steps=40 * 2,
seed=1, seed=1,
sampler_type=sampler_type, solver_type=solver_type,
) )
result = next(imagine(prompt)) result = next(imagine(prompt))
@ -91,7 +90,7 @@ def test_img2img_beach_to_sunset(
def test_img_to_img_from_url_cats( def test_img_to_img_from_url_cats(
sampler_type, solver_type,
filename_base_for_outputs, filename_base_for_outputs,
mocked_responses, mocked_responses,
filename_base_for_orig_outputs, filename_base_for_orig_outputs,
@ -113,11 +112,10 @@ def test_img_to_img_from_url_cats(
"dogs lying on a hot pink couch", "dogs lying on a hot pink couch",
init_image=img, init_image=img,
init_image_strength=0.5, init_image_strength=0.5,
width=512, size=512,
height=512,
steps=50, steps=50,
seed=1, seed=1,
sampler_type=sampler_type, solver_type=solver_type,
) )
result = next(imagine(prompt)) result = next(imagine(prompt))
@ -130,7 +128,7 @@ def test_img_to_img_from_url_cats(
def test_img2img_low_noise( def test_img2img_low_noise(
filename_base_for_outputs, filename_base_for_outputs,
sampler_type, solver_type,
): ):
fruit_path = os.path.join(TESTS_FOLDER, "data", "bowl_of_fruit.jpg") fruit_path = os.path.join(TESTS_FOLDER, "data", "bowl_of_fruit.jpg")
img = LazyLoadingImage(filepath=fruit_path) img = LazyLoadingImage(filepath=fruit_path)
@ -144,17 +142,18 @@ def test_img2img_low_noise(
mask_mode="replace", mask_mode="replace",
# steps=40, # steps=40,
seed=1, seed=1,
sampler_type=sampler_type, solver_type=solver_type,
) )
result = next(imagine(prompt)) result = next(imagine(prompt))
threshold_lookup = { threshold_lookup = {
"dpmpp": 26000,
"k_dpm_2_a": 26000, "k_dpm_2_a": 26000,
"k_euler_a": 18000, "k_euler_a": 18000,
"k_dpm_adaptive": 13000, "k_dpm_adaptive": 13000,
} }
threshold = threshold_lookup.get(sampler_type, 14000) threshold = threshold_lookup.get(solver_type, 14000)
img_path = f"{filename_base_for_outputs}.png" img_path = f"{filename_base_for_outputs}.png"
assert_image_similar_to_expectation( assert_image_similar_to_expectation(
@ -165,7 +164,7 @@ def test_img2img_low_noise(
@pytest.mark.parametrize("init_strength", [0, 0.05, 0.2, 1]) @pytest.mark.parametrize("init_strength", [0, 0.05, 0.2, 1])
def test_img_to_img_fruit_2_gold( def test_img_to_img_fruit_2_gold(
filename_base_for_outputs, filename_base_for_outputs,
sampler_type, solver_type,
init_strength, init_strength,
filename_base_for_orig_outputs, filename_base_for_orig_outputs,
): ):
@ -183,7 +182,7 @@ def test_img_to_img_fruit_2_gold(
mask_mode="replace", mask_mode="replace",
steps=needed_steps, steps=needed_steps,
seed=1, seed=1,
sampler_type=sampler_type, solver_type=solver_type,
) )
result = next(imagine(prompt)) result = next(imagine(prompt))
@ -194,7 +193,7 @@ def test_img_to_img_fruit_2_gold(
"k_dpm_adaptive": 13000, "k_dpm_adaptive": 13000,
"k_dpmpp_2s": 16000, "k_dpmpp_2s": 16000,
} }
threshold = threshold_lookup.get(sampler_type, 16000) threshold = threshold_lookup.get(solver_type, 16000)
pillow_fit_image_within(img).save(f"{filename_base_for_orig_outputs}__orig.jpg") pillow_fit_image_within(img).save(f"{filename_base_for_orig_outputs}__orig.jpg")
img_path = f"{filename_base_for_outputs}.png" img_path = f"{filename_base_for_outputs}.png"
@ -227,7 +226,7 @@ def test_img_to_img_fruit_2_gold_repeat():
] ]
for result in imagine(prompts, debug_img_callback=None): for result in imagine(prompts, debug_img_callback=None):
result.img.save( result.img.save(
f"{TESTS_FOLDER}/test_output/img2img_fruit_2_gold_{result.prompt.sampler_type}_{get_device()}_run-{run_count:02}.jpg" f"{TESTS_FOLDER}/test_output/img2img_fruit_2_gold_{result.prompt.solver_type}_{get_device()}_run-{run_count:02}.jpg"
) )
run_count += 1 run_count += 1
@ -236,9 +235,8 @@ def test_img_to_img_fruit_2_gold_repeat():
def test_img_to_file(): def test_img_to_file():
prompt = ImaginePrompt( prompt = ImaginePrompt(
"an old growth forest, diffuse light poking through the canopy. high-resolution, nature photography, nat geo photo", "an old growth forest, diffuse light poking through the canopy. high-resolution, nature photography, nat geo photo",
width=512 + 64, size=(512 + 64, 512 - 64),
height=512 - 64, steps=2,
steps=20,
seed=2, seed=2,
upscale=True, upscale=True,
) )
@ -254,8 +252,7 @@ def test_inpainting_bench(filename_base_for_outputs, filename_base_for_orig_outp
init_image=img, init_image=img,
init_image_strength=0.4, init_image_strength=0.4,
mask_image=LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/bench2_mask.png"), mask_image=LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/bench2_mask.png"),
width=512, size=512,
height=512,
steps=40, steps=40,
seed=1, seed=1,
) )
@ -279,9 +276,8 @@ def test_cliptext_inpainting_pearl_doctor(
init_image=img, init_image=img,
init_image_strength=0.2, init_image_strength=0.2,
mask_prompt="face AND NOT (bandana OR hair OR blue fabric){*5}", mask_prompt="face AND NOT (bandana OR hair OR blue fabric){*5}",
mask_mode=ImaginePrompt.MaskMode.KEEP, mask_mode=MaskMode.KEEP,
width=512, size=512,
height=512,
steps=40, steps=40,
seed=181509347, seed=181509347,
) )
@ -297,8 +293,7 @@ def test_tile_mode(filename_base_for_outputs):
prompt_text = "gold coins" prompt_text = "gold coins"
prompt = ImaginePrompt( prompt = ImaginePrompt(
prompt_text, prompt_text,
width=400, size=400,
height=400,
steps=15, steps=15,
seed=1, seed=1,
tile_mode="xy", tile_mode="xy",
@ -317,7 +312,7 @@ control_modes = list(CONTROL_MODES.keys())
def test_controlnet(filename_base_for_outputs, control_mode): def test_controlnet(filename_base_for_outputs, control_mode):
prompt_text = "a photo of a woman sitting on a bench" prompt_text = "a photo of a woman sitting on a bench"
img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/bench2.png") img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/bench2.png")
control_input = ControlNetInput( control_input = ControlInput(
mode=control_mode, mode=control_mode,
image=img, image=img,
) )
@ -327,30 +322,27 @@ def test_controlnet(filename_base_for_outputs, control_mode):
prompt_text = "a wise old man" prompt_text = "a wise old man"
seed = 1 seed = 1
mask_image = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/bench2_mask.png") mask_image = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/bench2_mask.png")
control_input = ControlNetInput( control_input = ControlInput(
mode=control_mode, mode=control_mode,
image=mask_image, image=mask_image,
) )
prompt = ImaginePrompt( prompt = ImaginePrompt(
prompt_text, prompt_text,
width=512, size=512,
height=512,
steps=45, steps=45,
seed=seed, seed=seed,
init_image=img, init_image=img,
init_image_strength=0, init_image_strength=0,
control_inputs=[control_input], control_inputs=[control_input],
fix_faces=True, fix_faces=True,
sampler="ddim", solver_type="ddim",
) )
prompt.steps = 1 prompt.steps = 1
prompt.width = 256 prompt.size = 256
prompt.height = 256
result = next(imagine(prompt)) result = next(imagine(prompt))
prompt.steps = 15 prompt.steps = 15
prompt.width = 512 prompt.size = 512
prompt.height = 512
result = next(imagine(prompt)) result = next(imagine(prompt))
img_path = f"{filename_base_for_outputs}.png" img_path = f"{filename_base_for_outputs}.png"
@ -365,8 +357,7 @@ def test_large_image(filename_base_for_outputs):
prompt_text = "a stormy ocean. oil painting" prompt_text = "a stormy ocean. oil painting"
prompt = ImaginePrompt( prompt = ImaginePrompt(
prompt_text, prompt_text,
width=1920, size="1080p",
height=1080,
steps=30, steps=30,
seed=0, seed=0,
) )

@ -72,8 +72,7 @@ def test_edit_demo(monkeypatch):
ImaginePrompt( ImaginePrompt(
"", "",
steps=1, steps=1,
width=256, size=256,
height=256,
# model="empty", # model="empty",
) )
] ]
@ -89,7 +88,7 @@ def test_edit_demo(monkeypatch):
f"{TESTS_FOLDER}/test_output", f"{TESTS_FOLDER}/test_output",
], ],
) )
assert result.exit_code == 0 assert result.exit_code == 0, result.stdout
def test_upscale(monkeypatch): def test_upscale(monkeypatch):

@ -1,6 +1,6 @@
from imaginairy import config # from imaginairy import config
from imaginairy.samplers import SAMPLER_TYPE_OPTIONS # from imaginairy.samplers import SOLVER_TYPE_OPTIONS
#
#
def test_sampler_options(): # def test_sampler_options():
assert set(config.SAMPLER_TYPE_OPTIONS) == set(SAMPLER_TYPE_OPTIONS) # assert set(config.SOLVER_TYPE_NAMES) == set(SOLVER_TYPE_OPTIONS)

@ -58,7 +58,7 @@ def test_clip_masking(filename_base_for_outputs):
upscale=False, upscale=False,
fix_faces=True, fix_faces=True,
seed=42, seed=42,
# sampler_type="plms", # solver_type="plms",
) )
result = next(imagine(prompt)) result = next(imagine(prompt))

@ -1,24 +1,25 @@
from imaginairy import config from imaginairy import config
from imaginairy.model_manager import resolve_model_paths from imaginairy.model_manager import resolve_model_weights_config
def test_resolved_paths(): def test_resolved_paths():
"""Test that the resolved model path is correct.""" """Test that the resolved model path is correct."""
( model_weights_config = resolve_model_weights_config(config.DEFAULT_MODEL_WEIGHTS)
model_metadata, assert config.DEFAULT_MODEL_WEIGHTS.lower() in model_weights_config.aliases
weights_path, assert (
config_path, config.DEFAULT_MODEL_ARCHITECTURE in model_weights_config.architecture.aliases
control_weights_path, )
) = resolve_model_paths()
assert model_metadata.short_name == config.DEFAULT_MODEL
assert model_metadata.config_path == config_path
default_config_path = config_path
( model_weights_config = resolve_model_weights_config(
model_metadata, model_weights="foo.ckpt",
weights_path, default_model_architecture="sd15",
config_path, )
control_weights_path, print(model_weights_config)
) = resolve_model_paths(weights_path="foo.ckpt") assert model_weights_config.aliases == []
assert weights_path == "foo.ckpt" assert "sd15" in model_weights_config.architecture.aliases
assert config_path == default_config_path
model_weights_config = resolve_model_weights_config(
model_weights="foo.ckpt", default_model_architecture="sd15", for_inpainting=True
)
assert model_weights_config.aliases == []
assert "sd15-inpaint" in model_weights_config.architecture.aliases

@ -2,7 +2,7 @@ import pytest
from pydantic import ValidationError from pydantic import ValidationError
from imaginairy import LazyLoadingImage from imaginairy import LazyLoadingImage
from imaginairy.schema import ControlNetInput from imaginairy.schema import ControlInput
from tests import TESTS_FOLDER from tests import TESTS_FOLDER
@ -12,29 +12,29 @@ def _lazy_img():
def test_controlnetinput_basic(lazy_img): def test_controlnetinput_basic(lazy_img):
ControlNetInput(mode="canny", image=lazy_img) ControlInput(mode="canny", image=lazy_img)
ControlNetInput(mode="canny", image_raw=lazy_img) ControlInput(mode="canny", image_raw=lazy_img)
def test_controlnetinput_invalid_mode(lazy_img): def test_controlnetinput_invalid_mode(lazy_img):
with pytest.raises(ValueError, match=r".*Invalid controlnet mode.*"): with pytest.raises(ValueError, match=r".*Invalid controlnet mode.*"):
ControlNetInput(mode="pizza", image=lazy_img) ControlInput(mode="pizza", image=lazy_img)
def test_controlnetinput_both_images(lazy_img): def test_controlnetinput_both_images(lazy_img):
with pytest.raises(ValueError, match=r".*cannot specify both.*"): with pytest.raises(ValueError, match=r".*cannot specify both.*"):
ControlNetInput(mode="canny", image=lazy_img, image_raw=lazy_img) ControlInput(mode="canny", image=lazy_img, image_raw=lazy_img)
def test_controlnetinput_filepath_input(lazy_img): def test_controlnetinput_filepath_input(lazy_img):
"""Test that we accept filepaths here.""" """Test that we accept filepaths here."""
c = ControlNetInput(mode="canny", image=f"{TESTS_FOLDER}/data/red.png") c = ControlInput(mode="canny", image=f"{TESTS_FOLDER}/data/red.png")
c.image.convert("RGB") c.image.convert("RGB")
c = ControlNetInput(mode="canny", image_raw=f"{TESTS_FOLDER}/data/red.png") c = ControlInput(mode="canny", image_raw=f"{TESTS_FOLDER}/data/red.png")
c.image_raw.convert("RGB") c.image_raw.convert("RGB")
def test_controlnetinput_big(lazy_img): def test_controlnetinput_big(lazy_img):
ControlNetInput(mode="canny", strength=2) ControlInput(mode="canny", strength=2)
with pytest.raises(ValidationError, match=r".*float_type.*"): with pytest.raises(ValidationError, match=r".*float_type.*"):
ControlNetInput(mode="canny", strength=2**2048) ControlInput(mode="canny", strength=2**2048)

@ -2,13 +2,21 @@ import pytest
from pydantic import ValidationError from pydantic import ValidationError
from imaginairy import LazyLoadingImage, config from imaginairy import LazyLoadingImage, config
from imaginairy.schema import ControlNetInput, ImaginePrompt, WeightedPrompt from imaginairy.schema import ControlInput, ImaginePrompt, WeightedPrompt
from imaginairy.utils.data_distorter import DataDistorter from imaginairy.utils.data_distorter import DataDistorter
from tests import TESTS_FOLDER from tests import TESTS_FOLDER
def test_imagine_prompt_default():
prompt = ImaginePrompt()
assert prompt.prompt == []
assert prompt.negative_prompt == [
WeightedPrompt(text=config.DEFAULT_NEGATIVE_PROMPT)
]
def test_imagine_prompt_has_default_negative(): def test_imagine_prompt_has_default_negative():
prompt = ImaginePrompt("fruit salad", model="foobar") prompt = ImaginePrompt("fruit salad", model_weights="foobar")
assert isinstance(prompt.prompt[0], WeightedPrompt) assert isinstance(prompt.prompt[0], WeightedPrompt)
assert isinstance(prompt.negative_prompt[0], WeightedPrompt) assert isinstance(prompt.negative_prompt[0], WeightedPrompt)
@ -21,10 +29,10 @@ def test_imagine_prompt_custom_negative_prompt():
def test_imagine_prompt_model_specific_negative_prompt(): def test_imagine_prompt_model_specific_negative_prompt():
prompt = ImaginePrompt("fruit salad", model="openjourney-v1") prompt = ImaginePrompt("fruit salad", model_weights="openjourney-v1")
assert isinstance(prompt.prompt[0], WeightedPrompt) assert isinstance(prompt.prompt[0], WeightedPrompt)
assert isinstance(prompt.negative_prompt[0], WeightedPrompt) assert isinstance(prompt.negative_prompt[0], WeightedPrompt)
assert prompt.negative_prompt[0].text == "" assert prompt.negative_prompt[0].text == "poor quality"
def test_imagine_prompt_weighted_prompts(): def test_imagine_prompt_weighted_prompts():
@ -84,7 +92,7 @@ def test_imagine_prompt_control_inputs():
prompt = ImaginePrompt( prompt = ImaginePrompt(
"fruit", "fruit",
control_inputs=[ control_inputs=[
ControlNetInput(mode="depth", image=img), ControlInput(mode="depth", image=img),
], ],
) )
prompt.control_inputs[0].image.convert("RGB") prompt.control_inputs[0].image.convert("RGB")
@ -98,7 +106,7 @@ def test_imagine_prompt_control_inputs():
"fruit", "fruit",
init_image=img, init_image=img,
control_inputs=[ control_inputs=[
ControlNetInput(mode="depth"), ControlInput(mode="depth"),
], ],
) )
assert prompt.control_inputs[0].image is not None assert prompt.control_inputs[0].image is not None
@ -107,7 +115,7 @@ def test_imagine_prompt_control_inputs():
prompt = ImaginePrompt( prompt = ImaginePrompt(
"fruit", "fruit",
control_inputs=[ control_inputs=[
ControlNetInput(mode="depth"), ControlInput(mode="depth"),
], ],
) )
assert prompt.control_inputs[0].image is None assert prompt.control_inputs[0].image is None
@ -136,8 +144,8 @@ def test_imagine_prompt_mask_params():
def test_imagine_prompt_default_model(): def test_imagine_prompt_default_model():
prompt = ImaginePrompt("fruit", model=None) prompt = ImaginePrompt("fruit", model_weights=None)
assert prompt.model == config.DEFAULT_MODEL assert prompt.model_weights == config.DEFAULT_MODEL_WEIGHTS
def test_imagine_prompt_default_negative(): def test_imagine_prompt_default_negative():
@ -152,7 +160,7 @@ def test_imagine_prompt_fix_faces_fidelity():
def test_imagine_prompt_init_strength_zero(): def test_imagine_prompt_init_strength_zero():
lazy_img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/red.png") lazy_img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/red.png")
prompt = ImaginePrompt( prompt = ImaginePrompt(
"fruit", control_inputs=[ControlNetInput(mode="depth", image=lazy_img)] "fruit", control_inputs=[ControlInput(mode="depth", image=lazy_img)]
) )
assert prompt.init_image_strength == 0.0 assert prompt.init_image_strength == 0.0
@ -171,12 +179,12 @@ def test_distorted_prompts():
init_image=LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/red.png"), init_image=LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/red.png"),
init_image_strength=0.5, init_image_strength=0.5,
control_inputs=[ control_inputs=[
ControlNetInput( ControlInput(
mode="details", mode="details",
image=LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/red.png"), image=LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/red.png"),
strength=2, strength=2,
), ),
ControlNetInput( ControlInput(
mode="depth", mode="depth",
image_raw=LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/red.png"), image_raw=LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/red.png"),
strength=3, strength=3,
@ -187,13 +195,11 @@ def test_distorted_prompts():
mask_mode="replace", mask_mode="replace",
mask_modify_original=False, mask_modify_original=False,
outpaint="all5,up0,down20", outpaint="all5,up0,down20",
model=config.DEFAULT_MODEL, model_weights=config.DEFAULT_MODEL_WEIGHTS,
model_config_path=None, solver_type=config.DEFAULT_SOLVER,
sampler_type=config.DEFAULT_SAMPLER,
seed=42, seed=42,
steps=10, steps=10,
height=256, size=256,
width=256,
upscale=True, upscale=True,
fix_faces=True, fix_faces=True,
fix_faces_fidelity=0.7, fix_faces_fidelity=0.7,

@ -1,6 +1,6 @@
import pytest import pytest
from imaginairy.utils.named_resolutions import get_named_resolution from imaginairy.utils.named_resolutions import normalize_image_size
valid_cases = [ valid_cases = [
("HD", (1280, 720)), ("HD", (1280, 720)),
@ -12,11 +12,25 @@ valid_cases = [
("1920x1080", (1920, 1080)), ("1920x1080", (1920, 1080)),
("1280x720", (1280, 720)), ("1280x720", (1280, 720)),
("1024x768", (1024, 768)), ("1024x768", (1024, 768)),
("1024,768", (1024, 768)),
("1024*768", (1024, 768)),
("1024, 768", (1024, 768)),
("800", (800, 800)), ("800", (800, 800)),
("1024", (1024, 1024)), ("1024", (1024, 1024)),
("1080p", (1920, 1080)),
("1080P", (1920, 1080)),
(512, (512, 512)),
((512, 512), (512, 512)),
("1x1", (1, 1)),
] ]
invalid_cases = [ invalid_cases = [
None,
3.14,
(3.14, 3.14),
"",
" ",
"abc", "abc",
"-512",
"1920xABC", "1920xABC",
"1920x1080x1234", "1920x1080x1234",
"x1920", "x1920",
@ -30,10 +44,10 @@ invalid_cases = [
@pytest.mark.parametrize(("named_resolution", "expected"), valid_cases) @pytest.mark.parametrize(("named_resolution", "expected"), valid_cases)
def test_named_resolutions(named_resolution, expected): def test_named_resolutions(named_resolution, expected):
assert get_named_resolution(named_resolution) == expected assert normalize_image_size(named_resolution) == expected
@pytest.mark.parametrize("named_resolution", invalid_cases) @pytest.mark.parametrize("named_resolution", invalid_cases)
def test_invalid_inputs(named_resolution): def test_invalid_inputs(named_resolution):
with pytest.raises(ValueError, match="Unknown resolution"): with pytest.raises(ValueError, match="Invalid resolution"):
get_named_resolution(named_resolution) normalize_image_size(named_resolution)

Loading…
Cancel
Save