feature: large refactor

- add type hints
- size parameter
- ControlNetInput => ControlInput
- simplify imagineresult
pull/411/head^2
Bryce 5 months ago committed by Bryce Drennan
parent db85f0898a
commit 2bd6cb264b

@ -787,7 +787,7 @@ Use with `--model SD-2.1` or `--model SD-2.0-v`
**6.1.0**
- feature: use different default steps and image sizes depending on sampler and model selected
- fix: #110 use proper version in image metadata
- refactor: samplers all have their own class that inherits from ImageSampler
- refactor: solvers all have their own class that inherits from ImageSolver
- feature: 🎉🎉🎉 Stable Diffusion 2.0
- `--model SD-2.0` to use (it makes worse images than 1.5 though...)
- Tested on macOS and Linux

@ -11,10 +11,11 @@
- allow selection of output video format
- chain multiple operations together imggen => videogen
- make sure terminal output on windows doesn't suck
- add karras schedule to refiners
- add interface for loading diffusers weights
- add method to show cache size
- add method to clear model cache
- add method to clear cached items not recently used (does diffusers have one?)
- https://huggingface.co/playgroundai/playground-v2-1024px-aesthetic
### Old Todo

@ -1,35 +1,37 @@
import logging
import os
import re
from typing import TYPE_CHECKING, Callable
from imaginairy.schema import ControlNetInput, SafetyMode
if TYPE_CHECKING:
from imaginairy.schema import ImaginePrompt
logger = logging.getLogger(__name__)
# leave undocumented. I'd ask that no one publicize this flag. Just want a
# slight barrier to entry. Please don't use this is any way that's gonna cause
# the media or politicians to freak out about AI...
IMAGINAIRY_SAFETY_MODE = os.getenv("IMAGINAIRY_SAFETY_MODE", SafetyMode.STRICT)
IMAGINAIRY_SAFETY_MODE = os.getenv("IMAGINAIRY_SAFETY_MODE", "strict")
if IMAGINAIRY_SAFETY_MODE in {"disabled", "classify"}:
IMAGINAIRY_SAFETY_MODE = SafetyMode.RELAXED
IMAGINAIRY_SAFETY_MODE = "relaxed"
elif IMAGINAIRY_SAFETY_MODE == "filter":
IMAGINAIRY_SAFETY_MODE = SafetyMode.STRICT
IMAGINAIRY_SAFETY_MODE = "strict"
# we put this in the global scope so it can be used in the interactive shell
_most_recent_result = None
def imagine_image_files(
prompts,
outdir,
precision="autocast",
record_step_images=False,
output_file_extension="jpg",
print_caption=False,
make_gif=False,
make_compare_gif=False,
return_filename_type="generated",
videogen=False,
prompts: "list[ImaginePrompt] | ImaginePrompt",
outdir: str,
precision: str = "autocast",
record_step_images: bool = False,
output_file_extension: str = "jpg",
print_caption: bool = False,
make_gif: bool = False,
make_compare_gif: bool = False,
return_filename_type: str = "generated",
videogen: bool = False,
):
from PIL import ImageDraw
@ -46,6 +48,9 @@ def imagine_image_files(
if output_file_extension not in {"jpg", "png"}:
raise ValueError("Must output a png or jpg")
if not isinstance(prompts, list):
prompts = [prompts]
def _record_step(img, description, image_count, step_count, prompt):
steps_path = os.path.join(outdir, "steps", f"{base_count:08}_S{prompt.seed}")
os.makedirs(steps_path, exist_ok=True)
@ -74,7 +79,7 @@ def imagine_image_files(
if prompt.init_image:
img_str = f"_img2img-{prompt.init_image_strength}"
basefilename = (
f"{base_count:06}_{prompt.seed}_{prompt.sampler_type.replace('_', '')}{prompt.steps}_"
f"{base_count:06}_{prompt.seed}_{prompt.solver_type.replace('_', '')}{prompt.steps}_"
f"PS{prompt.prompt_strength}{img_str}_{prompt_normalized(prompt.prompt_text)}"
)
@ -139,15 +144,15 @@ def imagine_image_files(
def imagine(
prompts,
precision="autocast",
debug_img_callback=None,
progress_img_callback=None,
progress_img_interval_steps=3,
prompts: "list[ImaginePrompt] | str | ImaginePrompt",
precision: str = "autocast",
debug_img_callback: Callable | None = None,
progress_img_callback: Callable | None = None,
progress_img_interval_steps: int = 3,
progress_img_interval_min_s=0.1,
half_mode=None,
add_caption=False,
unsafe_retry_count=1,
add_caption: bool = False,
unsafe_retry_count: int = 1,
):
import torch.nn
@ -209,7 +214,7 @@ def imagine(
def _generate_single_image_compvis(
prompt,
prompt: "ImaginePrompt",
debug_img_callback=None,
progress_img_callback=None,
progress_img_interval_steps=3,
@ -248,9 +253,9 @@ def _generate_single_image_compvis(
from imaginairy.modules.midas.api import torch_image_to_depth_map
from imaginairy.outpaint import outpaint_arg_str_parse, prepare_image_for_outpaint
from imaginairy.safety import create_safety_score
from imaginairy.samplers import SAMPLER_LOOKUP
from imaginairy.samplers import SOLVER_LOOKUP
from imaginairy.samplers.editing import CFGEditingDenoiser
from imaginairy.schema import ImaginePrompt, ImagineResult
from imaginairy.schema import ControlInput, ImagineResult, MaskMode
from imaginairy.utils import get_device, randn_seeded
latent_channels = 4
@ -326,8 +331,8 @@ def _generate_single_image_compvis(
prompt.height // downsampling_factor,
prompt.width // downsampling_factor,
]
SamplerCls = SAMPLER_LOOKUP[prompt.sampler_type.lower()]
sampler = SamplerCls(model)
SolverCls = SOLVER_LOOKUP[prompt.solver_type.lower()]
solver = SolverCls(model)
mask_latent = mask_image = mask_image_orig = mask_grayscale = None
t_enc = init_latent = control_image = None
starting_image = None
@ -385,7 +390,7 @@ def _generate_single_image_compvis(
log_img(mask_image, "init mask")
if prompt.mask_mode == ImaginePrompt.MaskMode.REPLACE:
if prompt.mask_mode == MaskMode.REPLACE:
mask_image = ImageOps.invert(mask_image)
mask_image_orig = mask_image
@ -396,7 +401,7 @@ def _generate_single_image_compvis(
if inpaint_method == "controlnet":
result_images["control-inpaint"] = mask_image
control_inputs.append(
ControlNetInput(mode="inpaint", image=mask_image)
ControlInput(mode="inpaint", image=mask_image)
)
seed_everything(prompt.seed)
@ -543,7 +548,7 @@ def _generate_single_image_compvis(
prompt=prompt,
target_height=init_image.height,
target_width=init_image.width,
cutoff=get_model_default_image_size(prompt.model),
cutoff=get_model_default_image_size(prompt.model_architecture),
)
else:
comp_image = _generate_composition_image(
@ -563,7 +568,7 @@ def _generate_single_image_compvis(
model.encode_first_stage(comp_image_t)
)
with lc.timing("sampling"):
samples = sampler.sample(
samples = solver.sample(
num_steps=prompt.steps,
positive_conditioning=positive_conditioning,
neutral_conditioning=neutral_conditioning,
@ -711,8 +716,10 @@ def _generate_composition_image(
composition_prompt = prompt.full_copy(
deep=True,
update={
"width": int(prompt.width * shrink_scale),
"height": int(prompt.height * shrink_scale),
"size": (
int(prompt.width * shrink_scale),
int(prompt.height * shrink_scale),
),
"steps": None,
"upscale": False,
"fix_faces": False,

@ -1,15 +1,16 @@
import logging
from typing import List, Optional
from imaginairy import WeightedPrompt
from imaginairy.config import CONTROLNET_CONFIG_SHORTCUTS
from imaginairy import ImaginePrompt, WeightedPrompt
from imaginairy.config import CONTROL_CONFIG_SHORTCUTS
from imaginairy.model_manager import load_controlnet_adapter
from imaginairy.schema import MaskMode
logger = logging.getLogger(__name__)
def _generate_single_image(
prompt,
prompt: ImaginePrompt,
debug_img_callback=None,
progress_img_callback=None,
progress_img_interval_steps=3,
@ -55,7 +56,7 @@ def _generate_single_image(
)
from imaginairy.outpaint import outpaint_arg_str_parse, prepare_image_for_outpaint
from imaginairy.safety import create_safety_score
from imaginairy.samplers import SamplerName
from imaginairy.samplers import SolverName
from imaginairy.schema import ImaginePrompt, ImagineResult
from imaginairy.utils import get_device, randn_seeded
@ -76,8 +77,8 @@ def _generate_single_image(
control_modes = [c.mode for c in prompt.control_inputs]
sd = get_diffusion_model_refiners(
weights_location=prompt.model,
config_path=prompt.model_config_path,
weights_location=prompt.model_weights,
model_architecture=prompt.model_architecture,
control_weights_locations=tuple(control_modes),
dtype=dtype,
for_inpainting=for_inpainting and inpaint_method == "finetune",
@ -90,7 +91,7 @@ def _generate_single_image(
mask_image = None
mask_image_orig = None
prompt = prompt.make_concrete_copy()
prompt: ImaginePrompt = prompt.make_concrete_copy()
def latent_logger(latents):
progress_latents.append(latents)
@ -171,7 +172,7 @@ def _generate_single_image(
log_img(mask_image, "init mask")
if prompt.mask_mode == ImaginePrompt.MaskMode.REPLACE:
if prompt.mask_mode == MaskMode.REPLACE:
mask_image = ImageOps.invert(mask_image)
mask_image_orig = mask_image
@ -182,7 +183,7 @@ def _generate_single_image(
# if inpaint_method == "controlnet":
# result_images["control-inpaint"] = mask_image
# control_inputs.append(
# ControlNetInput(mode="inpaint", image=mask_image)
# ControlInput(mode="inpaint", image=mask_image)
# )
seed_everything(prompt.seed)
@ -194,7 +195,6 @@ def _generate_single_image(
controlnets = []
if control_modes:
control_strengths = []
from imaginairy.img_processors.control_modes import CONTROL_MODES
for control_input in control_inputs:
@ -231,10 +231,10 @@ def _generate_single_image(
log_img(control_image_disp, "control_image")
if len(control_image_t.shape) == 3:
raise RuntimeError("Control image must be 4D")
raise ValueError("Control image must be 4D")
if control_image_t.shape[1] != 3:
raise RuntimeError("Control image must have 3 channels")
raise ValueError("Control image must have 3 channels")
if (
control_input.mode != "inpaint"
@ -242,21 +242,20 @@ def _generate_single_image(
or control_image_t.max() > 1
):
msg = f"Control image must be in [0, 1] but we received {control_image_t.min()} and {control_image_t.max()}"
raise RuntimeError(msg)
raise ValueError(msg)
if control_image_t.max() == control_image_t.min():
msg = f"No control signal found in control image {control_input.mode}."
raise RuntimeError(msg)
control_strengths.append(control_input.strength)
raise ValueError(msg)
control_weights_path = CONTROLNET_CONFIG_SHORTCUTS.get(
control_input.mode, None
).weights_url
control_config = CONTROL_CONFIG_SHORTCUTS.get(control_input.mode, None)
if not control_config:
msg = f"Unknown control mode: {control_input.mode}"
raise ValueError(msg)
controlnet = load_controlnet_adapter(
name=control_input.mode,
control_weights_location=control_weights_path,
control_weights_location=control_config.weights_location,
target_unet=sd.unet,
scale=control_input.strength,
)
@ -268,7 +267,7 @@ def _generate_single_image(
prompt=prompt,
target_height=init_image.height,
target_width=init_image.width,
cutoff=get_model_default_image_size(prompt.model),
cutoff=get_model_default_image_size(prompt.model_architecture),
dtype=dtype,
)
else:
@ -276,7 +275,7 @@ def _generate_single_image(
prompt=prompt,
target_height=prompt.height,
target_width=prompt.width,
cutoff=get_model_default_image_size(prompt.model),
cutoff=get_model_default_image_size(prompt.model_architecture),
dtype=dtype,
)
if comp_image is not None:
@ -296,12 +295,12 @@ def _generate_single_image(
control_image_t.to(device=sd.device, dtype=sd.dtype)
)
controlnet.inject()
if prompt.sampler_type.lower() == SamplerName.K_DPMPP_2M:
if prompt.solver_type.lower() == SolverName.DPMPP:
sd.scheduler = DPMSolver(num_inference_steps=prompt.steps)
elif prompt.sampler_type.lower() == SamplerName.DDIM:
elif prompt.solver_type.lower() == SolverName.DDIM:
sd.scheduler = DDIM(num_inference_steps=prompt.steps)
else:
msg = f"Unknown sampler type: {prompt.sampler_type}"
msg = f"Unknown solver type: {prompt.solver_type}"
raise ValueError(msg)
sd.scheduler.to(device=sd.device, dtype=sd.dtype)
sd.set_num_inference_steps(prompt.steps)
@ -414,18 +413,24 @@ def _generate_single_image(
caption_text = prompt.caption_text.format(prompt=prompt.prompt_text)
add_caption_to_image(gen_img, caption_text)
# todo: do something smarter
result_images.update(
{
"upscaled": upscaled_img,
"modified_original": rebuilt_orig_img,
"mask_binary": mask_image_orig,
"mask_grayscale": mask_grayscale,
}
)
result = ImagineResult(
img=gen_img,
prompt=prompt,
upscaled_img=upscaled_img,
is_nsfw=safety_score.is_nsfw,
safety_score=safety_score,
modified_original=rebuilt_orig_img,
mask_binary=mask_image_orig,
mask_grayscale=mask_grayscale,
result_images=result_images,
timings={},
progress_latents=[],
timings={}, # todo
progress_latents=[], # todo
)
_most_recent_result = result
@ -441,6 +446,9 @@ def _generate_single_image(
def _prompts_to_embeddings(prompts, text_encoder):
import torch
if not prompts:
prompts = [WeightedPrompt(text="")]
total_weight = sum(wp.weight for wp in prompts)
if str(text_encoder.device) == "cpu":
text_encoder = text_encoder.to(dtype=torch.float32)

@ -31,7 +31,7 @@ remove_option(edit_options, "allow_compose_phase")
@click.option(
"--model-weights-path",
"--model",
help=f"Model to use. Should be one of {', '.join(config.MODEL_SHORT_NAMES)}, or a path to custom weights.",
help=f"Model to use. Should be one of {', '.join(config.IMAGE_WEIGHTS_SHORT_NAMES)}, or a path to custom weights.",
show_default=True,
default="SD-1.5",
)
@ -53,15 +53,13 @@ def edit_cmd(
outdir,
output_file_extension,
repeats,
height,
width,
size,
steps,
seed,
upscale,
fix_faces,
fix_faces_fidelity,
sampler_type,
solver,
log_level,
quiet,
show_work,
@ -76,7 +74,7 @@ def edit_cmd(
caption,
precision,
model_weights_path,
model_config_path,
model_architecture,
prompt_library_path,
version,
make_gif,
@ -95,11 +93,11 @@ def edit_cmd(
Same as calling `aimg imagine --model edit --init-image my-dog.jpg --init-image-strength 1` except this command
can batch edit images.
"""
from imaginairy.schema import ControlNetInput
from imaginairy.schema import ControlInput
allow_compose_phase = False
control_inputs = [
ControlNetInput(
ControlInput(
image=None,
image_raw=None,
mode="edit",
@ -116,15 +114,13 @@ def edit_cmd(
outdir,
output_file_extension,
repeats,
height,
width,
size,
steps,
seed,
upscale,
fix_faces,
fix_faces_fidelity,
sampler_type,
solver,
log_level,
quiet,
show_work,
@ -140,7 +136,7 @@ def edit_cmd(
caption,
precision,
model_weights_path,
model_config_path,
model_architecture,
prompt_library_path,
version,
make_gif,

@ -77,15 +77,13 @@ def imagine_cmd(
outdir,
output_file_extension,
repeats,
height,
width,
size,
steps,
seed,
upscale,
fix_faces,
fix_faces_fidelity,
sampler_type,
solver,
log_level,
quiet,
show_work,
@ -101,7 +99,7 @@ def imagine_cmd(
caption,
precision,
model_weights_path,
model_config_path,
model_architecture,
prompt_library_path,
version,
make_gif,
@ -120,7 +118,7 @@ def imagine_cmd(
Can be invoked via either `aimg imagine` or just `imagine`.
"""
from imaginairy.schema import ControlNetInput, LazyLoadingImage
from imaginairy.schema import ControlInput, LazyLoadingImage
# hacky method of getting order of control images (mixing raw and normal images)
control_images = [
@ -128,13 +126,18 @@ def imagine_cmd(
for o, path in ImagineColorsCommand._option_order
if o.name in ("control_image", "control_image_raw")
]
control_strengths = [
strength
for o, strength in ImagineColorsCommand._option_order
if o.name == "control_strength"
]
control_inputs = []
if control_mode:
for i, cm in enumerate(control_mode):
try:
option = control_images[i]
except IndexError:
option = None
option = index_default(control_images, i, None)
control_strength = index_default(control_strengths, i, 1.0)
if option is None:
control_image = None
control_image_raw = None
@ -149,10 +152,10 @@ def imagine_cmd(
if control_image_raw and control_image_raw.startswith("http"):
control_image_raw = LazyLoadingImage(url=control_image_raw)
control_inputs.append(
ControlNetInput(
ControlInput(
image=control_image,
image_raw=control_image_raw,
strength=float(control_strength[i]),
strength=float(control_strength),
mode=cm,
)
)
@ -167,15 +170,13 @@ def imagine_cmd(
outdir,
output_file_extension,
repeats,
height,
width,
size,
steps,
seed,
upscale,
fix_faces,
fix_faces_fidelity,
sampler_type,
solver,
log_level,
quiet,
show_work,
@ -191,7 +192,7 @@ def imagine_cmd(
caption,
precision,
model_weights_path,
model_config_path,
model_architecture,
prompt_library_path,
version,
make_gif,
@ -204,5 +205,12 @@ def imagine_cmd(
)
def index_default(items, index, default):
try:
return items[index]
except IndexError:
return default
if __name__ == "__main__":
imagine_cmd()

@ -45,8 +45,6 @@ aimg.add_command(describe_cmd, name="describe")
aimg.add_command(edit_cmd, name="edit")
aimg.add_command(edit_demo_cmd, name="edit-demo")
aimg.add_command(imagine_cmd, name="imagine")
# aimg.add_command(prep_images_cmd, name="prep-images")
# aimg.add_command(prune_ckpt_cmd, name="prune-ckpt")
aimg.add_command(upscale_cmd, name="upscale")
aimg.add_command(run_server_cmd, name="server")
aimg.add_command(videogen_cmd, name="videogen")
@ -85,14 +83,14 @@ def model_list_cmd():
from imaginairy import config
print(f"{'ALIAS': <10} {'NAME': <18} {'DESCRIPTION'}")
for model_config in config.MODEL_CONFIGS:
for model_config in config.MODEL_WEIGHT_CONFIGS:
print(
f"{model_config.alias: <10} {model_config.short_name: <18} {model_config.description}"
)
print("\nCONTROL MODES:")
print(f"{'ALIAS': <10} {'NAME': <18} {'CONTROL TYPE'}")
for control_mode in config.CONTROLNET_CONFIGS:
for control_mode in config.CONTROL_CONFIGS:
print(
f"{control_mode.alias: <10} {control_mode.short_name: <18} {control_mode.control_type}"
)

@ -34,15 +34,13 @@ def _imagine_cmd(
outdir,
output_file_extension,
repeats,
height,
width,
size,
steps,
seed,
upscale,
fix_faces,
fix_faces_fidelity,
sampler_type,
solver,
log_level,
quiet,
show_work,
@ -58,7 +56,7 @@ def _imagine_cmd(
caption,
precision,
model_weights_path,
model_config_path,
model_architecture,
prompt_library_path,
version=False,
make_gif=False,
@ -96,15 +94,6 @@ def _imagine_cmd(
configure_logging(log_level)
if (height is not None or width is not None) and size is not None:
msg = "You cannot specify both --size and --height/--width. Please choose one."
raise ValueError(msg)
if size is not None:
from imaginairy.utils.named_resolutions import get_named_resolution
width, height = get_named_resolution(size)
init_images = [init_image] if isinstance(init_image, str) else init_image
from imaginairy.utils import glob_expand_paths
@ -171,10 +160,9 @@ def _imagine_cmd(
init_image_strength=init_image_strength,
control_inputs=control_inputs,
seed=seed,
sampler_type=sampler_type,
solver_type=solver,
steps=steps,
height=height,
width=width,
size=size,
mask_image=mask_image,
mask_prompt=mask_prompt,
mask_mode=mask_mode,
@ -185,8 +173,8 @@ def _imagine_cmd(
fix_faces_fidelity=fix_faces_fidelity,
tile_mode=_tile_mode,
allow_compose_phase=allow_compose_phase,
model=model_weights_path,
model_config_path=model_config_path,
model_weights=model_weights_path,
model_architecture=model_architecture,
caption_text=caption_text,
)
from imaginairy.prompt_schedules import (
@ -318,28 +306,12 @@ common_options = [
type=int,
help="How many times to repeat the renders. If you provide two prompts and --repeat=3 then six images will be generated.",
),
click.option(
"-h",
"--height",
default=None,
show_default=True,
type=int,
help="Image height. Should be multiple of 8.",
),
click.option(
"-w",
"--width",
default=None,
show_default=True,
type=int,
help="Image width. Should be multiple of 8.",
),
click.option(
"--size",
default=None,
show_default=True,
type=str,
help="Image size as a string. Can be a named size or WIDTHxHEIGHT format. Should be multiple of 8. Examples: 512x512, 4k, UHD, 8k, ",
help="Image size as a string. Can be a named size, WIDTHxHEIGHT, or single integer. Should be multiple of 8. Examples: 512x512, 4k, UHD, 8k, 512, 1080p",
),
click.option(
"--steps",
@ -363,18 +335,18 @@ common_options = [
help="How faithful to the original should face enhancement be. 1 = best fidelity, 0 = best looking face.",
),
click.option(
"--sampler-type",
"--solver",
"--sampler",
default=config.DEFAULT_SAMPLER,
default=config.DEFAULT_SOLVER,
show_default=True,
type=click.Choice(config.SAMPLER_TYPE_OPTIONS),
help="What sampling strategy to use.",
type=click.Choice(config.SOLVER_TYPE_NAMES, case_sensitive=False),
help="Solver algorithm to generate the image with. (AKA 'Sampler' or 'Scheduler' in other libraries.",
),
click.option(
"--log-level",
default="INFO",
show_default=True,
type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR"]),
type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR"], case_sensitive=False),
help="What level of logs to show.",
),
click.option(
@ -429,7 +401,7 @@ common_options = [
"--mask-mode",
default="replace",
show_default=True,
type=click.Choice(["keep", "replace"]),
type=click.Choice(["keep", "replace"], case_sensitive=False),
help="Should we replace the masked area or keep it?",
),
click.option(
@ -458,20 +430,20 @@ common_options = [
click.option(
"--precision",
help="Evaluate at this precision.",
type=click.Choice(["full", "autocast"]),
type=click.Choice(["full", "autocast"], case_sensitive=False),
default="autocast",
show_default=True,
),
click.option(
"--model-weights-path",
"--model",
help=f"Model to use. Should be one of {', '.join(config.MODEL_SHORT_NAMES)}, or a path to custom weights.",
help=f"Model to use. Should be one of {', '.join(config.IMAGE_WEIGHTS_SHORT_NAMES)}, or a path to custom weights.",
show_default=True,
default=config.DEFAULT_MODEL,
default=config.DEFAULT_MODEL_WEIGHTS,
),
click.option(
"--model-config-path",
help="Model config file to use. If a model name is specified, the appropriate config will be used.",
"--model-architecture",
help="Model architecture. When specifying custom weights the model architecture must be specified. (sd15, sdxl, etc).",
show_default=True,
default=None,
),

@ -44,9 +44,9 @@ logger = logging.getLogger(__name__)
"--model-weights-path",
"--model",
"model",
help=f"Model to use. Should be one of {', '.join(config.MODEL_SHORT_NAMES)}, or a path to custom weights.",
help=f"Model to use. Should be one of {', '.join(config.IMAGE_WEIGHTS_SHORT_NAMES)}, or a path to custom weights.",
show_default=True,
default=config.DEFAULT_MODEL,
default=config.DEFAULT_MODEL_WEIGHTS,
)
@click.option(
"--person",

@ -4,7 +4,7 @@ from PIL import Image, ImageEnhance, ImageStat
from imaginairy import ImaginePrompt, imagine
from imaginairy.enhancers.describe_image_blip import generate_caption
from imaginairy.schema import ControlNetInput
from imaginairy.schema import ControlInput
logger = logging.getLogger(__name__)
@ -23,7 +23,7 @@ def colorize_img(img, max_width=1024, max_height=1024, caption=None):
caption = caption.replace(" old ", " ")
logger.info(caption)
control_inputs = [
ControlNetInput(mode="colorize", image=img, strength=2),
ControlInput(mode="colorize", image=img, strength=2),
]
prompt_add = ". color photo, sharp-focus, highly detailed, intricate, Canon 5D"
prompt = ImaginePrompt(

@ -1,7 +1,9 @@
from dataclasses import dataclass
from typing import Any, List
DEFAULT_MODEL = "SD-1.5"
DEFAULT_SAMPLER = "ddim"
DEFAULT_MODEL_WEIGHTS = "sd15"
DEFAULT_MODEL_ARCHITECTURE = "sd15"
DEFAULT_SOLVER = "ddim"
DEFAULT_NEGATIVE_PROMPT = (
"Ugly, duplication, duplicates, mutilation, deformed, mutilated, mutation, twisted body, disfigured, bad anatomy, "
@ -12,229 +14,306 @@ DEFAULT_NEGATIVE_PROMPT = (
"grainy, blurred, blurry, writing, calligraphy, signature, text, watermark, bad art,"
)
SPLITMEM_ENABLED = False
midas_url = "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt"
@dataclass
class ModelConfig:
description: str
short_name: str
config_path: str
weights_url: str
default_image_size: int
forced_attn_precision: str = "default"
default_negative_prompt: str = DEFAULT_NEGATIVE_PROMPT
alias: str = None
class ModelArchitecture:
name: str
aliases: List[str]
output_modality: str
defaults: dict[str, Any]
config_path: str | None = None
midas_url = "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt"
@dataclass
class ModelWeightsConfig:
name: str
aliases: List[str]
architecture: ModelArchitecture
defaults: dict[str, Any]
weights_location: str
MODEL_CONFIGS = [
ModelConfig(
description="Stable Diffusion 1.5",
short_name="SD-1.5",
MODEL_ARCHITECTURES = [
ModelArchitecture(
name="Stable Diffusion 1.5",
aliases=["sd15", "sd-15", "sd1.5", "sd-1.5"],
output_modality="image",
defaults={"size": "512"},
config_path="configs/stable-diffusion-v1.yaml",
weights_url="https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/889b629140e71758e1e0006e355c331a5744b4bf/v1-5-pruned-emaonly.ckpt",
default_image_size=512,
alias="sd15",
),
ModelConfig(
description="Stable Diffusion 1.5 - Inpainting",
short_name="SD-1.5-inpaint",
ModelArchitecture(
name="Stable Diffusion 1.5 - Inpainting",
aliases=[
"sd15inpaint",
"sd15-inpaint",
"sd-15-inpaint",
"sd1.5inpaint",
"sd1.5-inpaint",
"sd-1.5-inpaint",
],
output_modality="image",
defaults={"size": "512"},
config_path="configs/stable-diffusion-v1-inpaint.yaml",
weights_url="https://huggingface.co/julienacquaviva/inpainting/resolve/2155ff7fe38b55f4c0d99c2f1ab9b561f8311ca7/sd-v1-5-inpainting.ckpt",
default_image_size=512,
alias="sd15in",
),
# ModelConfig(
# description="Instruct Pix2Pix - Photo Editing",
# short_name="instruct-pix2pix",
# config_path="configs/instruct-pix2pix.yaml",
# weights_url="https://huggingface.co/imaginairy/instruct-pix2pix/resolve/ea0009b3d0d4888f410a40bd06d69516d0b5a577/instruct-pix2pix-00-22000-pruned.ckpt",
# default_image_size=512,
# default_negative_prompt="",
# alias="edit",
# ),
ModelConfig(
description="OpenJourney V1",
short_name="openjourney-v1",
config_path="configs/stable-diffusion-v1.yaml",
weights_url="https://huggingface.co/prompthero/openjourney/resolve/7428477dad893424c92f6ea1cc29d45f6d1448c1/mdjrny-v4.safetensors",
default_image_size=512,
default_negative_prompt="",
alias="oj1",
),
ModelConfig(
description="OpenJourney V2",
short_name="openjourney-v2",
config_path="configs/stable-diffusion-v1.yaml",
weights_url="https://huggingface.co/prompthero/openjourney-v2/resolve/47257274a40e93dab7fbc0cd2cfd5f5704cfeb60/openjourney-v2.ckpt",
default_image_size=512,
default_negative_prompt="",
alias="oj2",
),
ModelConfig(
description="OpenJourney V4",
short_name="openjourney-v4",
config_path="configs/stable-diffusion-v1.yaml",
weights_url="https://huggingface.co/prompthero/openjourney/resolve/e291118e93d5423dc88ac1ed93c02362b17d698f/mdjrny-v4.safetensors",
default_image_size=512,
default_negative_prompt="",
alias="oj4",
),
ModelArchitecture(
name="Stable Diffusion XL",
aliases=["sdxl", "sd-xl"],
output_modality="image",
defaults={"size": "512"},
),
ModelArchitecture(
name="Stable Video Diffusion",
aliases=["svd", "stablevideo"],
output_modality="video",
defaults={"size": "1024x576"},
config_path="configs/svd.yaml",
),
ModelArchitecture(
name="Stable Video Diffusion - Image Decoder",
aliases=["svd-image-decoder", "svd-imdec"],
output_modality="video",
defaults={"size": "1024x576"},
config_path="configs/svd_image_decoder.yaml",
),
ModelArchitecture(
name="Stable Video Diffusion - XT",
aliases=["svd-xt", "svd25f", "svd-25f", "stablevideoxt", "svdxt"],
output_modality="video",
defaults={"size": "1024x576"},
config_path="configs/svd_xt.yaml",
),
ModelArchitecture(
name="Stable Video Diffusion - XT - Image Decoder",
aliases=[
"svd-xt-image-decoder",
"svd-xt-imdec",
"svd-25f-imdec",
"svdxt-imdec",
"svdxtimdec",
"svd25fimdec",
"svdxtimdec",
],
output_modality="video",
defaults={"size": "1024x576"},
config_path="configs/svd_xt_image_decoder.yaml",
),
]
MODEL_ARCHITECTURE_LOOKUP = {}
for m in MODEL_ARCHITECTURES:
for a in m.aliases:
MODEL_ARCHITECTURE_LOOKUP[a] = m
video_models = [
{
"short_name": "svd",
"description": "Stable Video Diffusion",
"default_frames": 14,
"default_steps": 25,
"config_path": "configs/svd.yaml",
"weights_url": "https://huggingface.co/imaginairy/stable-video-diffusion/resolve/f9dce2757a0713da6262f35438050357c2be7ee6/svd.fp16.safetensors",
},
{
"short_name": "svd_image_decoder",
"description": "Stable Video Diffusion - Image Decoder",
"default_frames": 14,
"default_steps": 25,
"config_path": "configs/svd_image_decoder.yaml",
"weights_url": "https://huggingface.co/imaginairy/stable-video-diffusion/resolve/f9dce2757a0713da6262f35438050357c2be7ee6/svd_image_decoder.fp16.safetensors",
},
{
"short_name": "svd_xt",
"description": "Stable Video Diffusion - XT",
"default_frames": 25,
"default_steps": 30,
"config_path": "configs/svd_xt.yaml",
"weights_url": "https://huggingface.co/imaginairy/stable-video-diffusion/resolve/f9dce2757a0713da6262f35438050357c2be7ee6/svd_xt.fp16.safetensors",
},
{
"short_name": "svd_xt_image_decoder",
"description": "Stable Video Diffusion - XT - Image Decoder",
"default_frames": 25,
"default_steps": 30,
"config_path": "configs/svd_xt_image_decoder.yaml",
"weights_url": "https://huggingface.co/imaginairy/stable-video-diffusion/resolve/f9dce2757a0713da6262f35438050357c2be7ee6/svd_xt_image_decoder.fp16.safetensors",
},
MODEL_WEIGHT_CONFIGS = [
ModelWeightsConfig(
name="Stable Diffusion 1.5",
aliases=MODEL_ARCHITECTURE_LOOKUP["sd15"].aliases,
architecture=MODEL_ARCHITECTURE_LOOKUP["sd15"],
defaults={"negative_prompt": DEFAULT_NEGATIVE_PROMPT},
weights_location="https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/889b629140e71758e1e0006e355c331a5744b4bf/v1-5-pruned-emaonly.ckpt",
),
ModelWeightsConfig(
name="Stable Diffusion 1.5 - Inpainting",
aliases=MODEL_ARCHITECTURE_LOOKUP["sd15inpaint"].aliases,
architecture=MODEL_ARCHITECTURE_LOOKUP["sd15inpaint"],
defaults={"negative_prompt": DEFAULT_NEGATIVE_PROMPT},
weights_location="https://huggingface.co/julienacquaviva/inpainting/resolve/2155ff7fe38b55f4c0d99c2f1ab9b561f8311ca7/sd-v1-5-inpainting.ckpt",
),
ModelWeightsConfig(
name="OpenJourney V1",
aliases=["openjourney-v1", "oj1", "ojv1", "openjourney1"],
architecture=MODEL_ARCHITECTURE_LOOKUP["sd15"],
defaults={"negative_prompt": "poor quality"},
weights_location="https://huggingface.co/prompthero/openjourney/resolve/7428477dad893424c92f6ea1cc29d45f6d1448c1/mdjrny-v4.safetensors",
),
ModelWeightsConfig(
name="OpenJourney V2",
aliases=["openjourney-v2", "oj2", "ojv2", "openjourney2"],
architecture=MODEL_ARCHITECTURE_LOOKUP["sd15"],
weights_location="https://huggingface.co/prompthero/openjourney-v2/resolve/47257274a40e93dab7fbc0cd2cfd5f5704cfeb60/openjourney-v2.ckpt",
defaults={"negative_prompt": "poor quality"},
),
ModelWeightsConfig(
name="OpenJourney V4",
aliases=["openjourney-v4", "oj4", "ojv4", "openjourney4", "openjourney", "oj"],
architecture=MODEL_ARCHITECTURE_LOOKUP["sd15"],
weights_location="https://huggingface.co/prompthero/openjourney/resolve/e291118e93d5423dc88ac1ed93c02362b17d698f/mdjrny-v4.safetensors",
defaults={"negative_prompt": "poor quality"},
),
# Video Weights
ModelWeightsConfig(
name="Stable Video Diffusion",
aliases=MODEL_ARCHITECTURE_LOOKUP["svd"].aliases,
architecture=MODEL_ARCHITECTURE_LOOKUP["svd"],
weights_location="https://huggingface.co/imaginairy/stable-video-diffusion/resolve/f9dce2757a0713da6262f35438050357c2be7ee6/svd.fp16.safetensors",
defaults={"frames": 14, "steps": 25},
),
ModelWeightsConfig(
name="Stable Video Diffusion - Image Decoder",
aliases=MODEL_ARCHITECTURE_LOOKUP["svd-image-decoder"].aliases,
architecture=MODEL_ARCHITECTURE_LOOKUP["svd-image-decoder"],
weights_location="https://huggingface.co/imaginairy/stable-video-diffusion/resolve/f9dce2757a0713da6262f35438050357c2be7ee6/svd_image_decoder.fp16.safetensors",
defaults={"frames": 14, "steps": 25},
),
ModelWeightsConfig(
name="Stable Video Diffusion - XT",
aliases=MODEL_ARCHITECTURE_LOOKUP["svdxt"].aliases,
architecture=MODEL_ARCHITECTURE_LOOKUP["svdxt"],
weights_location="https://huggingface.co/imaginairy/stable-video-diffusion/resolve/f9dce2757a0713da6262f35438050357c2be7ee6/svd_xt.fp16.safetensors",
defaults={"frames": 25, "steps": 30},
),
ModelWeightsConfig(
name="Stable Video Diffusion - XT - Image Decoder",
aliases=MODEL_ARCHITECTURE_LOOKUP["svd-xt-image-decoder"].aliases,
architecture=MODEL_ARCHITECTURE_LOOKUP["svd-xt-image-decoder"],
weights_location="https://huggingface.co/imaginairy/stable-video-diffusion/resolve/f9dce2757a0713da6262f35438050357c2be7ee6/svd_xt_image_decoder.fp16.safetensors",
defaults={"frames": 25, "steps": 30},
),
]
video_models = {m["short_name"]: m for m in video_models}
MODEL_CONFIG_SHORTCUTS = {m.short_name: m for m in MODEL_CONFIGS}
for m in MODEL_CONFIGS:
if m.alias:
MODEL_CONFIG_SHORTCUTS[m.alias] = m
MODEL_WEIGHT_CONFIG_LOOKUP = {}
for m in MODEL_WEIGHT_CONFIGS:
for a in m.aliases:
MODEL_WEIGHT_CONFIG_LOOKUP[a] = m
MODEL_CONFIG_SHORTCUTS["openjourney"] = MODEL_CONFIG_SHORTCUTS["openjourney-v2"]
MODEL_CONFIG_SHORTCUTS["oj"] = MODEL_CONFIG_SHORTCUTS["openjourney-v2"]
MODEL_SHORT_NAMES = sorted(MODEL_CONFIG_SHORTCUTS.keys())
IMAGE_WEIGHTS_SHORT_NAMES = [
k
for k, mw in MODEL_WEIGHT_CONFIG_LOOKUP.items()
if mw.architecture.output_modality == "image"
]
IMAGE_WEIGHTS_SHORT_NAMES.sort()
@dataclass
class ControlNetConfig:
short_name: str
class ControlConfig:
name: str
aliases: List[str]
control_type: str
config_path: str
weights_url: str
alias: str = None
weights_location: str
CONTROLNET_CONFIGS = [
ControlNetConfig(
short_name="canny15",
CONTROL_CONFIGS = [
ControlConfig(
name="Canny Edge Control",
aliases=["canny", "canny15"],
control_type="canny",
config_path="configs/control-net-v15.yaml",
weights_url="https://huggingface.co/lllyasviel/control_v11p_sd15_canny/resolve/115a470d547982438f70198e353a921996e2e819/diffusion_pytorch_model.fp16.safetensors",
alias="canny",
weights_location="https://huggingface.co/lllyasviel/control_v11p_sd15_canny/resolve/115a470d547982438f70198e353a921996e2e819/diffusion_pytorch_model.fp16.safetensors",
),
ControlNetConfig(
short_name="depth15",
ControlConfig(
name="Depth Control",
aliases=["depth", "depth15"],
control_type="depth",
config_path="configs/control-net-v15.yaml",
weights_url="https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth/resolve/539f99181d33db39cf1af2e517cd8056785f0a87/diffusion_pytorch_model.fp16.safetensors",
alias="depth",
weights_location="https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth/resolve/539f99181d33db39cf1af2e517cd8056785f0a87/diffusion_pytorch_model.fp16.safetensors",
),
ControlNetConfig(
short_name="normal15",
ControlConfig(
name="Normal Map Control",
aliases=["normal", "normal15"],
control_type="normal",
config_path="configs/control-net-v15.yaml",
weights_url="https://huggingface.co/lllyasviel/control_v11p_sd15_normalbae/resolve/cb7296e6587a219068e9d65864e38729cd862aa8/diffusion_pytorch_model.fp16.safetensors",
alias="normal",
weights_location="https://huggingface.co/lllyasviel/control_v11p_sd15_normalbae/resolve/cb7296e6587a219068e9d65864e38729cd862aa8/diffusion_pytorch_model.fp16.safetensors",
),
ControlNetConfig(
short_name="hed15",
ControlConfig(
name="Soft Edge Control (HED)",
aliases=["hed", "hed15"],
control_type="hed",
config_path="configs/control-net-v15.yaml",
weights_url="https://huggingface.co/lllyasviel/control_v11p_sd15_softedge/resolve/b5bcad0c48e9b12f091968cf5eadbb89402d6bc9/diffusion_pytorch_model.fp16.safetensors",
alias="hed",
weights_location="https://huggingface.co/lllyasviel/control_v11p_sd15_softedge/resolve/b5bcad0c48e9b12f091968cf5eadbb89402d6bc9/diffusion_pytorch_model.fp16.safetensors",
),
ControlNetConfig(
short_name="openpose15",
ControlConfig(
name="Pose Control",
control_type="openpose",
config_path="configs/control-net-v15.yaml",
weights_url="https://huggingface.co/lllyasviel/control_v11p_sd15_openpose/resolve/9ae9f970358db89e211b87c915f9535c6686d5ba/diffusion_pytorch_model.fp16.safetensors",
alias="openpose",
weights_location="https://huggingface.co/lllyasviel/control_v11p_sd15_openpose/resolve/9ae9f970358db89e211b87c915f9535c6686d5ba/diffusion_pytorch_model.fp16.safetensors",
aliases=["openpose", "pose", "pose15", "openpose15"],
),
ControlNetConfig(
short_name="shuffle15",
ControlConfig(
name="Shuffle Control",
control_type="shuffle",
config_path="configs/control-net-v15-pool.yaml",
weights_url="https://huggingface.co/lllyasviel/control_v11e_sd15_shuffle/resolve/8cf275970f984acf5cc0fdfa537db8be098936a3/diffusion_pytorch_model.fp16.safetensors",
alias="shuffle",
weights_location="https://huggingface.co/lllyasviel/control_v11e_sd15_shuffle/resolve/8cf275970f984acf5cc0fdfa537db8be098936a3/diffusion_pytorch_model.fp16.safetensors",
aliases=["shuffle", "shuffle15"],
),
# "instruct pix2pix"
ControlNetConfig(
short_name="edit15",
ControlConfig(
name="Edit Prompt Control",
aliases=["edit", "edit15"],
control_type="edit",
config_path="configs/control-net-v15.yaml",
weights_url="https://huggingface.co/lllyasviel/control_v11e_sd15_ip2p/resolve/1fed6ebb905c61929a60514830eb05b039969d6d/diffusion_pytorch_model.fp16.safetensors",
alias="edit",
weights_location="https://huggingface.co/lllyasviel/control_v11e_sd15_ip2p/resolve/1fed6ebb905c61929a60514830eb05b039969d6d/diffusion_pytorch_model.fp16.safetensors",
),
ControlNetConfig(
short_name="inpaint15",
ControlConfig(
name="Inpaint Control",
aliases=["inpaint", "inpaint15"],
control_type="inpaint",
config_path="configs/control-net-v15.yaml",
weights_url="https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint/resolve/c96e03a807e64135568ba8aecb66b3a306ec73bd/diffusion_pytorch_model.fp16.safetensors",
alias="inpaint",
weights_location="https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint/resolve/c96e03a807e64135568ba8aecb66b3a306ec73bd/diffusion_pytorch_model.fp16.safetensors",
),
ControlNetConfig(
short_name="details15",
ControlConfig(
name="Details Control (Upscale Tile)",
aliases=["details", "details15"],
control_type="details",
config_path="configs/control-net-v15.yaml",
weights_url="https://huggingface.co/lllyasviel/control_v11f1e_sd15_tile/resolve/3f877705c37010b7221c3d10743307d6b5b6efac/diffusion_pytorch_model.bin",
alias="details",
weights_location="https://huggingface.co/lllyasviel/control_v11f1e_sd15_tile/resolve/3f877705c37010b7221c3d10743307d6b5b6efac/diffusion_pytorch_model.bin",
),
ControlNetConfig(
short_name="colorize15",
ControlConfig(
name="Brightness Control (Colorize)",
aliases=["colorize", "colorize15"],
control_type="colorize",
config_path="configs/control-net-v15.yaml",
weights_url="https://huggingface.co/ioclab/control_v1p_sd15_brightness/resolve/8509361eb1ba89c03839040ed8c75e5f11bbd9c5/diffusion_pytorch_model.safetensors",
alias="colorize",
weights_location="https://huggingface.co/ioclab/control_v1p_sd15_brightness/resolve/8509361eb1ba89c03839040ed8c75e5f11bbd9c5/diffusion_pytorch_model.safetensors",
),
]
CONTROLNET_CONFIG_SHORTCUTS = {}
for m in CONTROLNET_CONFIGS:
if m.alias:
CONTROLNET_CONFIG_SHORTCUTS[m.alias] = m
for m in CONTROLNET_CONFIGS:
CONTROLNET_CONFIG_SHORTCUTS[m.short_name] = m
SAMPLER_TYPE_OPTIONS = [
# "plms",
"ddim",
"k_dpmpp_2m"
# "k_dpm_fast",
# "k_dpm_adaptive",
# "k_lms",
# "k_dpm_2",
# "k_dpm_2_a",
# "k_dpmpp_2m",
# "k_dpmpp_2s_a",
# "k_euler",
# "k_euler_a",
# "k_heun",
CONTROL_CONFIG_SHORTCUTS: dict[str, ControlConfig] = {}
for m in CONTROL_CONFIGS:
for a in m.aliases:
CONTROL_CONFIG_SHORTCUTS[a] = m
@dataclass
class SolverConfig:
name: str
short_name: str
aliases: List[str]
papers: List[str]
implementations: List[str]
SOLVER_CONFIGS = [
SolverConfig(
name="DDIM",
short_name="DDIM",
aliases=["ddim"],
papers=["https://arxiv.org/abs/2010.02502"],
implementations=[
"https://github.com/ermongroup/ddim",
"https://github.com/Stability-AI/stablediffusion/blob/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/ldm/models/diffusion/ddim.py#L10",
"https://github.com/huggingface/diffusers/blob/76c645d3a641c879384afcb43496f0b7db8cc5cb/src/diffusers/schedulers/scheduling_ddim.py#L131",
],
),
SolverConfig(
name="DPM-Solver++",
short_name="DPMPP",
aliases=["dpmpp", "dpm++", "dpmsolver"],
papers=["https://arxiv.org/abs/2211.01095"],
implementations=[
"https://github.com/LuChengTHU/dpm-solver/blob/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/dpm_solver_pytorch.py#L337",
"https://github.com/apple/ml-stable-diffusion/blob/7449ce46a4b23c94413b714704202e4ea4c55080/swift/StableDiffusion/pipeline/DPMSolverMultistepScheduler.swift#L27",
"https://github.com/crowsonkb/k-diffusion/blob/045515774882014cc14c1ba2668ab5bad9cbf7c0/k_diffusion/sampling.py#L509",
],
),
]
SOLVER_TYPE_NAMES = [s.aliases[0] for s in SOLVER_CONFIGS]
SOLVER_LOOKUP = {}
for s in SOLVER_CONFIGS:
for a in s.aliases:
SOLVER_LOOKUP[a.lower()] = s

@ -26,7 +26,7 @@ class StableStudioStyle(BaseModel):
image: Optional[HttpUrl] = None
class StableStudioSampler(BaseModel):
class StableStudioSolver(BaseModel):
id: str
name: Optional[str] = None
@ -55,7 +55,7 @@ class StableStudioInput(BaseModel, extra=Extra.forbid):
style: Optional[str] = None
width: Optional[int] = None
height: Optional[int] = None
sampler: Optional[StableStudioSampler] = None
solver: Optional[StableStudioSolver] = None
cfg_scale: Optional[float] = Field(None, alias="cfgScale")
steps: Optional[int] = None
seed: Optional[int] = None
@ -88,14 +88,14 @@ class StableStudioInput(BaseModel, extra=Extra.forbid):
mask_image = self.mask_image.blob if self.mask_image else None
sampler_type = self.sampler.id if self.sampler else None
solver_type = self.solver.id if self.solver else None
return ImaginePrompt(
prompt=positive_prompt,
prompt_strength=self.cfg_scale,
negative_prompt=negative_prompt,
model=self.model,
sampler_type=sampler_type,
solver_type=solver_type,
seed=self.seed,
steps=self.steps,
height=self.height,

@ -8,7 +8,7 @@ from imaginairy.http_app.stablestudio.models import (
StableStudioBatchResponse,
StableStudioImage,
StableStudioModel,
StableStudioSampler,
StableStudioSolver,
)
from imaginairy.http_app.utils import generate_image_b64
@ -37,11 +37,14 @@ async def generate(studio_request: StableStudioBatchRequest):
@router.get("/samplers")
async def list_samplers():
from imaginairy.config import SAMPLER_TYPE_OPTIONS
from imaginairy.config import SOLVER_CONFIGS
sampler_objs = []
for sampler_type in SAMPLER_TYPE_OPTIONS:
sampler_obj = StableStudioSampler(id=sampler_type, name=sampler_type)
for solver_config in SOLVER_CONFIGS:
sampler_obj = StableStudioSolver(
id=solver_config.aliases[0], name=solver_config.aliases[0]
)
sampler_objs.append(sampler_obj)
return sampler_objs
@ -49,10 +52,10 @@ async def list_samplers():
@router.get("/models")
async def list_models():
from imaginairy.config import MODEL_CONFIGS
from imaginairy.config import MODEL_WEIGHT_CONFIGS
model_objs = []
for model_config in MODEL_CONFIGS:
for model_config in MODEL_WEIGHT_CONFIGS:
if "inpaint" in model_config.description.lower():
continue
model_obj = StableStudioModel(

@ -17,7 +17,7 @@ from refiners.foundationals.latent_diffusion import SD1ControlnetAdapter, SD1UNe
from safetensors.torch import load_file
from imaginairy import config as iconfig
from imaginairy.config import MODEL_SHORT_NAMES
from imaginairy.config import IMAGE_WEIGHTS_SHORT_NAMES, ModelArchitecture
from imaginairy.modules import attention
from imaginairy.paths import PKG_ROOT
from imaginairy.utils import get_device, instantiate_from_config
@ -66,7 +66,7 @@ def load_state_dict(weights_location, half_mode=False, device=None):
except FileNotFoundError as e:
if e.errno == 2:
logger.error(
f'Error: "{ckpt_path}" not a valid path to model weights.\nPreconfigured models you can use: {MODEL_SHORT_NAMES}.'
f'Error: "{ckpt_path}" not a valid path to model weights.\nPreconfigured models you can use: {IMAGE_WEIGHTS_SHORT_NAMES}.'
)
sys.exit(1)
raise
@ -149,7 +149,7 @@ def add_controlnet(base_state_dict, controlnet_state_dict):
def get_diffusion_model(
weights_location=iconfig.DEFAULT_MODEL,
weights_location=iconfig.DEFAULT_MODEL_WEIGHTS,
config_path="configs/stable-diffusion-v1.yaml",
control_weights_locations=None,
half_mode=None,
@ -174,7 +174,7 @@ def get_diffusion_model(
f"Failed to load inpainting model. Attempting to fall-back to standard model. {e!s}"
)
return _get_diffusion_model(
iconfig.DEFAULT_MODEL,
iconfig.DEFAULT_MODEL_WEIGHTS,
config_path,
half_mode,
for_inpainting=False,
@ -184,8 +184,8 @@ def get_diffusion_model(
def _get_diffusion_model(
weights_location=iconfig.DEFAULT_MODEL,
config_path="configs/stable-diffusion-v1.yaml",
weights_location=iconfig.DEFAULT_MODEL_WEIGHTS,
model_architecture="configs/stable-diffusion-v1.yaml",
half_mode=None,
for_inpainting=False,
control_weights_locations=None,
@ -197,24 +197,20 @@ def _get_diffusion_model(
"""
global MOST_RECENTLY_LOADED_MODEL
(
model_config,
weights_location,
config_path,
control_weights_locations,
) = resolve_model_paths(
weights_path=weights_location,
config_path=config_path,
control_weights_paths=control_weights_locations,
model_weights_config = resolve_model_weights_config(
model_weights=weights_location,
default_model_architecture=model_architecture,
for_inpainting=for_inpainting,
)
# some models need the attention calculated in float32
if model_config is not None:
attention.ATTENTION_PRECISION_OVERRIDE = model_config.forced_attn_precision
if model_weights_config is not None:
attention.ATTENTION_PRECISION_OVERRIDE = (
model_weights_config.forced_attn_precision
)
else:
attention.ATTENTION_PRECISION_OVERRIDE = "default"
diffusion_model = _load_diffusion_model(
config_path=config_path,
config_path=model_weights_config.architecture.config_path,
weights_location=weights_location,
half_mode=half_mode,
)
@ -229,8 +225,8 @@ def _get_diffusion_model(
def get_diffusion_model_refiners(
weights_location=iconfig.DEFAULT_MODEL,
config_path="configs/stable-diffusion-v1.yaml",
weights_location=iconfig.DEFAULT_MODEL_WEIGHTS,
model_architecture=None,
control_weights_locations=None,
dtype=None,
for_inpainting=False,
@ -243,8 +239,8 @@ def get_diffusion_model_refiners(
try:
return _get_diffusion_model_refiners(
weights_location,
config_path,
for_inpainting,
model_architecture=model_architecture,
for_inpainting=for_inpainting,
dtype=dtype,
control_weights_locations=control_weights_locations,
)
@ -254,8 +250,8 @@ def get_diffusion_model_refiners(
f"Failed to load inpainting model. Attempting to fall-back to standard model. {e!s}"
)
return _get_diffusion_model_refiners(
iconfig.DEFAULT_MODEL,
config_path,
iconfig.DEFAULT_MODEL_WEIGHTS,
model_architecture=model_architecture,
dtype=dtype,
for_inpainting=False,
control_weights_locations=control_weights_locations,
@ -264,8 +260,8 @@ def get_diffusion_model_refiners(
def _get_diffusion_model_refiners(
weights_location=iconfig.DEFAULT_MODEL,
config_path="configs/stable-diffusion-v1.yaml",
weights_location=iconfig.DEFAULT_MODEL_WEIGHTS,
model_architecture=None,
for_inpainting=False,
control_weights_locations=None,
device=None,
@ -279,7 +275,7 @@ def _get_diffusion_model_refiners(
sd = _get_diffusion_model_refiners_only(
weights_location=weights_location,
config_path=config_path,
model_architecture=model_architecture,
for_inpainting=for_inpainting,
device=device,
dtype=dtype,
@ -290,10 +286,9 @@ def _get_diffusion_model_refiners(
@lru_cache(maxsize=1)
def _get_diffusion_model_refiners_only(
weights_location=iconfig.DEFAULT_MODEL,
config_path="configs/stable-diffusion-v1.yaml",
weights_location=iconfig.DEFAULT_MODEL_WEIGHTS,
model_architecture=None,
for_inpainting=False,
control_weights_locations=None,
device=None,
dtype=torch.float16,
):
@ -312,28 +307,17 @@ def _get_diffusion_model_refiners_only(
device = device or get_device()
(
model_config,
weights_location,
config_path,
control_weights_locations,
) = resolve_model_paths(
weights_path=weights_location,
config_path=config_path,
control_weights_paths=control_weights_locations,
model_weights_config = resolve_model_weights_config(
model_weights=weights_location,
default_model_architecture=model_architecture,
for_inpainting=for_inpainting,
)
# some models need the attention calculated in float32
if model_config is not None:
attention.ATTENTION_PRECISION_OVERRIDE = model_config.forced_attn_precision
else:
attention.ATTENTION_PRECISION_OVERRIDE = "default"
(
vae_weights,
unet_weights,
text_encoder_weights,
) = load_stable_diffusion_compvis_weights(weights_location)
) = load_stable_diffusion_compvis_weights(model_weights_config.weights_location)
if for_inpainting:
unet = SD1UNet(in_channels=9)
@ -422,58 +406,69 @@ def load_controlnet(control_weights_location, half_mode):
return controlnet
def resolve_model_paths(
weights_path=iconfig.DEFAULT_MODEL,
config_path=None,
control_weights_paths=None,
for_inpainting=False,
):
def resolve_model_weights_config(
model_weights: str,
default_model_architecture: str | None = None,
for_inpainting: bool = False,
) -> iconfig.ModelWeightsConfig:
"""Resolve weight and config path if they happen to be shortcuts."""
model_metadata_w = iconfig.MODEL_CONFIG_SHORTCUTS.get(weights_path, None)
model_metadata_c = iconfig.MODEL_CONFIG_SHORTCUTS.get(config_path, None)
control_weights_paths = control_weights_paths or []
control_net_metadatas = [
iconfig.CONTROLNET_CONFIG_SHORTCUTS.get(control_weights_path, None)
for control_weights_path in control_weights_paths
]
if not control_net_metadatas and for_inpainting:
model_metadata_w = iconfig.MODEL_CONFIG_SHORTCUTS.get(
f"{weights_path}-inpaint", model_metadata_w
if for_inpainting:
model_weights_config = iconfig.MODEL_WEIGHT_CONFIG_LOOKUP.get(
f"{model_weights.lower()}-inpaint", None
)
if model_weights_config:
return model_weights_config
model_weights_config = iconfig.MODEL_WEIGHT_CONFIG_LOOKUP.get(
model_weights.lower(), None
)
if model_weights_config:
return model_weights_config
if not default_model_architecture:
msg = "You must specify the model architecture when loading custom weights."
raise ValueError(msg)
default_model_architecture = default_model_architecture.lower()
model_architecture_config = None
if for_inpainting:
model_architecture_config = iconfig.MODEL_ARCHITECTURE_LOOKUP.get(
f"{default_model_architecture}-inpaint", None
)
model_metadata_c = iconfig.MODEL_CONFIG_SHORTCUTS.get(
f"{config_path}-inpaint", model_metadata_c
if not model_architecture_config:
model_architecture_config = iconfig.MODEL_ARCHITECTURE_LOOKUP.get(
default_model_architecture, None
)
if model_architecture_config is None:
msg = f"Invalid model architecture: {default_model_architecture}"
raise ValueError(msg)
model_weights_config = iconfig.ModelWeightsConfig(
name="Custom Loaded",
aliases=[],
architecture=model_architecture_config,
weights_location=model_weights,
defaults={},
)
return model_weights_config
def get_model_default_image_size(model_architecture: str | ModelArchitecture):
if isinstance(model_architecture, str):
model_architecture = iconfig.MODEL_WEIGHT_CONFIG_LOOKUP.get(
model_architecture, None
)
default_size = None
if model_architecture:
default_size = model_architecture.defaults.get("size")
if model_metadata_w:
if config_path is None:
config_path = model_metadata_w.config_path
weights_path = model_metadata_w.weights_url
if model_metadata_c:
config_path = model_metadata_c.config_path
if config_path is None:
config_path = iconfig.MODEL_CONFIG_SHORTCUTS[iconfig.DEFAULT_MODEL].config_path
if control_net_metadatas:
if "stable-diffusion-v1" not in config_path:
msg = "Control net is only supported for stable diffusion v1. Please use a different model."
raise ValueError(msg)
control_weights_paths = [cnm.weights_url for cnm in control_net_metadatas]
config_path = control_net_metadatas[0].config_path
model_metadata = model_metadata_w or model_metadata_c
logger.debug(f"Loading model weights from: {weights_path}")
logger.debug(f"Loading model config from: {config_path}")
return model_metadata, weights_path, config_path, control_weights_paths
def get_model_default_image_size(weights_location):
model_config = iconfig.MODEL_CONFIG_SHORTCUTS.get(weights_location, None)
if model_config:
return model_config.default_image_size
return 512
if default_size is None:
default_size = 512
return default_size
def get_current_diffusion_model():

@ -1,21 +1,20 @@
from imaginairy.samplers import kdiff
from imaginairy.samplers.base import SamplerName # noqa
from imaginairy.samplers.ddim import DDIMSampler
from imaginairy.samplers.base import SolverName # noqa
from imaginairy.samplers.ddim import DDIMSolver
SAMPLERS = [
# PLMSSampler,
DDIMSampler,
SOLVERS = [
# PLMSSolver,
DDIMSolver,
# kdiff.DPMFastSampler,
# kdiff.DPMAdaptiveSampler,
# kdiff.LMSSampler,
# kdiff.DPM2Sampler,
# kdiff.DPM2AncestralSampler,
kdiff.DPMPP2MSampler,
# kdiff.DPMPP2MSampler,
# kdiff.DPMPP2SAncestralSampler,
# kdiff.EulerSampler,
# kdiff.EulerAncestralSampler,
# kdiff.HeunSampler,
]
SAMPLER_LOOKUP = {sampler.short_name: sampler for sampler in SAMPLERS}
SAMPLER_TYPE_OPTIONS = [sampler.short_name for sampler in SAMPLERS]
SOLVER_LOOKUP = {s.short_name: s for s in SOLVERS}
SOLVER_TYPE_OPTIONS = [s.short_name for s in SOLVERS]

@ -16,9 +16,10 @@ from imaginairy.utils import get_device
logger = logging.getLogger(__name__)
class SamplerName:
class SolverName:
PLMS = "plms"
DDIM = "ddim"
DPMPP = "dpmpp"
K_DPM_FAST = "k_dpm_fast"
K_DPM_ADAPTIVE = "k_dpm_adaptive"
K_LMS = "k_lms"
@ -31,7 +32,7 @@ class SamplerName:
K_HEUN = "k_heun"
class ImageSampler(ABC):
class ImageSolver(ABC):
short_name: str
name: str
default_steps: int

@ -8,9 +8,9 @@ from tqdm import tqdm
from imaginairy.log_utils import increment_step, log_latent
from imaginairy.modules.diffusion.util import extract_into_tensor, noise_like
from imaginairy.samplers.base import (
ImageSampler,
ImageSolver,
NoiseSchedule,
SamplerName,
SolverName,
get_noise_prediction,
mask_blend,
)
@ -19,14 +19,14 @@ from imaginairy.utils import get_device
logger = logging.getLogger(__name__)
class DDIMSampler(ImageSampler):
class DDIMSolver(ImageSolver):
"""
Denoising Diffusion Implicit Models.
https://arxiv.org/abs/2010.02502
"""
short_name = SamplerName.DDIM
short_name = SolverName.DDIM
name = "Denoising Diffusion Implicit Models"
default_steps = 50

@ -6,8 +6,8 @@ from torch import nn
from imaginairy.log_utils import increment_step, log_latent
from imaginairy.samplers.base import (
ImageSampler,
SamplerName,
ImageSolver,
SolverName,
get_noise_prediction,
mask_blend,
)
@ -57,7 +57,7 @@ def sample_dpm_fast(model, x, sigmas, extra_args=None, disable=False, callback=N
)
class KDiffusionSampler(ImageSampler, ABC):
class KDiffusionSolver(ImageSolver, ABC):
sampler_func: callable
def __init__(self, model):
@ -98,9 +98,9 @@ class KDiffusionSampler(ImageSampler, ABC):
# see https://github.com/crowsonkb/k-diffusion/issues/43#issuecomment-1305195666
if self.short_name in (
SamplerName.K_DPM_2,
SamplerName.K_DPMPP_2M,
SamplerName.K_DPM_2_ANCESTRAL,
SolverName.K_DPM_2,
SolverName.K_DPMPP_2M,
SolverName.K_DPM_2_ANCESTRAL,
):
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
@ -152,73 +152,73 @@ class KDiffusionSampler(ImageSampler, ABC):
#
# class DPMFastSampler(KDiffusionSampler):
# short_name = SamplerName.K_DPM_FAST
# class DPMFastSampler(KDiffusionSolver):
# short_name = SolverName.K_DPM_FAST
# name = "Diffusion probabilistic models - fast"
# default_steps = 15
# sampler_func = staticmethod(sample_dpm_fast)
#
#
# class DPMAdaptiveSampler(KDiffusionSampler):
# short_name = SamplerName.K_DPM_ADAPTIVE
# class DPMAdaptiveSampler(KDiffusionSolver):
# short_name = SolverName.K_DPM_ADAPTIVE
# name = "Diffusion probabilistic models - adaptive"
# default_steps = 40
# sampler_func = staticmethod(sample_dpm_adaptive)
#
#
# class DPM2Sampler(KDiffusionSampler):
# short_name = SamplerName.K_DPM_2
# class DPM2Sampler(KDiffusionSolver):
# short_name = SolverName.K_DPM_2
# name = "Diffusion probabilistic models - 2"
# default_steps = 40
# sampler_func = staticmethod(k_sampling.sample_dpm_2)
#
#
# class DPM2AncestralSampler(KDiffusionSampler):
# short_name = SamplerName.K_DPM_2_ANCESTRAL
# class DPM2AncestralSampler(KDiffusionSolver):
# short_name = SolverName.K_DPM_2_ANCESTRAL
# name = "Diffusion probabilistic models - 2 ancestral"
# default_steps = 40
# sampler_func = staticmethod(k_sampling.sample_dpm_2_ancestral)
#
class DPMPP2MSampler(KDiffusionSampler):
short_name = SamplerName.K_DPMPP_2M
class DPMPP2MSampler(KDiffusionSolver):
short_name = SolverName.K_DPMPP_2M
name = "Diffusion probabilistic models - 2m"
default_steps = 15
sampler_func = staticmethod(k_sampling.sample_dpmpp_2m)
#
# class DPMPP2SAncestralSampler(KDiffusionSampler):
# short_name = SamplerName.K_DPMPP_2S_ANCESTRAL
# class DPMPP2SAncestralSampler(KDiffusionSolver):
# short_name = SolverName.K_DPMPP_2S_ANCESTRAL
# name = "Ancestral sampling with DPM-Solver++(2S) second-order steps."
# default_steps = 15
# sampler_func = staticmethod(k_sampling.sample_dpmpp_2s_ancestral)
#
#
# class EulerSampler(KDiffusionSampler):
# short_name = SamplerName.K_EULER
# class EulerSampler(KDiffusionSolver):
# short_name = SolverName.K_EULER
# name = "Algorithm 2 (Euler steps) from Karras et al. (2022)"
# default_steps = 40
# sampler_func = staticmethod(k_sampling.sample_euler)
#
#
# class EulerAncestralSampler(KDiffusionSampler):
# short_name = SamplerName.K_EULER_ANCESTRAL
# class EulerAncestralSampler(KDiffusionSolver):
# short_name = SolverName.K_EULER_ANCESTRAL
# name = "Euler ancestral"
# default_steps = 40
# sampler_func = staticmethod(k_sampling.sample_euler_ancestral)
#
#
# class HeunSampler(KDiffusionSampler):
# short_name = SamplerName.K_HEUN
# class HeunSampler(KDiffusionSolver):
# short_name = SolverName.K_HEUN
# name = "Algorithm 2 (Heun steps) from Karras et al. (2022)."
# default_steps = 40
# sampler_func = staticmethod(k_sampling.sample_heun)
#
#
# class LMSSampler(KDiffusionSampler):
# short_name = SamplerName.K_LMS
# class LMSSampler(KDiffusionSolver):
# short_name = SolverName.K_LMS
# name = "LMS"
# default_steps = 40
# sampler_func = staticmethod(k_sampling.sample_lms)

@ -8,9 +8,9 @@ from tqdm import tqdm
from imaginairy.log_utils import increment_step, log_latent
from imaginairy.modules.diffusion.util import extract_into_tensor, noise_like
from imaginairy.samplers.base import (
ImageSampler,
ImageSolver,
NoiseSchedule,
SamplerName,
SolverName,
get_noise_prediction,
mask_blend,
)
@ -19,7 +19,7 @@ from imaginairy.utils import get_device
logger = logging.getLogger(__name__)
class PLMSSampler(ImageSampler):
class PLMSSolver(ImageSolver):
"""
probabilistic least-mean-squares.
@ -29,7 +29,7 @@ class PLMSSampler(ImageSampler):
https://github.com/luping-liu/PNDM
"""
short_name = SamplerName.PLMS
short_name = SolverName.PLMS
name = "probabilistic least-mean-squares sampler"
default_steps = 40

@ -7,43 +7,32 @@ import logging
import os.path
import random
from datetime import datetime, timezone
from enum import Enum
from io import BytesIO
from typing import TYPE_CHECKING, Any, List, Literal, Optional
from typing import TYPE_CHECKING, Any, List
from pydantic import (
BaseModel,
ConfigDict,
Field,
GetCoreSchemaHandler,
field_validator,
model_validator,
)
from pydantic_core import core_schema
from typing_extensions import Self
from imaginairy import config
if TYPE_CHECKING:
from pathlib import Path
from PIL import Image
else:
Image = Any
logger = logging.getLogger(__name__)
def save_image_as_base64(image: "Image.Image") -> str:
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_bytes = buffered.getvalue()
return base64.b64encode(img_bytes).decode()
def load_image_from_base64(image_str: str) -> "Image.Image":
from PIL import Image
img_bytes = base64.b64decode(image_str)
return Image.open(io.BytesIO(img_bytes))
class InvalidUrlError(ValueError):
pass
@ -52,7 +41,12 @@ class LazyLoadingImage:
"""Image file encoded as base64 string."""
def __init__(
self, *, filepath=None, url=None, img: Image = None, b64: Optional[str] = None
self,
*,
filepath=None,
url=None,
img: "Image.Image" = None,
b64: str | None = None,
):
if not filepath and not url and not img and not b64:
msg = "You must specify a url or filepath or img or base64 string"
@ -208,10 +202,10 @@ class LazyLoadingImage:
return f"<LazyLoadingImage RENDER EXCEPTION*{e}*>"
class ControlNetInput(BaseModel):
class ControlInput(BaseModel):
mode: str
image: Optional[LazyLoadingImage] = None
image_raw: Optional[LazyLoadingImage] = None
image: LazyLoadingImage | None = None
image_raw: LazyLoadingImage | None = None
strength: float = Field(1, ge=0, le=1000)
# @field_validator("image", "image_raw", mode="before")
@ -233,8 +227,8 @@ class ControlNetInput(BaseModel):
@field_validator("mode")
def mode_validate(cls, v):
if v not in config.CONTROLNET_CONFIG_SHORTCUTS:
valid_modes = list(config.CONTROLNET_CONFIG_SHORTCUTS.keys())
if v not in config.CONTROL_CONFIG_SHORTCUTS:
valid_modes = list(config.CONTROL_CONFIG_SHORTCUTS.keys())
valid_modes = ", ".join(valid_modes)
msg = f"Invalid controlnet mode: '{v}'. Valid modes are: {valid_modes}"
raise ValueError(msg)
@ -249,43 +243,51 @@ class WeightedPrompt(BaseModel):
return f"{self.weight}*({self.text})"
class MaskMode(str, Enum):
REPLACE = "replace"
KEEP = "keep"
MaskInput = MaskMode | str
PromptInput = str | WeightedPrompt | list[WeightedPrompt] | list[str] | None
class ImaginePrompt(BaseModel, protected_namespaces=()):
prompt: Optional[List[WeightedPrompt]] = Field(default=None, validate_default=True)
negative_prompt: Optional[List[WeightedPrompt]] = Field(
default=None, validate_default=True
)
prompt_strength: Optional[float] = Field(
default=7.5, le=10_000, ge=-10_000, validate_default=True
)
init_image: Optional[LazyLoadingImage] = Field(
model_config = ConfigDict(extra="forbid", validate_assignment=True)
prompt: List[WeightedPrompt] = Field(default=None, validate_default=True)
negative_prompt: List[WeightedPrompt] = Field(default=None, validate_default=True)
prompt_strength: float = Field(default=7.5, le=50, ge=-50, validate_default=True)
init_image: LazyLoadingImage | None = Field(
None, description="base64 encoded image", validate_default=True
)
init_image_strength: Optional[float] = Field(
init_image_strength: float | None = Field(
ge=0, le=1, default=None, validate_default=True
)
control_inputs: List[ControlNetInput] = Field(
control_inputs: List[ControlInput] = Field(
default_factory=list, validate_default=True
)
mask_prompt: Optional[str] = Field(
mask_prompt: str | None = Field(
default=None,
description="text description of the things to be masked",
validate_default=True,
)
mask_image: Optional[LazyLoadingImage] = Field(default=None, validate_default=True)
mask_mode: Optional[Literal["keep", "replace"]] = "replace"
mask_image: LazyLoadingImage | None = Field(default=None, validate_default=True)
mask_mode: MaskMode = MaskMode.REPLACE
mask_modify_original: bool = True
outpaint: Optional[str] = ""
model: str = Field(default=config.DEFAULT_MODEL, validate_default=True)
model_config_path: Optional[str] = None
sampler_type: str = Field(default=config.DEFAULT_SAMPLER, validate_default=True)
seed: Optional[int] = Field(default=None, validate_default=True)
steps: Optional[int] = Field(default=None, validate_default=True)
height: Optional[int] = Field(None, ge=1, le=100_000, validate_default=True)
width: Optional[int] = Field(None, ge=1, le=100_000, validate_default=True)
outpaint: str | None = ""
model_architecture: str | None = None
model_weights: str = Field(
default=config.DEFAULT_MODEL_WEIGHTS, validate_default=True
)
solver_type: str = Field(default=config.DEFAULT_SOLVER, validate_default=True)
seed: int | None = Field(default=None, validate_default=True)
steps: int | None = Field(default=None, validate_default=True)
size: tuple[int, int] | None = Field(default=None, validate_default=True)
upscale: bool = False
fix_faces: bool = False
fix_faces_fidelity: Optional[float] = Field(0.2, ge=0, le=1, validate_default=True)
conditioning: Optional[str] = None
fix_faces_fidelity: float | None = Field(0.2, ge=0, le=1, validate_default=True)
conditioning: str | None = None
tile_mode: str = ""
allow_compose_phase: bool = True
is_intermediate: bool = False
@ -294,22 +296,87 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
"", description="text to be overlaid on the image", validate_default=True
)
class MaskMode:
REPLACE = "replace"
KEEP = "keep"
def __init__(self, prompt=None, **kwargs):
# allows `prompt` to be positional
super().__init__(prompt=prompt, **kwargs)
def __init__(
self,
prompt: PromptInput = "",
*,
negative_prompt: PromptInput = None,
prompt_strength: float | None = 7.5,
init_image: LazyLoadingImage | None = None,
init_image_strength: float | None = None,
control_inputs: List[ControlInput] | None = None,
mask_prompt: str | None = None,
mask_image: LazyLoadingImage | None = None,
mask_mode: MaskInput = MaskMode.REPLACE,
mask_modify_original: bool = True,
outpaint: str | None = "",
model_architecture: str | None = None,
model_weights: str = config.DEFAULT_MODEL_WEIGHTS,
solver_type: str = config.DEFAULT_SOLVER,
seed: int | None = None,
steps: int | None = None,
size: int | str | tuple[int, int] | None = None,
upscale: bool = False,
fix_faces: bool = False,
fix_faces_fidelity: float | None = 0.2,
conditioning: str | None = None,
tile_mode: str = "",
allow_compose_phase: bool = True,
is_intermediate: bool = False,
collect_progress_latents: bool = False,
caption_text: str = "",
):
super().__init__(
prompt=prompt,
negative_prompt=negative_prompt,
prompt_strength=prompt_strength,
init_image=init_image,
init_image_strength=init_image_strength,
control_inputs=control_inputs,
mask_prompt=mask_prompt,
mask_image=mask_image,
mask_mode=mask_mode,
mask_modify_original=mask_modify_original,
outpaint=outpaint,
model_architecture=model_architecture,
model_weights=model_weights,
solver_type=solver_type,
seed=seed,
steps=steps,
size=size,
upscale=upscale,
fix_faces=fix_faces,
fix_faces_fidelity=fix_faces_fidelity,
conditioning=conditioning,
tile_mode=tile_mode,
allow_compose_phase=allow_compose_phase,
is_intermediate=is_intermediate,
collect_progress_latents=collect_progress_latents,
caption_text=caption_text,
)
@field_validator("prompt", "negative_prompt", mode="before")
@classmethod
def make_into_weighted_prompts(cls, v):
if isinstance(v, str):
v = [WeightedPrompt(text=v)]
elif isinstance(v, WeightedPrompt):
v = [v]
return v
def make_into_weighted_prompts(
cls,
value: PromptInput,
) -> list[WeightedPrompt]:
match value:
case None:
return []
case str():
if value:
return [WeightedPrompt(text=value)]
else:
return []
case WeightedPrompt():
return [value]
case list():
if all(isinstance(item, str) for item in value):
return [WeightedPrompt(text=p) for p in value]
elif all(isinstance(item, WeightedPrompt) for item in value):
return value
raise ValueError("Invalid prompt input")
@field_validator("prompt", "negative_prompt", mode="after")
@classmethod
@ -328,16 +395,20 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
@model_validator(mode="after")
def validate_negative_prompt(self):
if self.negative_prompt is None:
model_config = config.MODEL_CONFIG_SHORTCUTS.get(self.model, None)
if model_config:
self.negative_prompt = [
WeightedPrompt(text=model_config.default_negative_prompt)
]
else:
self.negative_prompt = [
WeightedPrompt(text=config.DEFAULT_NEGATIVE_PROMPT)
]
if (
self.negative_prompt == [WeightedPrompt(text="")]
or self.negative_prompt == []
):
model_weight_config = config.MODEL_WEIGHT_CONFIG_LOOKUP.get(
self.model_weights, None
)
default_negative_prompt = config.DEFAULT_NEGATIVE_PROMPT
if model_weight_config:
default_negative_prompt = model_weight_config.defaults.get(
"negative_prompt", default_negative_prompt
)
self.negative_prompt = [WeightedPrompt(text=default_negative_prompt)]
return self
@field_validator("prompt_strength")
@ -426,10 +497,10 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
raise ValueError(msg)
return v
@field_validator("model", mode="before")
@field_validator("model_weights", mode="before")
def set_default_diffusion_model(cls, v):
if v is None:
return config.DEFAULT_MODEL
return config.DEFAULT_MODEL_WEIGHTS
return v
@ -444,33 +515,32 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
return v
@field_validator("sampler_type", mode="after")
def validate_sampler_type(cls, v, info: core_schema.FieldValidationInfo):
from imaginairy.samplers import SamplerName
@field_validator("solver_type", mode="after")
def validate_solver_type(cls, v, info: core_schema.FieldValidationInfo):
from imaginairy.samplers import SolverName
if v is None:
v = config.DEFAULT_SAMPLER
v = config.DEFAULT_SOLVER
v = v.lower()
if info.data.get("model") == "SD-2.0-v" and v == SamplerName.PLMS:
raise ValueError("PLMS sampler is not supported for SD-2.0-v model.")
if info.data.get("model") == "SD-2.0-v" and v == SolverName.PLMS:
raise ValueError("PLMS solvers is not supported for SD-2.0-v model.")
if info.data.get("model") == "edit" and v in (
SamplerName.PLMS,
SamplerName.DDIM,
SolverName.PLMS,
SolverName.DDIM,
):
msg = "PLMS and DDIM samplers are not supported for pix2pix edit model."
msg = "PLMS and DDIM solvers are not supported for pix2pix edit model."
raise ValueError(msg)
return v
@field_validator("steps")
def validate_steps(cls, v, info: core_schema.FieldValidationInfo):
from imaginairy.samplers import SAMPLER_LOOKUP
steps_lookup = {"ddim": 50, "dpmpp": 20}
if v is None:
SamplerCls = SAMPLER_LOOKUP[info.data["sampler_type"]]
v = SamplerCls.default_steps
v = steps_lookup[info.data["solver_type"]]
return int(v)
@ -486,14 +556,26 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
return self
@field_validator("height", "width")
@field_validator("size", mode="before")
def validate_image_size(cls, v, info: core_schema.FieldValidationInfo):
from imaginairy.model_manager import get_model_default_image_size
from imaginairy.utils.named_resolutions import normalize_image_size
if v is None:
v = get_model_default_image_size(info.data["model"])
v = get_model_default_image_size(info.data["model_architecture"])
return v
width, height = normalize_image_size(v)
min_size = 8
max_size = 100_000
if not min_size <= width <= max_size:
msg = f"Width must be between {min_size} and {max_size}. Got: {width}"
raise ValueError(msg)
if not min_size <= height <= max_size:
msg = f"Height must be between {min_size} and {max_size}. Got: {height}"
raise ValueError(msg)
return width, height
@field_validator("caption_text", mode="before")
def validate_caption_text(cls, v):
@ -507,7 +589,7 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
return self.prompt
@property
def prompt_text(self):
def prompt_text(self) -> str:
if not self.prompt:
return ""
if len(self.prompt) == 1:
@ -515,18 +597,31 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
return "|".join(str(p) for p in self.prompt)
@property
def negative_prompt_text(self):
def negative_prompt_text(self) -> str:
if not self.negative_prompt:
return ""
if len(self.negative_prompt) == 1:
return self.negative_prompt[0].text
return "|".join(str(p) for p in self.negative_prompt)
@property
def width(self) -> int:
return self.size[0]
@property
def height(self) -> int:
return self.size[1]
def prompt_description(self):
return (
f'"{self.prompt_text}" {self.width}x{self.height}px '
f'negative-prompt:"{self.negative_prompt_text}" '
f"seed:{self.seed} prompt-strength:{self.prompt_strength} steps:{self.steps} sampler-type:{self.sampler_type} init-image-strength:{self.init_image_strength} model:{self.model}"
f"seed:{self.seed} "
f"prompt-strength:{self.prompt_strength} "
f"steps:{self.steps} solver-type:{self.solver_type} "
f"init-image-strength:{self.init_image_strength} "
f"arch:{self.model_architecture} "
f"weights: {self.model_weights}"
)
def logging_dict(self):
@ -547,7 +642,7 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
new_prompt = new_prompt.model_validate(dict(new_prompt))
return new_prompt
def make_concrete_copy(self):
def make_concrete_copy(self) -> Self:
seed = self.seed if self.seed is not None else random.randint(1, 1_000_000_000)
return self.full_copy(
deep=False,
@ -574,10 +669,6 @@ class ImagineResult:
prompt: ImaginePrompt,
is_nsfw,
safety_score,
upscaled_img=None,
modified_original=None,
mask_binary=None,
mask_grayscale=None,
result_images=None,
timings=None,
progress_latents=None,
@ -594,20 +685,10 @@ class ImagineResult:
self.images = {"generated": img}
if upscaled_img:
self.images["upscaled"] = upscaled_img
if modified_original:
self.images["modified_original"] = modified_original
if mask_binary:
self.images["mask_binary"] = mask_binary
if mask_grayscale:
self.images["mask_grayscale"] = mask_grayscale
if result_images:
for img_type, r_img in result_images.items():
if r_img is None:
continue
if isinstance(r_img, torch.Tensor):
if r_img.shape[1] == 4:
r_img = model_latent_to_pillow_img(r_img)
@ -620,7 +701,6 @@ class ImagineResult:
# for backward compat
self.img = img
self.upscaled_img = upscaled_img
self.is_nsfw = is_nsfw
self.safety_score = safety_score
@ -628,7 +708,7 @@ class ImagineResult:
self.torch_backend = get_device()
self.hardware_name = get_hardware_description(get_device())
def md5(self):
def md5(self) -> str:
return hashlib.md5(self.img.tobytes()).hexdigest()
def metadata_dict(self):
@ -636,19 +716,19 @@ class ImagineResult:
"prompt": self.prompt.logging_dict(),
}
def timings_str(self):
def timings_str(self) -> str:
if not self.timings:
return ""
return " ".join(f"{k}:{v:.2f}s" for k, v in self.timings.items())
def _exif(self):
def _exif(self) -> "Image.Exif":
from PIL import Image
exif = Image.Exif()
exif[ExifCodes.ImageDescription] = self.prompt.prompt_description()
exif[ExifCodes.UserComment] = json.dumps(self.metadata_dict())
# help future web scrapes not ingest AI generated art
sd_version = self.prompt.model
sd_version = self.prompt.model_weights
if len(sd_version) > 20:
sd_version = "custom weights"
exif[ExifCodes.Software] = f"Imaginairy / Stable Diffusion {sd_version}"
@ -656,7 +736,7 @@ class ImagineResult:
exif[ExifCodes.HostComputer] = f"{self.torch_backend}:{self.hardware_name}"
return exif
def save(self, save_path, image_type="generated"):
def save(self, save_path: "Path | str", image_type: str = "generated") -> None:
img = self.images.get(image_type, None)
if img is None:
msg = f"Image of type {image_type} not stored. Options are: {self.images.keys()}"
@ -665,6 +745,6 @@ class ImagineResult:
img.convert("RGB").save(save_path, exif=self._exif())
class SafetyMode:
class SafetyMode(str, Enum):
STRICT = "strict"
RELAXED = "relaxed"

@ -10,7 +10,7 @@ from imaginairy import ImaginePrompt, LazyLoadingImage, imagine_image_files
from imaginairy.animations import make_gif_animation
from imaginairy.enhancers.facecrop import detect_faces
from imaginairy.img_utils import add_caption_to_image, pillow_fit_image_within
from imaginairy.schema import ControlNetInput
from imaginairy.schema import ControlInput
preserve_head_kwargs = {
"mask_prompt": "head|face",
@ -142,7 +142,7 @@ def surprise_me_prompts(
for prompt_text, strength, kwargs in generic_prompts:
if use_controlnet:
strength = 5
control_input = ControlNetInput(mode="edit", strength=2)
control_input = ControlInput(mode="edit", strength=2)
prompts.append(
ImaginePrompt(
prompt_text,
@ -163,7 +163,7 @@ def surprise_me_prompts(
prompt_text,
init_image=img,
prompt_strength=strength,
model="edit",
model_weights="edit",
steps=steps,
width=width,
height=height,
@ -178,7 +178,7 @@ def surprise_me_prompts(
for prompt_subconfig in prompt_subconfigs:
prompt_text, strength, kwargs = prompt_subconfig
if use_controlnet:
control_input = ControlNetInput(
control_input = ControlInput(
mode="edit",
)
prompts.append(
@ -201,7 +201,7 @@ def surprise_me_prompts(
prompt_text,
init_image=img,
prompt_strength=strength,
model="edit",
model_weights="edit",
steps=steps,
width=width,
height=height,

@ -0,0 +1,533 @@
import datetime
import logging
import os
import signal
import time
from functools import partial
import numpy as np
import pytorch_lightning as pl
import torch
import torchvision
from omegaconf import OmegaConf
from PIL import Image
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import Callback, LearningRateMonitor
try:
from pytorch_lightning.strategies import DDPStrategy
except ImportError:
# let's not break all of imaginairy just because a training import doesn't exist in an older version of PL
# Use >= 1.6.0 to make this work
DDPStrategy = None
import contextlib
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.utilities import rank_zero_info
from pytorch_lightning.utilities.distributed import rank_zero_only
from torch.utils.data import DataLoader, Dataset
from imaginairy import config
from imaginairy.model_manager import get_diffusion_model
from imaginairy.training_tools.single_concept import SingleConceptDataset
from imaginairy.utils import get_device, instantiate_from_config
mod_logger = logging.getLogger(__name__)
referenced_by_string = [LearningRateMonitor]
class WrappedDataset(Dataset):
"""Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset."""
def __init__(self, dataset):
self.data = dataset
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
def worker_init_fn(_):
worker_info = torch.utils.data.get_worker_info()
dataset = worker_info.dataset
worker_id = worker_info.id
if isinstance(dataset, SingleConceptDataset):
# split_size = dataset.num_records // worker_info.num_workers
# reset num_records to the true number to retain reliable length information
# dataset.sample_ids = dataset.valid_ids[
# worker_id * split_size : (worker_id + 1) * split_size
# ]
current_id = np.random.choice(len(np.random.get_state()[1]), 1)
return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
return np.random.seed(np.random.get_state()[1][0] + worker_id)
class DataModuleFromConfig(pl.LightningDataModule):
def __init__(
self,
batch_size,
train=None,
validation=None,
test=None,
predict=None,
wrap=False,
num_workers=None,
shuffle_test_loader=False,
use_worker_init_fn=False,
shuffle_val_dataloader=False,
num_val_workers=0,
):
super().__init__()
self.batch_size = batch_size
self.dataset_configs = {}
self.num_workers = num_workers if num_workers is not None else batch_size * 2
if num_val_workers is None:
self.num_val_workers = self.num_workers
else:
self.num_val_workers = num_val_workers
self.use_worker_init_fn = use_worker_init_fn
if train is not None:
self.dataset_configs["train"] = train
self.train_dataloader = self._train_dataloader
if validation is not None:
self.dataset_configs["validation"] = validation
self.val_dataloader = partial(
self._val_dataloader, shuffle=shuffle_val_dataloader
)
if test is not None:
self.dataset_configs["test"] = test
self.test_dataloader = partial(
self._test_dataloader, shuffle=shuffle_test_loader
)
if predict is not None:
self.dataset_configs["predict"] = predict
self.predict_dataloader = self._predict_dataloader
self.wrap = wrap
self.datasets = None
def prepare_data(self):
for data_cfg in self.dataset_configs.values():
instantiate_from_config(data_cfg)
def setup(self, stage=None):
self.datasets = {
k: instantiate_from_config(c) for k, c in self.dataset_configs.items()
}
if self.wrap:
self.datasets = {k: WrappedDataset(v) for k, v in self.datasets.items()}
def _train_dataloader(self):
is_iterable_dataset = isinstance(self.datasets["train"], SingleConceptDataset)
if is_iterable_dataset or self.use_worker_init_fn:
pass
else:
pass
return DataLoader(
self.datasets["train"],
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
worker_init_fn=worker_init_fn,
)
def _val_dataloader(self, shuffle=False):
if (
isinstance(self.datasets["validation"], SingleConceptDataset)
or self.use_worker_init_fn
):
init_fn = worker_init_fn
else:
init_fn = None
return DataLoader(
self.datasets["validation"],
batch_size=self.batch_size,
num_workers=self.num_val_workers,
worker_init_fn=init_fn,
shuffle=shuffle,
)
def _test_dataloader(self, shuffle=False):
is_iterable_dataset = isinstance(self.datasets["train"], SingleConceptDataset)
if is_iterable_dataset or self.use_worker_init_fn:
init_fn = worker_init_fn
else:
init_fn = None
is_iterable_dataset = False
# do not shuffle dataloader for iterable dataset
shuffle = shuffle and (not is_iterable_dataset)
return DataLoader(
self.datasets["test"],
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=init_fn,
shuffle=shuffle,
)
def _predict_dataloader(self, shuffle=False):
if (
isinstance(self.datasets["predict"], SingleConceptDataset)
or self.use_worker_init_fn
):
init_fn = worker_init_fn
else:
init_fn = None
return DataLoader(
self.datasets["predict"],
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=init_fn,
)
class SetupCallback(Callback):
def __init__(
self,
resume,
now,
logdir,
ckptdir,
cfgdir,
):
super().__init__()
self.resume = resume
self.now = now
self.logdir = logdir
self.ckptdir = ckptdir
self.cfgdir = cfgdir
def on_keyboard_interrupt(self, trainer, pl_module):
if trainer.global_rank == 0:
mod_logger.info("Stopping execution and saving final checkpoint.")
ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
trainer.save_checkpoint(ckpt_path)
def on_fit_start(self, trainer, pl_module):
if trainer.global_rank == 0:
# Create logdirs and save configs
os.makedirs(self.logdir, exist_ok=True)
os.makedirs(self.ckptdir, exist_ok=True)
os.makedirs(self.cfgdir, exist_ok=True)
else:
# ModelCheckpoint callback created log directory --- remove it
if not self.resume and os.path.exists(self.logdir):
dst, name = os.path.split(self.logdir)
dst = os.path.join(dst, "child_runs", name)
os.makedirs(os.path.split(dst)[0], exist_ok=True)
with contextlib.suppress(FileNotFoundError):
os.rename(self.logdir, dst)
class ImageLogger(Callback):
def __init__(
self,
batch_frequency,
max_images,
clamp=True,
increase_log_steps=True,
rescale=True,
disabled=False,
log_on_batch_idx=False,
log_first_step=False,
log_images_kwargs=None,
log_all_val=False,
concept_label=None,
):
super().__init__()
self.rescale = rescale
self.batch_freq = batch_frequency
self.max_images = max_images
self.logger_log_images = {}
self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)]
if not increase_log_steps:
self.log_steps = [self.batch_freq]
self.clamp = clamp
self.disabled = disabled
self.log_on_batch_idx = log_on_batch_idx
self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
self.log_first_step = log_first_step
self.log_all_val = log_all_val
self.concept_label = concept_label
@rank_zero_only
def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx):
root = os.path.join(save_dir, "logs", "images", split)
for k in images:
grid = torchvision.utils.make_grid(images[k], nrow=4)
if self.rescale:
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
grid = grid.numpy()
grid = (grid * 255).astype(np.uint8)
filename = (
f"{k}_gs-{global_step:06}_e-{current_epoch:06}_b-{batch_idx:06}.png"
)
path = os.path.join(root, filename)
os.makedirs(os.path.split(path)[0], exist_ok=True)
Image.fromarray(grid).save(path)
def log_img(self, pl_module, batch, batch_idx, split="train"):
# always generate the concept label
batch["txt"][0] = self.concept_label
check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
if self.log_all_val and split == "val":
should_log = True
else:
should_log = self.check_frequency(check_idx)
if (
should_log
and (batch_idx % self.batch_freq == 0)
and hasattr(pl_module, "log_images")
and callable(pl_module.log_images)
and self.max_images > 0
):
logger = type(pl_module.logger)
is_train = pl_module.training
if is_train:
pl_module.eval()
with torch.no_grad():
images = pl_module.log_images(
batch, split=split, **self.log_images_kwargs
)
for k in images:
N = min(images[k].shape[0], self.max_images)
images[k] = images[k][:N]
if isinstance(images[k], torch.Tensor):
images[k] = images[k].detach().cpu()
if self.clamp:
images[k] = torch.clamp(images[k], -1.0, 1.0)
self.log_local(
pl_module.logger.save_dir,
split,
images,
pl_module.global_step,
pl_module.current_epoch,
batch_idx,
)
logger_log_images = self.logger_log_images.get(
logger, lambda *args, **kwargs: None
)
logger_log_images(pl_module, images, pl_module.global_step, split)
if is_train:
pl_module.train()
def check_frequency(self, check_idx):
if (check_idx % self.batch_freq) == 0 and (
check_idx > 0 or self.log_first_step
):
return True
return False
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
self.log_img(pl_module, batch, batch_idx, split="train")
def on_validation_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
):
if not self.disabled and pl_module.global_step > 0:
self.log_img(pl_module, batch, batch_idx, split="val")
if (
hasattr(pl_module, "calibrate_grad_norm")
and (pl_module.calibrate_grad_norm and batch_idx % 25 == 0)
and batch_idx > 0
):
self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
class CUDACallback(Callback):
# see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py
def on_train_epoch_start(self, trainer, pl_module):
# Reset the memory use counter
if "cuda" in get_device():
torch.cuda.reset_peak_memory_stats(trainer.strategy.root_device.index)
torch.cuda.synchronize(trainer.strategy.root_device.index)
self.start_time = time.time()
def on_train_epoch_end(self, trainer, pl_module):
if "cuda" in get_device():
torch.cuda.synchronize(trainer.strategy.root_device.index)
max_memory = (
torch.cuda.max_memory_allocated(trainer.strategy.root_device.index)
/ 2**20
)
epoch_time = time.time() - self.start_time
try:
max_memory = trainer.training_type_plugin.reduce(max_memory)
epoch_time = trainer.training_type_plugin.reduce(epoch_time)
rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB")
except AttributeError:
pass
def train_diffusion_model(
concept_label,
concept_images_dir,
class_label,
class_images_dir,
weights_location=config.DEFAULT_MODEL_WEIGHTS,
logdir="logs",
learning_rate=1e-6,
accumulate_grad_batches=32,
resume=None,
):
"""
Train a diffusion model on a single concept.
accumulate_grad_batches used to simulate a bigger batch size - https://arxiv.org/pdf/1711.00489.pdf
"""
if DDPStrategy is None:
msg = "Please install pytorch-lightning>=1.6.0 to train a model"
raise ImportError(msg)
batch_size = 1
seed = 23
num_workers = 1
num_val_workers = 0
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") # noqa: DTZ005
logdir = os.path.join(logdir, now)
ckpt_output_dir = os.path.join(logdir, "checkpoints")
cfg_output_dir = os.path.join(logdir, "configs")
seed_everything(seed)
model = get_diffusion_model(
weights_location=weights_location, half_mode=False, for_training=True
)._model
model.learning_rate = learning_rate * accumulate_grad_batches * batch_size
# add callback which sets up log directory
default_callbacks_cfg = {
"setup_callback": {
"target": "imaginairy.train.SetupCallback",
"params": {
"resume": False,
"now": now,
"logdir": logdir,
"ckptdir": ckpt_output_dir,
"cfgdir": cfg_output_dir,
},
},
"image_logger": {
"target": "imaginairy.train.ImageLogger",
"params": {
"batch_frequency": 10,
"max_images": 1,
"clamp": True,
"increase_log_steps": False,
"log_first_step": True,
"log_all_val": True,
"concept_label": concept_label,
"log_images_kwargs": {
"use_ema_scope": True,
"inpaint": False,
"plot_progressive_rows": False,
"plot_diffusion_rows": False,
"N": 1,
"unconditional_guidance_scale:": 7.5,
"unconditional_guidance_label": [""],
"ddim_steps": 20,
},
},
},
"learning_rate_logger": {
"target": "imaginairy.train.LearningRateMonitor",
"params": {
"logging_interval": "step",
# "log_momentum": True
},
},
"cuda_callback": {"target": "imaginairy.train.CUDACallback"},
}
default_modelckpt_cfg = {
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
"params": {
"dirpath": ckpt_output_dir,
"filename": "{epoch:06}",
"verbose": True,
"save_last": True,
"every_n_train_steps": 50,
"save_top_k": -1,
"monitor": None,
},
}
modelckpt_cfg = OmegaConf.create(default_modelckpt_cfg)
default_callbacks_cfg.update({"checkpoint_callback": modelckpt_cfg})
callbacks_cfg = OmegaConf.create(default_callbacks_cfg)
dataset_config = {
"concept_label": concept_label,
"concept_images_dir": concept_images_dir,
"class_label": class_label,
"class_images_dir": class_images_dir,
"image_transforms": [
{
"target": "torchvision.transforms.Resize",
"params": {"size": 512, "interpolation": 3},
},
{"target": "torchvision.transforms.RandomCrop", "params": {"size": 512}},
],
}
data_module_config = {
"batch_size": batch_size,
"num_workers": num_workers,
"num_val_workers": num_val_workers,
"train": {
"target": "imaginairy.training_tools.single_concept.SingleConceptDataset",
"params": dataset_config,
},
}
trainer = Trainer(
benchmark=True,
num_sanity_val_steps=0,
accumulate_grad_batches=accumulate_grad_batches,
strategy=DDPStrategy(),
callbacks=[instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg],
gpus=1,
default_root_dir=".",
)
trainer.logdir = logdir
data = DataModuleFromConfig(**data_module_config)
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
# calling these ourselves should not be necessary but it is.
# lightning still takes care of proper multiprocessing though
data.prepare_data()
data.setup()
def melk(*args, **kwargs):
if trainer.global_rank == 0:
mod_logger.info("Summoning checkpoint.")
ckpt_path = os.path.join(ckpt_output_dir, "last.ckpt")
trainer.save_checkpoint(ckpt_path)
signal.signal(signal.SIGUSR1, melk)
try:
try:
trainer.fit(model, data)
except Exception:
melk()
raise
finally:
mod_logger.info(trainer.profiler.summary())

@ -170,7 +170,21 @@ def replace_value_at_path(data, path, new_value):
parent = get_path(data, path[:-1])
last_key = path[-1]
if new_value == NODE_DELETE:
del parent[last_key]
if isinstance(parent, tuple):
grandparent = get_path(data, path[:-2])
grandparent_key = path[-2]
new_parent = list(parent)
del new_parent[last_key]
grandparent[grandparent_key] = tuple(new_parent)
else:
del parent[last_key]
else:
parent[last_key] = new_value
if isinstance(parent, tuple):
grandparent = get_path(data, path[:-2])
grandparent_key = path[-2]
new_parent = list(parent)
new_parent[last_key] = new_value
grandparent[grandparent_key] = tuple(new_parent)
else:
parent[last_key] = new_value
return data

@ -43,27 +43,37 @@ _NAMED_RESOLUTIONS = {
"SVD": (1024, 576), # stable video diffusion
}
_NAMED_RESOLUTIONS = {k.upper(): v for k, v in _NAMED_RESOLUTIONS.items()}
def get_named_resolution(resolution: str):
resolution = resolution.upper()
size = _NAMED_RESOLUTIONS.get(resolution)
if size is None:
# is it WIDTHxHEIGHT format?
try:
width, height = resolution.split("X")
size = (int(width), int(height))
except ValueError:
pass
if size is None:
# is it just a single number?
with contextlib.suppress(ValueError):
size = (int(resolution), int(resolution))
if size is None:
msg = f"Unknown resolution: {resolution}"
def normalize_image_size(resolution: str | int | tuple[int, int]) -> tuple[int, int]:
match resolution:
case (int(), int()):
size = resolution
case int():
size = resolution, resolution
case str():
resolution = resolution.strip().upper()
resolution = resolution.replace(" ", "").replace("X", ",").replace("*", ",")
size = _NAMED_RESOLUTIONS.get(resolution.upper())
if size is None:
# is it WIDTH,HEIGHT format?
try:
width, height = resolution.split(",")
size = int(width), int(height)
except ValueError:
pass
if size is None:
# is it just a single number?
with contextlib.suppress(ValueError):
size = (int(resolution), int(resolution))
if size is None:
msg = f"Invalid resolution: '{resolution}'"
raise ValueError(msg)
case _:
msg = f"Invalid resolution: {resolution!r}"
raise ValueError(msg)
if size[0] <= 0 or size[1] <= 0:
msg = f"Invalid resolution: {resolution!r}"
raise ValueError(msg)
return size

@ -87,7 +87,7 @@ def generate_video(
device="cpu",
num_frames=num_frames,
num_steps=num_steps,
weights_url=video_model_config["weights_url"],
weights_url=video_model_config["weights_location"],
)
torch.manual_seed(seed)

@ -3,7 +3,7 @@ import safetensors
from imaginairy.model_manager import (
get_cached_url_path,
open_weights,
resolve_model_paths,
resolve_model_weights_config,
)
from imaginairy.weight_management import utils
from imaginairy.weight_management.pattern_collapse import find_state_dict_key_patterns
@ -11,15 +11,12 @@ from imaginairy.weight_management.utils import save_model_info
def save_compvis_patterns():
(
model_metadata,
weights_url,
config_path,
control_weights_paths,
) = resolve_model_paths(
weights_path="openjourney-v1",
model_weights_config = resolve_model_weights_config(
model_weights="openjourney-v1",
)
weights_path = get_cached_url_path(
model_weights_config.weights_location, category="weights"
)
weights_path = get_cached_url_path(weights_url, category="weights")
with safetensors.safe_open(weights_path, "pytorch") as f:
weights_keys = f.keys()
@ -98,7 +95,7 @@ def save_weight_info(
model_name, component_name, format_name, weights_url=None, weights_keys=None
):
if weights_keys is None and weights_url is None:
msg = "Either weights_keys or weights_url must be provided"
msg = "Either weights_keys or weights_location must be provided"
raise ValueError(msg)
if weights_keys is None:

@ -12,7 +12,6 @@ from urllib3 import HTTPConnectionPool
from imaginairy import ImaginePrompt, api, imagine
from imaginairy.log_utils import configure_logging, suppress_annoying_logs_and_warnings
from imaginairy.samplers import SAMPLER_TYPE_OPTIONS
from imaginairy.utils import (
fix_torch_group_norm,
fix_torch_nn_layer_norm,
@ -26,13 +25,13 @@ if "pytest" in str(sys.argv):
logger = logging.getLogger(__name__)
SAMPLERS_FOR_TESTING = SAMPLER_TYPE_OPTIONS
if get_device() == "mps:0":
SAMPLERS_FOR_TESTING = ["plms", "k_euler_a"]
elif get_device() == "cpu":
SAMPLERS_FOR_TESTING = []
# SOLVERS_FOR_TESTING = SOLVER_TYPE_OPTIONS
# if get_device() == "mps:0":
# SOLVERS_FOR_TESTING = ["plms", "k_euler_a"]
# elif get_device() == "cpu":
# SOLVERS_FOR_TESTING = []
SAMPLERS_FOR_TESTING = ["ddim", "k_dpmpp_2m"]
SOLVERS_FOR_TESTING = ["ddim", "dpmpp"]
@pytest.fixture(scope="session", autouse=True)
@ -90,8 +89,8 @@ def filename_base_for_orig_outputs(request):
return filename_base
@pytest.fixture(params=SAMPLERS_FOR_TESTING)
def sampler_type(request):
@pytest.fixture(params=SOLVERS_FOR_TESTING)
def solver_type(request):
return request.param
@ -118,11 +117,10 @@ def default_model_loaded():
"""
prompt = ImaginePrompt(
"dogs lying on a hot pink couch",
width=64,
height=64,
size=64,
steps=2,
seed=1,
sampler_type="ddim",
solver_type="ddim",
)
next(imagine(prompt))

Binary file not shown.

After

Width:  |  Height:  |  Size: 569 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 308 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 248 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 392 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 281 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 276 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 278 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 239 KiB

@ -6,22 +6,22 @@ from imaginairy import LazyLoadingImage
from imaginairy.api import imagine, imagine_image_files
from imaginairy.img_processors.control_modes import CONTROL_MODES
from imaginairy.img_utils import pillow_fit_image_within
from imaginairy.schema import ControlNetInput, ImaginePrompt
from imaginairy.schema import ControlInput, ImaginePrompt, MaskMode
from imaginairy.utils import get_device
from . import TESTS_FOLDER
from .utils import assert_image_similar_to_expectation
def test_imagine(sampler_type, filename_base_for_outputs):
def test_imagine(solver_type, filename_base_for_outputs):
prompt_text = "a scenic old-growth forest with diffuse light poking through the canopy. high resolution nature photography"
prompt = ImaginePrompt(
prompt_text, width=512, height=512, steps=20, seed=1, sampler_type=sampler_type
prompt_text, size=512, steps=20, seed=1, solver_type=solver_type
)
result = next(imagine(prompt))
threshold_lookup = {"k_dpm_2_a": 26000}
threshold = threshold_lookup.get(sampler_type, 10000)
threshold = threshold_lookup.get(solver_type, 10000)
img_path = f"{filename_base_for_outputs}.png"
assert_image_similar_to_expectation(
@ -49,25 +49,25 @@ def test_model_versions(filename_base_for_orig_outputs, model_version):
ImaginePrompt(
prompt_text,
seed=1,
model=model_version,
model_weights=model_version,
)
)
threshold = 35000
for i, result in enumerate(imagine(prompts)):
img_path = f"{filename_base_for_orig_outputs}_{result.prompt.prompt_text}_{result.prompt.model}.png"
results = list(imagine(prompts))
for i, result in enumerate(results):
img_path = f"{filename_base_for_orig_outputs}_{result.prompt.prompt_text}_{result.prompt.model_weights}.png"
result.img.save(img_path)
for i, result in enumerate(imagine(prompts)):
img_path = f"{filename_base_for_orig_outputs}_{result.prompt.prompt_text}_{result.prompt.model}.png"
for i, result in enumerate(results):
img_path = f"{filename_base_for_orig_outputs}_{result.prompt.prompt_text}_{result.prompt.model_weights}.png"
assert_image_similar_to_expectation(
result.img, img_path=img_path, threshold=threshold
)
def test_img2img_beach_to_sunset(
sampler_type, filename_base_for_outputs, filename_base_for_orig_outputs
solver_type, filename_base_for_outputs, filename_base_for_orig_outputs
):
img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/beach_at_sainte_adresse.jpg")
prompt = ImaginePrompt(
@ -77,11 +77,10 @@ def test_img2img_beach_to_sunset(
prompt_strength=15,
mask_prompt="(sky|clouds) AND !(buildings|trees)",
mask_mode="replace",
width=512,
height=512,
size=512,
steps=40 * 2,
seed=1,
sampler_type=sampler_type,
solver_type=solver_type,
)
result = next(imagine(prompt))
@ -91,7 +90,7 @@ def test_img2img_beach_to_sunset(
def test_img_to_img_from_url_cats(
sampler_type,
solver_type,
filename_base_for_outputs,
mocked_responses,
filename_base_for_orig_outputs,
@ -113,11 +112,10 @@ def test_img_to_img_from_url_cats(
"dogs lying on a hot pink couch",
init_image=img,
init_image_strength=0.5,
width=512,
height=512,
size=512,
steps=50,
seed=1,
sampler_type=sampler_type,
solver_type=solver_type,
)
result = next(imagine(prompt))
@ -130,7 +128,7 @@ def test_img_to_img_from_url_cats(
def test_img2img_low_noise(
filename_base_for_outputs,
sampler_type,
solver_type,
):
fruit_path = os.path.join(TESTS_FOLDER, "data", "bowl_of_fruit.jpg")
img = LazyLoadingImage(filepath=fruit_path)
@ -144,17 +142,18 @@ def test_img2img_low_noise(
mask_mode="replace",
# steps=40,
seed=1,
sampler_type=sampler_type,
solver_type=solver_type,
)
result = next(imagine(prompt))
threshold_lookup = {
"dpmpp": 26000,
"k_dpm_2_a": 26000,
"k_euler_a": 18000,
"k_dpm_adaptive": 13000,
}
threshold = threshold_lookup.get(sampler_type, 14000)
threshold = threshold_lookup.get(solver_type, 14000)
img_path = f"{filename_base_for_outputs}.png"
assert_image_similar_to_expectation(
@ -165,7 +164,7 @@ def test_img2img_low_noise(
@pytest.mark.parametrize("init_strength", [0, 0.05, 0.2, 1])
def test_img_to_img_fruit_2_gold(
filename_base_for_outputs,
sampler_type,
solver_type,
init_strength,
filename_base_for_orig_outputs,
):
@ -183,7 +182,7 @@ def test_img_to_img_fruit_2_gold(
mask_mode="replace",
steps=needed_steps,
seed=1,
sampler_type=sampler_type,
solver_type=solver_type,
)
result = next(imagine(prompt))
@ -194,7 +193,7 @@ def test_img_to_img_fruit_2_gold(
"k_dpm_adaptive": 13000,
"k_dpmpp_2s": 16000,
}
threshold = threshold_lookup.get(sampler_type, 16000)
threshold = threshold_lookup.get(solver_type, 16000)
pillow_fit_image_within(img).save(f"{filename_base_for_orig_outputs}__orig.jpg")
img_path = f"{filename_base_for_outputs}.png"
@ -227,7 +226,7 @@ def test_img_to_img_fruit_2_gold_repeat():
]
for result in imagine(prompts, debug_img_callback=None):
result.img.save(
f"{TESTS_FOLDER}/test_output/img2img_fruit_2_gold_{result.prompt.sampler_type}_{get_device()}_run-{run_count:02}.jpg"
f"{TESTS_FOLDER}/test_output/img2img_fruit_2_gold_{result.prompt.solver_type}_{get_device()}_run-{run_count:02}.jpg"
)
run_count += 1
@ -236,9 +235,8 @@ def test_img_to_img_fruit_2_gold_repeat():
def test_img_to_file():
prompt = ImaginePrompt(
"an old growth forest, diffuse light poking through the canopy. high-resolution, nature photography, nat geo photo",
width=512 + 64,
height=512 - 64,
steps=20,
size=(512 + 64, 512 - 64),
steps=2,
seed=2,
upscale=True,
)
@ -254,8 +252,7 @@ def test_inpainting_bench(filename_base_for_outputs, filename_base_for_orig_outp
init_image=img,
init_image_strength=0.4,
mask_image=LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/bench2_mask.png"),
width=512,
height=512,
size=512,
steps=40,
seed=1,
)
@ -279,9 +276,8 @@ def test_cliptext_inpainting_pearl_doctor(
init_image=img,
init_image_strength=0.2,
mask_prompt="face AND NOT (bandana OR hair OR blue fabric){*5}",
mask_mode=ImaginePrompt.MaskMode.KEEP,
width=512,
height=512,
mask_mode=MaskMode.KEEP,
size=512,
steps=40,
seed=181509347,
)
@ -297,8 +293,7 @@ def test_tile_mode(filename_base_for_outputs):
prompt_text = "gold coins"
prompt = ImaginePrompt(
prompt_text,
width=400,
height=400,
size=400,
steps=15,
seed=1,
tile_mode="xy",
@ -317,7 +312,7 @@ control_modes = list(CONTROL_MODES.keys())
def test_controlnet(filename_base_for_outputs, control_mode):
prompt_text = "a photo of a woman sitting on a bench"
img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/bench2.png")
control_input = ControlNetInput(
control_input = ControlInput(
mode=control_mode,
image=img,
)
@ -327,30 +322,27 @@ def test_controlnet(filename_base_for_outputs, control_mode):
prompt_text = "a wise old man"
seed = 1
mask_image = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/bench2_mask.png")
control_input = ControlNetInput(
control_input = ControlInput(
mode=control_mode,
image=mask_image,
)
prompt = ImaginePrompt(
prompt_text,
width=512,
height=512,
size=512,
steps=45,
seed=seed,
init_image=img,
init_image_strength=0,
control_inputs=[control_input],
fix_faces=True,
sampler="ddim",
solver_type="ddim",
)
prompt.steps = 1
prompt.width = 256
prompt.height = 256
prompt.size = 256
result = next(imagine(prompt))
prompt.steps = 15
prompt.width = 512
prompt.height = 512
prompt.size = 512
result = next(imagine(prompt))
img_path = f"{filename_base_for_outputs}.png"
@ -365,8 +357,7 @@ def test_large_image(filename_base_for_outputs):
prompt_text = "a stormy ocean. oil painting"
prompt = ImaginePrompt(
prompt_text,
width=1920,
height=1080,
size="1080p",
steps=30,
seed=0,
)

@ -72,8 +72,7 @@ def test_edit_demo(monkeypatch):
ImaginePrompt(
"",
steps=1,
width=256,
height=256,
size=256,
# model="empty",
)
]
@ -89,7 +88,7 @@ def test_edit_demo(monkeypatch):
f"{TESTS_FOLDER}/test_output",
],
)
assert result.exit_code == 0
assert result.exit_code == 0, result.stdout
def test_upscale(monkeypatch):

@ -1,6 +1,6 @@
from imaginairy import config
from imaginairy.samplers import SAMPLER_TYPE_OPTIONS
def test_sampler_options():
assert set(config.SAMPLER_TYPE_OPTIONS) == set(SAMPLER_TYPE_OPTIONS)
# from imaginairy import config
# from imaginairy.samplers import SOLVER_TYPE_OPTIONS
#
#
# def test_sampler_options():
# assert set(config.SOLVER_TYPE_NAMES) == set(SOLVER_TYPE_OPTIONS)

@ -58,7 +58,7 @@ def test_clip_masking(filename_base_for_outputs):
upscale=False,
fix_faces=True,
seed=42,
# sampler_type="plms",
# solver_type="plms",
)
result = next(imagine(prompt))

@ -1,24 +1,25 @@
from imaginairy import config
from imaginairy.model_manager import resolve_model_paths
from imaginairy.model_manager import resolve_model_weights_config
def test_resolved_paths():
"""Test that the resolved model path is correct."""
(
model_metadata,
weights_path,
config_path,
control_weights_path,
) = resolve_model_paths()
assert model_metadata.short_name == config.DEFAULT_MODEL
assert model_metadata.config_path == config_path
default_config_path = config_path
model_weights_config = resolve_model_weights_config(config.DEFAULT_MODEL_WEIGHTS)
assert config.DEFAULT_MODEL_WEIGHTS.lower() in model_weights_config.aliases
assert (
config.DEFAULT_MODEL_ARCHITECTURE in model_weights_config.architecture.aliases
)
(
model_metadata,
weights_path,
config_path,
control_weights_path,
) = resolve_model_paths(weights_path="foo.ckpt")
assert weights_path == "foo.ckpt"
assert config_path == default_config_path
model_weights_config = resolve_model_weights_config(
model_weights="foo.ckpt",
default_model_architecture="sd15",
)
print(model_weights_config)
assert model_weights_config.aliases == []
assert "sd15" in model_weights_config.architecture.aliases
model_weights_config = resolve_model_weights_config(
model_weights="foo.ckpt", default_model_architecture="sd15", for_inpainting=True
)
assert model_weights_config.aliases == []
assert "sd15-inpaint" in model_weights_config.architecture.aliases

@ -2,7 +2,7 @@ import pytest
from pydantic import ValidationError
from imaginairy import LazyLoadingImage
from imaginairy.schema import ControlNetInput
from imaginairy.schema import ControlInput
from tests import TESTS_FOLDER
@ -12,29 +12,29 @@ def _lazy_img():
def test_controlnetinput_basic(lazy_img):
ControlNetInput(mode="canny", image=lazy_img)
ControlNetInput(mode="canny", image_raw=lazy_img)
ControlInput(mode="canny", image=lazy_img)
ControlInput(mode="canny", image_raw=lazy_img)
def test_controlnetinput_invalid_mode(lazy_img):
with pytest.raises(ValueError, match=r".*Invalid controlnet mode.*"):
ControlNetInput(mode="pizza", image=lazy_img)
ControlInput(mode="pizza", image=lazy_img)
def test_controlnetinput_both_images(lazy_img):
with pytest.raises(ValueError, match=r".*cannot specify both.*"):
ControlNetInput(mode="canny", image=lazy_img, image_raw=lazy_img)
ControlInput(mode="canny", image=lazy_img, image_raw=lazy_img)
def test_controlnetinput_filepath_input(lazy_img):
"""Test that we accept filepaths here."""
c = ControlNetInput(mode="canny", image=f"{TESTS_FOLDER}/data/red.png")
c = ControlInput(mode="canny", image=f"{TESTS_FOLDER}/data/red.png")
c.image.convert("RGB")
c = ControlNetInput(mode="canny", image_raw=f"{TESTS_FOLDER}/data/red.png")
c = ControlInput(mode="canny", image_raw=f"{TESTS_FOLDER}/data/red.png")
c.image_raw.convert("RGB")
def test_controlnetinput_big(lazy_img):
ControlNetInput(mode="canny", strength=2)
ControlInput(mode="canny", strength=2)
with pytest.raises(ValidationError, match=r".*float_type.*"):
ControlNetInput(mode="canny", strength=2**2048)
ControlInput(mode="canny", strength=2**2048)

@ -2,13 +2,21 @@ import pytest
from pydantic import ValidationError
from imaginairy import LazyLoadingImage, config
from imaginairy.schema import ControlNetInput, ImaginePrompt, WeightedPrompt
from imaginairy.schema import ControlInput, ImaginePrompt, WeightedPrompt
from imaginairy.utils.data_distorter import DataDistorter
from tests import TESTS_FOLDER
def test_imagine_prompt_default():
prompt = ImaginePrompt()
assert prompt.prompt == []
assert prompt.negative_prompt == [
WeightedPrompt(text=config.DEFAULT_NEGATIVE_PROMPT)
]
def test_imagine_prompt_has_default_negative():
prompt = ImaginePrompt("fruit salad", model="foobar")
prompt = ImaginePrompt("fruit salad", model_weights="foobar")
assert isinstance(prompt.prompt[0], WeightedPrompt)
assert isinstance(prompt.negative_prompt[0], WeightedPrompt)
@ -21,10 +29,10 @@ def test_imagine_prompt_custom_negative_prompt():
def test_imagine_prompt_model_specific_negative_prompt():
prompt = ImaginePrompt("fruit salad", model="openjourney-v1")
prompt = ImaginePrompt("fruit salad", model_weights="openjourney-v1")
assert isinstance(prompt.prompt[0], WeightedPrompt)
assert isinstance(prompt.negative_prompt[0], WeightedPrompt)
assert prompt.negative_prompt[0].text == ""
assert prompt.negative_prompt[0].text == "poor quality"
def test_imagine_prompt_weighted_prompts():
@ -84,7 +92,7 @@ def test_imagine_prompt_control_inputs():
prompt = ImaginePrompt(
"fruit",
control_inputs=[
ControlNetInput(mode="depth", image=img),
ControlInput(mode="depth", image=img),
],
)
prompt.control_inputs[0].image.convert("RGB")
@ -98,7 +106,7 @@ def test_imagine_prompt_control_inputs():
"fruit",
init_image=img,
control_inputs=[
ControlNetInput(mode="depth"),
ControlInput(mode="depth"),
],
)
assert prompt.control_inputs[0].image is not None
@ -107,7 +115,7 @@ def test_imagine_prompt_control_inputs():
prompt = ImaginePrompt(
"fruit",
control_inputs=[
ControlNetInput(mode="depth"),
ControlInput(mode="depth"),
],
)
assert prompt.control_inputs[0].image is None
@ -136,8 +144,8 @@ def test_imagine_prompt_mask_params():
def test_imagine_prompt_default_model():
prompt = ImaginePrompt("fruit", model=None)
assert prompt.model == config.DEFAULT_MODEL
prompt = ImaginePrompt("fruit", model_weights=None)
assert prompt.model_weights == config.DEFAULT_MODEL_WEIGHTS
def test_imagine_prompt_default_negative():
@ -152,7 +160,7 @@ def test_imagine_prompt_fix_faces_fidelity():
def test_imagine_prompt_init_strength_zero():
lazy_img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/red.png")
prompt = ImaginePrompt(
"fruit", control_inputs=[ControlNetInput(mode="depth", image=lazy_img)]
"fruit", control_inputs=[ControlInput(mode="depth", image=lazy_img)]
)
assert prompt.init_image_strength == 0.0
@ -171,12 +179,12 @@ def test_distorted_prompts():
init_image=LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/red.png"),
init_image_strength=0.5,
control_inputs=[
ControlNetInput(
ControlInput(
mode="details",
image=LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/red.png"),
strength=2,
),
ControlNetInput(
ControlInput(
mode="depth",
image_raw=LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/red.png"),
strength=3,
@ -187,13 +195,11 @@ def test_distorted_prompts():
mask_mode="replace",
mask_modify_original=False,
outpaint="all5,up0,down20",
model=config.DEFAULT_MODEL,
model_config_path=None,
sampler_type=config.DEFAULT_SAMPLER,
model_weights=config.DEFAULT_MODEL_WEIGHTS,
solver_type=config.DEFAULT_SOLVER,
seed=42,
steps=10,
height=256,
width=256,
size=256,
upscale=True,
fix_faces=True,
fix_faces_fidelity=0.7,

@ -1,6 +1,6 @@
import pytest
from imaginairy.utils.named_resolutions import get_named_resolution
from imaginairy.utils.named_resolutions import normalize_image_size
valid_cases = [
("HD", (1280, 720)),
@ -12,11 +12,25 @@ valid_cases = [
("1920x1080", (1920, 1080)),
("1280x720", (1280, 720)),
("1024x768", (1024, 768)),
("1024,768", (1024, 768)),
("1024*768", (1024, 768)),
("1024, 768", (1024, 768)),
("800", (800, 800)),
("1024", (1024, 1024)),
("1080p", (1920, 1080)),
("1080P", (1920, 1080)),
(512, (512, 512)),
((512, 512), (512, 512)),
("1x1", (1, 1)),
]
invalid_cases = [
None,
3.14,
(3.14, 3.14),
"",
" ",
"abc",
"-512",
"1920xABC",
"1920x1080x1234",
"x1920",
@ -30,10 +44,10 @@ invalid_cases = [
@pytest.mark.parametrize(("named_resolution", "expected"), valid_cases)
def test_named_resolutions(named_resolution, expected):
assert get_named_resolution(named_resolution) == expected
assert normalize_image_size(named_resolution) == expected
@pytest.mark.parametrize("named_resolution", invalid_cases)
def test_invalid_inputs(named_resolution):
with pytest.raises(ValueError, match="Unknown resolution"):
get_named_resolution(named_resolution)
with pytest.raises(ValueError, match="Invalid resolution"):
normalize_image_size(named_resolution)

Loading…
Cancel
Save