diff --git a/README.md b/README.md index 71b901c..166b56c 100644 --- a/README.md +++ b/README.md @@ -787,7 +787,7 @@ Use with `--model SD-2.1` or `--model SD-2.0-v` **6.1.0** - feature: use different default steps and image sizes depending on sampler and model selected - fix: #110 use proper version in image metadata -- refactor: samplers all have their own class that inherits from ImageSampler +- refactor: solvers all have their own class that inherits from ImageSolver - feature: 🎉🎉🎉 Stable Diffusion 2.0 - `--model SD-2.0` to use (it makes worse images than 1.5 though...) - Tested on macOS and Linux diff --git a/docs/todo.md b/docs/todo.md index e95bd8f..bf59996 100644 --- a/docs/todo.md +++ b/docs/todo.md @@ -11,10 +11,11 @@ - allow selection of output video format - chain multiple operations together imggen => videogen - make sure terminal output on windows doesn't suck - - add karras schedule to refiners + - add interface for loading diffusers weights - add method to show cache size - add method to clear model cache - add method to clear cached items not recently used (does diffusers have one?) + - https://huggingface.co/playgroundai/playground-v2-1024px-aesthetic ### Old Todo diff --git a/imaginairy/api.py b/imaginairy/api.py index b7f2be6..7237c17 100755 --- a/imaginairy/api.py +++ b/imaginairy/api.py @@ -1,35 +1,37 @@ import logging import os import re +from typing import TYPE_CHECKING, Callable -from imaginairy.schema import ControlNetInput, SafetyMode +if TYPE_CHECKING: + from imaginairy.schema import ImaginePrompt logger = logging.getLogger(__name__) # leave undocumented. I'd ask that no one publicize this flag. Just want a # slight barrier to entry. Please don't use this is any way that's gonna cause # the media or politicians to freak out about AI... -IMAGINAIRY_SAFETY_MODE = os.getenv("IMAGINAIRY_SAFETY_MODE", SafetyMode.STRICT) +IMAGINAIRY_SAFETY_MODE = os.getenv("IMAGINAIRY_SAFETY_MODE", "strict") if IMAGINAIRY_SAFETY_MODE in {"disabled", "classify"}: - IMAGINAIRY_SAFETY_MODE = SafetyMode.RELAXED + IMAGINAIRY_SAFETY_MODE = "relaxed" elif IMAGINAIRY_SAFETY_MODE == "filter": - IMAGINAIRY_SAFETY_MODE = SafetyMode.STRICT + IMAGINAIRY_SAFETY_MODE = "strict" # we put this in the global scope so it can be used in the interactive shell _most_recent_result = None def imagine_image_files( - prompts, - outdir, - precision="autocast", - record_step_images=False, - output_file_extension="jpg", - print_caption=False, - make_gif=False, - make_compare_gif=False, - return_filename_type="generated", - videogen=False, + prompts: "list[ImaginePrompt] | ImaginePrompt", + outdir: str, + precision: str = "autocast", + record_step_images: bool = False, + output_file_extension: str = "jpg", + print_caption: bool = False, + make_gif: bool = False, + make_compare_gif: bool = False, + return_filename_type: str = "generated", + videogen: bool = False, ): from PIL import ImageDraw @@ -46,6 +48,9 @@ def imagine_image_files( if output_file_extension not in {"jpg", "png"}: raise ValueError("Must output a png or jpg") + if not isinstance(prompts, list): + prompts = [prompts] + def _record_step(img, description, image_count, step_count, prompt): steps_path = os.path.join(outdir, "steps", f"{base_count:08}_S{prompt.seed}") os.makedirs(steps_path, exist_ok=True) @@ -74,7 +79,7 @@ def imagine_image_files( if prompt.init_image: img_str = f"_img2img-{prompt.init_image_strength}" basefilename = ( - f"{base_count:06}_{prompt.seed}_{prompt.sampler_type.replace('_', '')}{prompt.steps}_" + f"{base_count:06}_{prompt.seed}_{prompt.solver_type.replace('_', '')}{prompt.steps}_" f"PS{prompt.prompt_strength}{img_str}_{prompt_normalized(prompt.prompt_text)}" ) @@ -139,15 +144,15 @@ def imagine_image_files( def imagine( - prompts, - precision="autocast", - debug_img_callback=None, - progress_img_callback=None, - progress_img_interval_steps=3, + prompts: "list[ImaginePrompt] | str | ImaginePrompt", + precision: str = "autocast", + debug_img_callback: Callable | None = None, + progress_img_callback: Callable | None = None, + progress_img_interval_steps: int = 3, progress_img_interval_min_s=0.1, half_mode=None, - add_caption=False, - unsafe_retry_count=1, + add_caption: bool = False, + unsafe_retry_count: int = 1, ): import torch.nn @@ -209,7 +214,7 @@ def imagine( def _generate_single_image_compvis( - prompt, + prompt: "ImaginePrompt", debug_img_callback=None, progress_img_callback=None, progress_img_interval_steps=3, @@ -248,9 +253,9 @@ def _generate_single_image_compvis( from imaginairy.modules.midas.api import torch_image_to_depth_map from imaginairy.outpaint import outpaint_arg_str_parse, prepare_image_for_outpaint from imaginairy.safety import create_safety_score - from imaginairy.samplers import SAMPLER_LOOKUP + from imaginairy.samplers import SOLVER_LOOKUP from imaginairy.samplers.editing import CFGEditingDenoiser - from imaginairy.schema import ImaginePrompt, ImagineResult + from imaginairy.schema import ControlInput, ImagineResult, MaskMode from imaginairy.utils import get_device, randn_seeded latent_channels = 4 @@ -326,8 +331,8 @@ def _generate_single_image_compvis( prompt.height // downsampling_factor, prompt.width // downsampling_factor, ] - SamplerCls = SAMPLER_LOOKUP[prompt.sampler_type.lower()] - sampler = SamplerCls(model) + SolverCls = SOLVER_LOOKUP[prompt.solver_type.lower()] + solver = SolverCls(model) mask_latent = mask_image = mask_image_orig = mask_grayscale = None t_enc = init_latent = control_image = None starting_image = None @@ -385,7 +390,7 @@ def _generate_single_image_compvis( log_img(mask_image, "init mask") - if prompt.mask_mode == ImaginePrompt.MaskMode.REPLACE: + if prompt.mask_mode == MaskMode.REPLACE: mask_image = ImageOps.invert(mask_image) mask_image_orig = mask_image @@ -396,7 +401,7 @@ def _generate_single_image_compvis( if inpaint_method == "controlnet": result_images["control-inpaint"] = mask_image control_inputs.append( - ControlNetInput(mode="inpaint", image=mask_image) + ControlInput(mode="inpaint", image=mask_image) ) seed_everything(prompt.seed) @@ -543,7 +548,7 @@ def _generate_single_image_compvis( prompt=prompt, target_height=init_image.height, target_width=init_image.width, - cutoff=get_model_default_image_size(prompt.model), + cutoff=get_model_default_image_size(prompt.model_architecture), ) else: comp_image = _generate_composition_image( @@ -563,7 +568,7 @@ def _generate_single_image_compvis( model.encode_first_stage(comp_image_t) ) with lc.timing("sampling"): - samples = sampler.sample( + samples = solver.sample( num_steps=prompt.steps, positive_conditioning=positive_conditioning, neutral_conditioning=neutral_conditioning, @@ -711,8 +716,10 @@ def _generate_composition_image( composition_prompt = prompt.full_copy( deep=True, update={ - "width": int(prompt.width * shrink_scale), - "height": int(prompt.height * shrink_scale), + "size": ( + int(prompt.width * shrink_scale), + int(prompt.height * shrink_scale), + ), "steps": None, "upscale": False, "fix_faces": False, diff --git a/imaginairy/api_refiners.py b/imaginairy/api_refiners.py index 5731a23..76124bb 100644 --- a/imaginairy/api_refiners.py +++ b/imaginairy/api_refiners.py @@ -1,15 +1,16 @@ import logging from typing import List, Optional -from imaginairy import WeightedPrompt -from imaginairy.config import CONTROLNET_CONFIG_SHORTCUTS +from imaginairy import ImaginePrompt, WeightedPrompt +from imaginairy.config import CONTROL_CONFIG_SHORTCUTS from imaginairy.model_manager import load_controlnet_adapter +from imaginairy.schema import MaskMode logger = logging.getLogger(__name__) def _generate_single_image( - prompt, + prompt: ImaginePrompt, debug_img_callback=None, progress_img_callback=None, progress_img_interval_steps=3, @@ -55,7 +56,7 @@ def _generate_single_image( ) from imaginairy.outpaint import outpaint_arg_str_parse, prepare_image_for_outpaint from imaginairy.safety import create_safety_score - from imaginairy.samplers import SamplerName + from imaginairy.samplers import SolverName from imaginairy.schema import ImaginePrompt, ImagineResult from imaginairy.utils import get_device, randn_seeded @@ -76,8 +77,8 @@ def _generate_single_image( control_modes = [c.mode for c in prompt.control_inputs] sd = get_diffusion_model_refiners( - weights_location=prompt.model, - config_path=prompt.model_config_path, + weights_location=prompt.model_weights, + model_architecture=prompt.model_architecture, control_weights_locations=tuple(control_modes), dtype=dtype, for_inpainting=for_inpainting and inpaint_method == "finetune", @@ -90,7 +91,7 @@ def _generate_single_image( mask_image = None mask_image_orig = None - prompt = prompt.make_concrete_copy() + prompt: ImaginePrompt = prompt.make_concrete_copy() def latent_logger(latents): progress_latents.append(latents) @@ -171,7 +172,7 @@ def _generate_single_image( log_img(mask_image, "init mask") - if prompt.mask_mode == ImaginePrompt.MaskMode.REPLACE: + if prompt.mask_mode == MaskMode.REPLACE: mask_image = ImageOps.invert(mask_image) mask_image_orig = mask_image @@ -182,7 +183,7 @@ def _generate_single_image( # if inpaint_method == "controlnet": # result_images["control-inpaint"] = mask_image # control_inputs.append( - # ControlNetInput(mode="inpaint", image=mask_image) + # ControlInput(mode="inpaint", image=mask_image) # ) seed_everything(prompt.seed) @@ -194,7 +195,6 @@ def _generate_single_image( controlnets = [] if control_modes: - control_strengths = [] from imaginairy.img_processors.control_modes import CONTROL_MODES for control_input in control_inputs: @@ -231,10 +231,10 @@ def _generate_single_image( log_img(control_image_disp, "control_image") if len(control_image_t.shape) == 3: - raise RuntimeError("Control image must be 4D") + raise ValueError("Control image must be 4D") if control_image_t.shape[1] != 3: - raise RuntimeError("Control image must have 3 channels") + raise ValueError("Control image must have 3 channels") if ( control_input.mode != "inpaint" @@ -242,21 +242,20 @@ def _generate_single_image( or control_image_t.max() > 1 ): msg = f"Control image must be in [0, 1] but we received {control_image_t.min()} and {control_image_t.max()}" - raise RuntimeError(msg) + raise ValueError(msg) if control_image_t.max() == control_image_t.min(): msg = f"No control signal found in control image {control_input.mode}." - raise RuntimeError(msg) - - control_strengths.append(control_input.strength) + raise ValueError(msg) - control_weights_path = CONTROLNET_CONFIG_SHORTCUTS.get( - control_input.mode, None - ).weights_url + control_config = CONTROL_CONFIG_SHORTCUTS.get(control_input.mode, None) + if not control_config: + msg = f"Unknown control mode: {control_input.mode}" + raise ValueError(msg) controlnet = load_controlnet_adapter( name=control_input.mode, - control_weights_location=control_weights_path, + control_weights_location=control_config.weights_location, target_unet=sd.unet, scale=control_input.strength, ) @@ -268,7 +267,7 @@ def _generate_single_image( prompt=prompt, target_height=init_image.height, target_width=init_image.width, - cutoff=get_model_default_image_size(prompt.model), + cutoff=get_model_default_image_size(prompt.model_architecture), dtype=dtype, ) else: @@ -276,7 +275,7 @@ def _generate_single_image( prompt=prompt, target_height=prompt.height, target_width=prompt.width, - cutoff=get_model_default_image_size(prompt.model), + cutoff=get_model_default_image_size(prompt.model_architecture), dtype=dtype, ) if comp_image is not None: @@ -296,12 +295,12 @@ def _generate_single_image( control_image_t.to(device=sd.device, dtype=sd.dtype) ) controlnet.inject() - if prompt.sampler_type.lower() == SamplerName.K_DPMPP_2M: + if prompt.solver_type.lower() == SolverName.DPMPP: sd.scheduler = DPMSolver(num_inference_steps=prompt.steps) - elif prompt.sampler_type.lower() == SamplerName.DDIM: + elif prompt.solver_type.lower() == SolverName.DDIM: sd.scheduler = DDIM(num_inference_steps=prompt.steps) else: - msg = f"Unknown sampler type: {prompt.sampler_type}" + msg = f"Unknown solver type: {prompt.solver_type}" raise ValueError(msg) sd.scheduler.to(device=sd.device, dtype=sd.dtype) sd.set_num_inference_steps(prompt.steps) @@ -414,18 +413,24 @@ def _generate_single_image( caption_text = prompt.caption_text.format(prompt=prompt.prompt_text) add_caption_to_image(gen_img, caption_text) + # todo: do something smarter + result_images.update( + { + "upscaled": upscaled_img, + "modified_original": rebuilt_orig_img, + "mask_binary": mask_image_orig, + "mask_grayscale": mask_grayscale, + } + ) + result = ImagineResult( img=gen_img, prompt=prompt, - upscaled_img=upscaled_img, is_nsfw=safety_score.is_nsfw, safety_score=safety_score, - modified_original=rebuilt_orig_img, - mask_binary=mask_image_orig, - mask_grayscale=mask_grayscale, result_images=result_images, - timings={}, - progress_latents=[], + timings={}, # todo + progress_latents=[], # todo ) _most_recent_result = result @@ -441,6 +446,9 @@ def _generate_single_image( def _prompts_to_embeddings(prompts, text_encoder): import torch + if not prompts: + prompts = [WeightedPrompt(text="")] + total_weight = sum(wp.weight for wp in prompts) if str(text_encoder.device) == "cpu": text_encoder = text_encoder.to(dtype=torch.float32) diff --git a/imaginairy/cli/edit.py b/imaginairy/cli/edit.py index 62e7b9e..7befabe 100644 --- a/imaginairy/cli/edit.py +++ b/imaginairy/cli/edit.py @@ -31,7 +31,7 @@ remove_option(edit_options, "allow_compose_phase") @click.option( "--model-weights-path", "--model", - help=f"Model to use. Should be one of {', '.join(config.MODEL_SHORT_NAMES)}, or a path to custom weights.", + help=f"Model to use. Should be one of {', '.join(config.IMAGE_WEIGHTS_SHORT_NAMES)}, or a path to custom weights.", show_default=True, default="SD-1.5", ) @@ -53,15 +53,13 @@ def edit_cmd( outdir, output_file_extension, repeats, - height, - width, size, steps, seed, upscale, fix_faces, fix_faces_fidelity, - sampler_type, + solver, log_level, quiet, show_work, @@ -76,7 +74,7 @@ def edit_cmd( caption, precision, model_weights_path, - model_config_path, + model_architecture, prompt_library_path, version, make_gif, @@ -95,11 +93,11 @@ def edit_cmd( Same as calling `aimg imagine --model edit --init-image my-dog.jpg --init-image-strength 1` except this command can batch edit images. """ - from imaginairy.schema import ControlNetInput + from imaginairy.schema import ControlInput allow_compose_phase = False control_inputs = [ - ControlNetInput( + ControlInput( image=None, image_raw=None, mode="edit", @@ -116,15 +114,13 @@ def edit_cmd( outdir, output_file_extension, repeats, - height, - width, size, steps, seed, upscale, fix_faces, fix_faces_fidelity, - sampler_type, + solver, log_level, quiet, show_work, @@ -140,7 +136,7 @@ def edit_cmd( caption, precision, model_weights_path, - model_config_path, + model_architecture, prompt_library_path, version, make_gif, diff --git a/imaginairy/cli/imagine.py b/imaginairy/cli/imagine.py index ada0d57..54ca195 100644 --- a/imaginairy/cli/imagine.py +++ b/imaginairy/cli/imagine.py @@ -77,15 +77,13 @@ def imagine_cmd( outdir, output_file_extension, repeats, - height, - width, size, steps, seed, upscale, fix_faces, fix_faces_fidelity, - sampler_type, + solver, log_level, quiet, show_work, @@ -101,7 +99,7 @@ def imagine_cmd( caption, precision, model_weights_path, - model_config_path, + model_architecture, prompt_library_path, version, make_gif, @@ -120,7 +118,7 @@ def imagine_cmd( Can be invoked via either `aimg imagine` or just `imagine`. """ - from imaginairy.schema import ControlNetInput, LazyLoadingImage + from imaginairy.schema import ControlInput, LazyLoadingImage # hacky method of getting order of control images (mixing raw and normal images) control_images = [ @@ -128,13 +126,18 @@ def imagine_cmd( for o, path in ImagineColorsCommand._option_order if o.name in ("control_image", "control_image_raw") ] + control_strengths = [ + strength + for o, strength in ImagineColorsCommand._option_order + if o.name == "control_strength" + ] + control_inputs = [] if control_mode: for i, cm in enumerate(control_mode): - try: - option = control_images[i] - except IndexError: - option = None + option = index_default(control_images, i, None) + control_strength = index_default(control_strengths, i, 1.0) + if option is None: control_image = None control_image_raw = None @@ -149,10 +152,10 @@ def imagine_cmd( if control_image_raw and control_image_raw.startswith("http"): control_image_raw = LazyLoadingImage(url=control_image_raw) control_inputs.append( - ControlNetInput( + ControlInput( image=control_image, image_raw=control_image_raw, - strength=float(control_strength[i]), + strength=float(control_strength), mode=cm, ) ) @@ -167,15 +170,13 @@ def imagine_cmd( outdir, output_file_extension, repeats, - height, - width, size, steps, seed, upscale, fix_faces, fix_faces_fidelity, - sampler_type, + solver, log_level, quiet, show_work, @@ -191,7 +192,7 @@ def imagine_cmd( caption, precision, model_weights_path, - model_config_path, + model_architecture, prompt_library_path, version, make_gif, @@ -204,5 +205,12 @@ def imagine_cmd( ) +def index_default(items, index, default): + try: + return items[index] + except IndexError: + return default + + if __name__ == "__main__": imagine_cmd() diff --git a/imaginairy/cli/main.py b/imaginairy/cli/main.py index 3722f5e..a866f5d 100644 --- a/imaginairy/cli/main.py +++ b/imaginairy/cli/main.py @@ -45,8 +45,6 @@ aimg.add_command(describe_cmd, name="describe") aimg.add_command(edit_cmd, name="edit") aimg.add_command(edit_demo_cmd, name="edit-demo") aimg.add_command(imagine_cmd, name="imagine") -# aimg.add_command(prep_images_cmd, name="prep-images") -# aimg.add_command(prune_ckpt_cmd, name="prune-ckpt") aimg.add_command(upscale_cmd, name="upscale") aimg.add_command(run_server_cmd, name="server") aimg.add_command(videogen_cmd, name="videogen") @@ -85,14 +83,14 @@ def model_list_cmd(): from imaginairy import config print(f"{'ALIAS': <10} {'NAME': <18} {'DESCRIPTION'}") - for model_config in config.MODEL_CONFIGS: + for model_config in config.MODEL_WEIGHT_CONFIGS: print( f"{model_config.alias: <10} {model_config.short_name: <18} {model_config.description}" ) print("\nCONTROL MODES:") print(f"{'ALIAS': <10} {'NAME': <18} {'CONTROL TYPE'}") - for control_mode in config.CONTROLNET_CONFIGS: + for control_mode in config.CONTROL_CONFIGS: print( f"{control_mode.alias: <10} {control_mode.short_name: <18} {control_mode.control_type}" ) diff --git a/imaginairy/cli/shared.py b/imaginairy/cli/shared.py index 62dae7e..ad0a559 100644 --- a/imaginairy/cli/shared.py +++ b/imaginairy/cli/shared.py @@ -34,15 +34,13 @@ def _imagine_cmd( outdir, output_file_extension, repeats, - height, - width, size, steps, seed, upscale, fix_faces, fix_faces_fidelity, - sampler_type, + solver, log_level, quiet, show_work, @@ -58,7 +56,7 @@ def _imagine_cmd( caption, precision, model_weights_path, - model_config_path, + model_architecture, prompt_library_path, version=False, make_gif=False, @@ -96,15 +94,6 @@ def _imagine_cmd( configure_logging(log_level) - if (height is not None or width is not None) and size is not None: - msg = "You cannot specify both --size and --height/--width. Please choose one." - raise ValueError(msg) - - if size is not None: - from imaginairy.utils.named_resolutions import get_named_resolution - - width, height = get_named_resolution(size) - init_images = [init_image] if isinstance(init_image, str) else init_image from imaginairy.utils import glob_expand_paths @@ -171,10 +160,9 @@ def _imagine_cmd( init_image_strength=init_image_strength, control_inputs=control_inputs, seed=seed, - sampler_type=sampler_type, + solver_type=solver, steps=steps, - height=height, - width=width, + size=size, mask_image=mask_image, mask_prompt=mask_prompt, mask_mode=mask_mode, @@ -185,8 +173,8 @@ def _imagine_cmd( fix_faces_fidelity=fix_faces_fidelity, tile_mode=_tile_mode, allow_compose_phase=allow_compose_phase, - model=model_weights_path, - model_config_path=model_config_path, + model_weights=model_weights_path, + model_architecture=model_architecture, caption_text=caption_text, ) from imaginairy.prompt_schedules import ( @@ -318,28 +306,12 @@ common_options = [ type=int, help="How many times to repeat the renders. If you provide two prompts and --repeat=3 then six images will be generated.", ), - click.option( - "-h", - "--height", - default=None, - show_default=True, - type=int, - help="Image height. Should be multiple of 8.", - ), - click.option( - "-w", - "--width", - default=None, - show_default=True, - type=int, - help="Image width. Should be multiple of 8.", - ), click.option( "--size", default=None, show_default=True, type=str, - help="Image size as a string. Can be a named size or WIDTHxHEIGHT format. Should be multiple of 8. Examples: 512x512, 4k, UHD, 8k, ", + help="Image size as a string. Can be a named size, WIDTHxHEIGHT, or single integer. Should be multiple of 8. Examples: 512x512, 4k, UHD, 8k, 512, 1080p", ), click.option( "--steps", @@ -363,18 +335,18 @@ common_options = [ help="How faithful to the original should face enhancement be. 1 = best fidelity, 0 = best looking face.", ), click.option( - "--sampler-type", + "--solver", "--sampler", - default=config.DEFAULT_SAMPLER, + default=config.DEFAULT_SOLVER, show_default=True, - type=click.Choice(config.SAMPLER_TYPE_OPTIONS), - help="What sampling strategy to use.", + type=click.Choice(config.SOLVER_TYPE_NAMES, case_sensitive=False), + help="Solver algorithm to generate the image with. (AKA 'Sampler' or 'Scheduler' in other libraries.", ), click.option( "--log-level", default="INFO", show_default=True, - type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR"]), + type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR"], case_sensitive=False), help="What level of logs to show.", ), click.option( @@ -429,7 +401,7 @@ common_options = [ "--mask-mode", default="replace", show_default=True, - type=click.Choice(["keep", "replace"]), + type=click.Choice(["keep", "replace"], case_sensitive=False), help="Should we replace the masked area or keep it?", ), click.option( @@ -458,20 +430,20 @@ common_options = [ click.option( "--precision", help="Evaluate at this precision.", - type=click.Choice(["full", "autocast"]), + type=click.Choice(["full", "autocast"], case_sensitive=False), default="autocast", show_default=True, ), click.option( "--model-weights-path", "--model", - help=f"Model to use. Should be one of {', '.join(config.MODEL_SHORT_NAMES)}, or a path to custom weights.", + help=f"Model to use. Should be one of {', '.join(config.IMAGE_WEIGHTS_SHORT_NAMES)}, or a path to custom weights.", show_default=True, - default=config.DEFAULT_MODEL, + default=config.DEFAULT_MODEL_WEIGHTS, ), click.option( - "--model-config-path", - help="Model config file to use. If a model name is specified, the appropriate config will be used.", + "--model-architecture", + help="Model architecture. When specifying custom weights the model architecture must be specified. (sd15, sdxl, etc).", show_default=True, default=None, ), diff --git a/imaginairy/cli/train.py b/imaginairy/cli/train.py index ec86d17..b952f14 100644 --- a/imaginairy/cli/train.py +++ b/imaginairy/cli/train.py @@ -44,9 +44,9 @@ logger = logging.getLogger(__name__) "--model-weights-path", "--model", "model", - help=f"Model to use. Should be one of {', '.join(config.MODEL_SHORT_NAMES)}, or a path to custom weights.", + help=f"Model to use. Should be one of {', '.join(config.IMAGE_WEIGHTS_SHORT_NAMES)}, or a path to custom weights.", show_default=True, - default=config.DEFAULT_MODEL, + default=config.DEFAULT_MODEL_WEIGHTS, ) @click.option( "--person", diff --git a/imaginairy/colorize.py b/imaginairy/colorize.py index c3ec536..ca0ce33 100644 --- a/imaginairy/colorize.py +++ b/imaginairy/colorize.py @@ -4,7 +4,7 @@ from PIL import Image, ImageEnhance, ImageStat from imaginairy import ImaginePrompt, imagine from imaginairy.enhancers.describe_image_blip import generate_caption -from imaginairy.schema import ControlNetInput +from imaginairy.schema import ControlInput logger = logging.getLogger(__name__) @@ -23,7 +23,7 @@ def colorize_img(img, max_width=1024, max_height=1024, caption=None): caption = caption.replace(" old ", " ") logger.info(caption) control_inputs = [ - ControlNetInput(mode="colorize", image=img, strength=2), + ControlInput(mode="colorize", image=img, strength=2), ] prompt_add = ". color photo, sharp-focus, highly detailed, intricate, Canon 5D" prompt = ImaginePrompt( diff --git a/imaginairy/config.py b/imaginairy/config.py index 3e61fc5..02d76cf 100644 --- a/imaginairy/config.py +++ b/imaginairy/config.py @@ -1,7 +1,9 @@ from dataclasses import dataclass +from typing import Any, List -DEFAULT_MODEL = "SD-1.5" -DEFAULT_SAMPLER = "ddim" +DEFAULT_MODEL_WEIGHTS = "sd15" +DEFAULT_MODEL_ARCHITECTURE = "sd15" +DEFAULT_SOLVER = "ddim" DEFAULT_NEGATIVE_PROMPT = ( "Ugly, duplication, duplicates, mutilation, deformed, mutilated, mutation, twisted body, disfigured, bad anatomy, " @@ -12,229 +14,306 @@ DEFAULT_NEGATIVE_PROMPT = ( "grainy, blurred, blurry, writing, calligraphy, signature, text, watermark, bad art," ) -SPLITMEM_ENABLED = False +midas_url = "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt" @dataclass -class ModelConfig: - description: str - short_name: str - config_path: str - weights_url: str - default_image_size: int - forced_attn_precision: str = "default" - default_negative_prompt: str = DEFAULT_NEGATIVE_PROMPT - alias: str = None +class ModelArchitecture: + name: str + aliases: List[str] + output_modality: str + defaults: dict[str, Any] + config_path: str | None = None -midas_url = "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt" +@dataclass +class ModelWeightsConfig: + name: str + aliases: List[str] + architecture: ModelArchitecture + defaults: dict[str, Any] + weights_location: str -MODEL_CONFIGS = [ - ModelConfig( - description="Stable Diffusion 1.5", - short_name="SD-1.5", + +MODEL_ARCHITECTURES = [ + ModelArchitecture( + name="Stable Diffusion 1.5", + aliases=["sd15", "sd-15", "sd1.5", "sd-1.5"], + output_modality="image", + defaults={"size": "512"}, config_path="configs/stable-diffusion-v1.yaml", - weights_url="https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/889b629140e71758e1e0006e355c331a5744b4bf/v1-5-pruned-emaonly.ckpt", - default_image_size=512, - alias="sd15", ), - ModelConfig( - description="Stable Diffusion 1.5 - Inpainting", - short_name="SD-1.5-inpaint", + ModelArchitecture( + name="Stable Diffusion 1.5 - Inpainting", + aliases=[ + "sd15inpaint", + "sd15-inpaint", + "sd-15-inpaint", + "sd1.5inpaint", + "sd1.5-inpaint", + "sd-1.5-inpaint", + ], + output_modality="image", + defaults={"size": "512"}, config_path="configs/stable-diffusion-v1-inpaint.yaml", - weights_url="https://huggingface.co/julienacquaviva/inpainting/resolve/2155ff7fe38b55f4c0d99c2f1ab9b561f8311ca7/sd-v1-5-inpainting.ckpt", - default_image_size=512, - alias="sd15in", - ), - # ModelConfig( - # description="Instruct Pix2Pix - Photo Editing", - # short_name="instruct-pix2pix", - # config_path="configs/instruct-pix2pix.yaml", - # weights_url="https://huggingface.co/imaginairy/instruct-pix2pix/resolve/ea0009b3d0d4888f410a40bd06d69516d0b5a577/instruct-pix2pix-00-22000-pruned.ckpt", - # default_image_size=512, - # default_negative_prompt="", - # alias="edit", - # ), - ModelConfig( - description="OpenJourney V1", - short_name="openjourney-v1", - config_path="configs/stable-diffusion-v1.yaml", - weights_url="https://huggingface.co/prompthero/openjourney/resolve/7428477dad893424c92f6ea1cc29d45f6d1448c1/mdjrny-v4.safetensors", - default_image_size=512, - default_negative_prompt="", - alias="oj1", - ), - ModelConfig( - description="OpenJourney V2", - short_name="openjourney-v2", - config_path="configs/stable-diffusion-v1.yaml", - weights_url="https://huggingface.co/prompthero/openjourney-v2/resolve/47257274a40e93dab7fbc0cd2cfd5f5704cfeb60/openjourney-v2.ckpt", - default_image_size=512, - default_negative_prompt="", - alias="oj2", - ), - ModelConfig( - description="OpenJourney V4", - short_name="openjourney-v4", - config_path="configs/stable-diffusion-v1.yaml", - weights_url="https://huggingface.co/prompthero/openjourney/resolve/e291118e93d5423dc88ac1ed93c02362b17d698f/mdjrny-v4.safetensors", - default_image_size=512, - default_negative_prompt="", - alias="oj4", + ), + ModelArchitecture( + name="Stable Diffusion XL", + aliases=["sdxl", "sd-xl"], + output_modality="image", + defaults={"size": "512"}, + ), + ModelArchitecture( + name="Stable Video Diffusion", + aliases=["svd", "stablevideo"], + output_modality="video", + defaults={"size": "1024x576"}, + config_path="configs/svd.yaml", + ), + ModelArchitecture( + name="Stable Video Diffusion - Image Decoder", + aliases=["svd-image-decoder", "svd-imdec"], + output_modality="video", + defaults={"size": "1024x576"}, + config_path="configs/svd_image_decoder.yaml", + ), + ModelArchitecture( + name="Stable Video Diffusion - XT", + aliases=["svd-xt", "svd25f", "svd-25f", "stablevideoxt", "svdxt"], + output_modality="video", + defaults={"size": "1024x576"}, + config_path="configs/svd_xt.yaml", + ), + ModelArchitecture( + name="Stable Video Diffusion - XT - Image Decoder", + aliases=[ + "svd-xt-image-decoder", + "svd-xt-imdec", + "svd-25f-imdec", + "svdxt-imdec", + "svdxtimdec", + "svd25fimdec", + "svdxtimdec", + ], + output_modality="video", + defaults={"size": "1024x576"}, + config_path="configs/svd_xt_image_decoder.yaml", ), ] +MODEL_ARCHITECTURE_LOOKUP = {} +for m in MODEL_ARCHITECTURES: + for a in m.aliases: + MODEL_ARCHITECTURE_LOOKUP[a] = m -video_models = [ - { - "short_name": "svd", - "description": "Stable Video Diffusion", - "default_frames": 14, - "default_steps": 25, - "config_path": "configs/svd.yaml", - "weights_url": "https://huggingface.co/imaginairy/stable-video-diffusion/resolve/f9dce2757a0713da6262f35438050357c2be7ee6/svd.fp16.safetensors", - }, - { - "short_name": "svd_image_decoder", - "description": "Stable Video Diffusion - Image Decoder", - "default_frames": 14, - "default_steps": 25, - "config_path": "configs/svd_image_decoder.yaml", - "weights_url": "https://huggingface.co/imaginairy/stable-video-diffusion/resolve/f9dce2757a0713da6262f35438050357c2be7ee6/svd_image_decoder.fp16.safetensors", - }, - { - "short_name": "svd_xt", - "description": "Stable Video Diffusion - XT", - "default_frames": 25, - "default_steps": 30, - "config_path": "configs/svd_xt.yaml", - "weights_url": "https://huggingface.co/imaginairy/stable-video-diffusion/resolve/f9dce2757a0713da6262f35438050357c2be7ee6/svd_xt.fp16.safetensors", - }, - { - "short_name": "svd_xt_image_decoder", - "description": "Stable Video Diffusion - XT - Image Decoder", - "default_frames": 25, - "default_steps": 30, - "config_path": "configs/svd_xt_image_decoder.yaml", - "weights_url": "https://huggingface.co/imaginairy/stable-video-diffusion/resolve/f9dce2757a0713da6262f35438050357c2be7ee6/svd_xt_image_decoder.fp16.safetensors", - }, + +MODEL_WEIGHT_CONFIGS = [ + ModelWeightsConfig( + name="Stable Diffusion 1.5", + aliases=MODEL_ARCHITECTURE_LOOKUP["sd15"].aliases, + architecture=MODEL_ARCHITECTURE_LOOKUP["sd15"], + defaults={"negative_prompt": DEFAULT_NEGATIVE_PROMPT}, + weights_location="https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/889b629140e71758e1e0006e355c331a5744b4bf/v1-5-pruned-emaonly.ckpt", + ), + ModelWeightsConfig( + name="Stable Diffusion 1.5 - Inpainting", + aliases=MODEL_ARCHITECTURE_LOOKUP["sd15inpaint"].aliases, + architecture=MODEL_ARCHITECTURE_LOOKUP["sd15inpaint"], + defaults={"negative_prompt": DEFAULT_NEGATIVE_PROMPT}, + weights_location="https://huggingface.co/julienacquaviva/inpainting/resolve/2155ff7fe38b55f4c0d99c2f1ab9b561f8311ca7/sd-v1-5-inpainting.ckpt", + ), + ModelWeightsConfig( + name="OpenJourney V1", + aliases=["openjourney-v1", "oj1", "ojv1", "openjourney1"], + architecture=MODEL_ARCHITECTURE_LOOKUP["sd15"], + defaults={"negative_prompt": "poor quality"}, + weights_location="https://huggingface.co/prompthero/openjourney/resolve/7428477dad893424c92f6ea1cc29d45f6d1448c1/mdjrny-v4.safetensors", + ), + ModelWeightsConfig( + name="OpenJourney V2", + aliases=["openjourney-v2", "oj2", "ojv2", "openjourney2"], + architecture=MODEL_ARCHITECTURE_LOOKUP["sd15"], + weights_location="https://huggingface.co/prompthero/openjourney-v2/resolve/47257274a40e93dab7fbc0cd2cfd5f5704cfeb60/openjourney-v2.ckpt", + defaults={"negative_prompt": "poor quality"}, + ), + ModelWeightsConfig( + name="OpenJourney V4", + aliases=["openjourney-v4", "oj4", "ojv4", "openjourney4", "openjourney", "oj"], + architecture=MODEL_ARCHITECTURE_LOOKUP["sd15"], + weights_location="https://huggingface.co/prompthero/openjourney/resolve/e291118e93d5423dc88ac1ed93c02362b17d698f/mdjrny-v4.safetensors", + defaults={"negative_prompt": "poor quality"}, + ), + # Video Weights + ModelWeightsConfig( + name="Stable Video Diffusion", + aliases=MODEL_ARCHITECTURE_LOOKUP["svd"].aliases, + architecture=MODEL_ARCHITECTURE_LOOKUP["svd"], + weights_location="https://huggingface.co/imaginairy/stable-video-diffusion/resolve/f9dce2757a0713da6262f35438050357c2be7ee6/svd.fp16.safetensors", + defaults={"frames": 14, "steps": 25}, + ), + ModelWeightsConfig( + name="Stable Video Diffusion - Image Decoder", + aliases=MODEL_ARCHITECTURE_LOOKUP["svd-image-decoder"].aliases, + architecture=MODEL_ARCHITECTURE_LOOKUP["svd-image-decoder"], + weights_location="https://huggingface.co/imaginairy/stable-video-diffusion/resolve/f9dce2757a0713da6262f35438050357c2be7ee6/svd_image_decoder.fp16.safetensors", + defaults={"frames": 14, "steps": 25}, + ), + ModelWeightsConfig( + name="Stable Video Diffusion - XT", + aliases=MODEL_ARCHITECTURE_LOOKUP["svdxt"].aliases, + architecture=MODEL_ARCHITECTURE_LOOKUP["svdxt"], + weights_location="https://huggingface.co/imaginairy/stable-video-diffusion/resolve/f9dce2757a0713da6262f35438050357c2be7ee6/svd_xt.fp16.safetensors", + defaults={"frames": 25, "steps": 30}, + ), + ModelWeightsConfig( + name="Stable Video Diffusion - XT - Image Decoder", + aliases=MODEL_ARCHITECTURE_LOOKUP["svd-xt-image-decoder"].aliases, + architecture=MODEL_ARCHITECTURE_LOOKUP["svd-xt-image-decoder"], + weights_location="https://huggingface.co/imaginairy/stable-video-diffusion/resolve/f9dce2757a0713da6262f35438050357c2be7ee6/svd_xt_image_decoder.fp16.safetensors", + defaults={"frames": 25, "steps": 30}, + ), ] -video_models = {m["short_name"]: m for m in video_models} -MODEL_CONFIG_SHORTCUTS = {m.short_name: m for m in MODEL_CONFIGS} -for m in MODEL_CONFIGS: - if m.alias: - MODEL_CONFIG_SHORTCUTS[m.alias] = m +MODEL_WEIGHT_CONFIG_LOOKUP = {} +for m in MODEL_WEIGHT_CONFIGS: + for a in m.aliases: + MODEL_WEIGHT_CONFIG_LOOKUP[a] = m -MODEL_CONFIG_SHORTCUTS["openjourney"] = MODEL_CONFIG_SHORTCUTS["openjourney-v2"] -MODEL_CONFIG_SHORTCUTS["oj"] = MODEL_CONFIG_SHORTCUTS["openjourney-v2"] -MODEL_SHORT_NAMES = sorted(MODEL_CONFIG_SHORTCUTS.keys()) +IMAGE_WEIGHTS_SHORT_NAMES = [ + k + for k, mw in MODEL_WEIGHT_CONFIG_LOOKUP.items() + if mw.architecture.output_modality == "image" +] +IMAGE_WEIGHTS_SHORT_NAMES.sort() @dataclass -class ControlNetConfig: - short_name: str +class ControlConfig: + name: str + aliases: List[str] control_type: str config_path: str - weights_url: str - alias: str = None + weights_location: str -CONTROLNET_CONFIGS = [ - ControlNetConfig( - short_name="canny15", +CONTROL_CONFIGS = [ + ControlConfig( + name="Canny Edge Control", + aliases=["canny", "canny15"], control_type="canny", config_path="configs/control-net-v15.yaml", - weights_url="https://huggingface.co/lllyasviel/control_v11p_sd15_canny/resolve/115a470d547982438f70198e353a921996e2e819/diffusion_pytorch_model.fp16.safetensors", - alias="canny", + weights_location="https://huggingface.co/lllyasviel/control_v11p_sd15_canny/resolve/115a470d547982438f70198e353a921996e2e819/diffusion_pytorch_model.fp16.safetensors", ), - ControlNetConfig( - short_name="depth15", + ControlConfig( + name="Depth Control", + aliases=["depth", "depth15"], control_type="depth", config_path="configs/control-net-v15.yaml", - weights_url="https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth/resolve/539f99181d33db39cf1af2e517cd8056785f0a87/diffusion_pytorch_model.fp16.safetensors", - alias="depth", + weights_location="https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth/resolve/539f99181d33db39cf1af2e517cd8056785f0a87/diffusion_pytorch_model.fp16.safetensors", ), - ControlNetConfig( - short_name="normal15", + ControlConfig( + name="Normal Map Control", + aliases=["normal", "normal15"], control_type="normal", config_path="configs/control-net-v15.yaml", - weights_url="https://huggingface.co/lllyasviel/control_v11p_sd15_normalbae/resolve/cb7296e6587a219068e9d65864e38729cd862aa8/diffusion_pytorch_model.fp16.safetensors", - alias="normal", + weights_location="https://huggingface.co/lllyasviel/control_v11p_sd15_normalbae/resolve/cb7296e6587a219068e9d65864e38729cd862aa8/diffusion_pytorch_model.fp16.safetensors", ), - ControlNetConfig( - short_name="hed15", + ControlConfig( + name="Soft Edge Control (HED)", + aliases=["hed", "hed15"], control_type="hed", config_path="configs/control-net-v15.yaml", - weights_url="https://huggingface.co/lllyasviel/control_v11p_sd15_softedge/resolve/b5bcad0c48e9b12f091968cf5eadbb89402d6bc9/diffusion_pytorch_model.fp16.safetensors", - alias="hed", + weights_location="https://huggingface.co/lllyasviel/control_v11p_sd15_softedge/resolve/b5bcad0c48e9b12f091968cf5eadbb89402d6bc9/diffusion_pytorch_model.fp16.safetensors", ), - ControlNetConfig( - short_name="openpose15", + ControlConfig( + name="Pose Control", control_type="openpose", config_path="configs/control-net-v15.yaml", - weights_url="https://huggingface.co/lllyasviel/control_v11p_sd15_openpose/resolve/9ae9f970358db89e211b87c915f9535c6686d5ba/diffusion_pytorch_model.fp16.safetensors", - alias="openpose", + weights_location="https://huggingface.co/lllyasviel/control_v11p_sd15_openpose/resolve/9ae9f970358db89e211b87c915f9535c6686d5ba/diffusion_pytorch_model.fp16.safetensors", + aliases=["openpose", "pose", "pose15", "openpose15"], ), - ControlNetConfig( - short_name="shuffle15", + ControlConfig( + name="Shuffle Control", control_type="shuffle", config_path="configs/control-net-v15-pool.yaml", - weights_url="https://huggingface.co/lllyasviel/control_v11e_sd15_shuffle/resolve/8cf275970f984acf5cc0fdfa537db8be098936a3/diffusion_pytorch_model.fp16.safetensors", - alias="shuffle", + weights_location="https://huggingface.co/lllyasviel/control_v11e_sd15_shuffle/resolve/8cf275970f984acf5cc0fdfa537db8be098936a3/diffusion_pytorch_model.fp16.safetensors", + aliases=["shuffle", "shuffle15"], ), # "instruct pix2pix" - ControlNetConfig( - short_name="edit15", + ControlConfig( + name="Edit Prompt Control", + aliases=["edit", "edit15"], control_type="edit", config_path="configs/control-net-v15.yaml", - weights_url="https://huggingface.co/lllyasviel/control_v11e_sd15_ip2p/resolve/1fed6ebb905c61929a60514830eb05b039969d6d/diffusion_pytorch_model.fp16.safetensors", - alias="edit", + weights_location="https://huggingface.co/lllyasviel/control_v11e_sd15_ip2p/resolve/1fed6ebb905c61929a60514830eb05b039969d6d/diffusion_pytorch_model.fp16.safetensors", ), - ControlNetConfig( - short_name="inpaint15", + ControlConfig( + name="Inpaint Control", + aliases=["inpaint", "inpaint15"], control_type="inpaint", config_path="configs/control-net-v15.yaml", - weights_url="https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint/resolve/c96e03a807e64135568ba8aecb66b3a306ec73bd/diffusion_pytorch_model.fp16.safetensors", - alias="inpaint", + weights_location="https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint/resolve/c96e03a807e64135568ba8aecb66b3a306ec73bd/diffusion_pytorch_model.fp16.safetensors", ), - ControlNetConfig( - short_name="details15", + ControlConfig( + name="Details Control (Upscale Tile)", + aliases=["details", "details15"], control_type="details", config_path="configs/control-net-v15.yaml", - weights_url="https://huggingface.co/lllyasviel/control_v11f1e_sd15_tile/resolve/3f877705c37010b7221c3d10743307d6b5b6efac/diffusion_pytorch_model.bin", - alias="details", + weights_location="https://huggingface.co/lllyasviel/control_v11f1e_sd15_tile/resolve/3f877705c37010b7221c3d10743307d6b5b6efac/diffusion_pytorch_model.bin", ), - ControlNetConfig( - short_name="colorize15", + ControlConfig( + name="Brightness Control (Colorize)", + aliases=["colorize", "colorize15"], control_type="colorize", config_path="configs/control-net-v15.yaml", - weights_url="https://huggingface.co/ioclab/control_v1p_sd15_brightness/resolve/8509361eb1ba89c03839040ed8c75e5f11bbd9c5/diffusion_pytorch_model.safetensors", - alias="colorize", + weights_location="https://huggingface.co/ioclab/control_v1p_sd15_brightness/resolve/8509361eb1ba89c03839040ed8c75e5f11bbd9c5/diffusion_pytorch_model.safetensors", ), ] -CONTROLNET_CONFIG_SHORTCUTS = {} -for m in CONTROLNET_CONFIGS: - if m.alias: - CONTROLNET_CONFIG_SHORTCUTS[m.alias] = m - -for m in CONTROLNET_CONFIGS: - CONTROLNET_CONFIG_SHORTCUTS[m.short_name] = m - -SAMPLER_TYPE_OPTIONS = [ - # "plms", - "ddim", - "k_dpmpp_2m" - # "k_dpm_fast", - # "k_dpm_adaptive", - # "k_lms", - # "k_dpm_2", - # "k_dpm_2_a", - # "k_dpmpp_2m", - # "k_dpmpp_2s_a", - # "k_euler", - # "k_euler_a", - # "k_heun", +CONTROL_CONFIG_SHORTCUTS: dict[str, ControlConfig] = {} +for m in CONTROL_CONFIGS: + for a in m.aliases: + CONTROL_CONFIG_SHORTCUTS[a] = m + + +@dataclass +class SolverConfig: + name: str + short_name: str + aliases: List[str] + papers: List[str] + implementations: List[str] + + +SOLVER_CONFIGS = [ + SolverConfig( + name="DDIM", + short_name="DDIM", + aliases=["ddim"], + papers=["https://arxiv.org/abs/2010.02502"], + implementations=[ + "https://github.com/ermongroup/ddim", + "https://github.com/Stability-AI/stablediffusion/blob/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/ldm/models/diffusion/ddim.py#L10", + "https://github.com/huggingface/diffusers/blob/76c645d3a641c879384afcb43496f0b7db8cc5cb/src/diffusers/schedulers/scheduling_ddim.py#L131", + ], + ), + SolverConfig( + name="DPM-Solver++", + short_name="DPMPP", + aliases=["dpmpp", "dpm++", "dpmsolver"], + papers=["https://arxiv.org/abs/2211.01095"], + implementations=[ + "https://github.com/LuChengTHU/dpm-solver/blob/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/dpm_solver_pytorch.py#L337", + "https://github.com/apple/ml-stable-diffusion/blob/7449ce46a4b23c94413b714704202e4ea4c55080/swift/StableDiffusion/pipeline/DPMSolverMultistepScheduler.swift#L27", + "https://github.com/crowsonkb/k-diffusion/blob/045515774882014cc14c1ba2668ab5bad9cbf7c0/k_diffusion/sampling.py#L509", + ], + ), ] + +SOLVER_TYPE_NAMES = [s.aliases[0] for s in SOLVER_CONFIGS] + +SOLVER_LOOKUP = {} +for s in SOLVER_CONFIGS: + for a in s.aliases: + SOLVER_LOOKUP[a.lower()] = s diff --git a/imaginairy/http_app/stablestudio/models.py b/imaginairy/http_app/stablestudio/models.py index 77cb39e..c295892 100644 --- a/imaginairy/http_app/stablestudio/models.py +++ b/imaginairy/http_app/stablestudio/models.py @@ -26,7 +26,7 @@ class StableStudioStyle(BaseModel): image: Optional[HttpUrl] = None -class StableStudioSampler(BaseModel): +class StableStudioSolver(BaseModel): id: str name: Optional[str] = None @@ -55,7 +55,7 @@ class StableStudioInput(BaseModel, extra=Extra.forbid): style: Optional[str] = None width: Optional[int] = None height: Optional[int] = None - sampler: Optional[StableStudioSampler] = None + solver: Optional[StableStudioSolver] = None cfg_scale: Optional[float] = Field(None, alias="cfgScale") steps: Optional[int] = None seed: Optional[int] = None @@ -88,14 +88,14 @@ class StableStudioInput(BaseModel, extra=Extra.forbid): mask_image = self.mask_image.blob if self.mask_image else None - sampler_type = self.sampler.id if self.sampler else None + solver_type = self.solver.id if self.solver else None return ImaginePrompt( prompt=positive_prompt, prompt_strength=self.cfg_scale, negative_prompt=negative_prompt, model=self.model, - sampler_type=sampler_type, + solver_type=solver_type, seed=self.seed, steps=self.steps, height=self.height, diff --git a/imaginairy/http_app/stablestudio/routes.py b/imaginairy/http_app/stablestudio/routes.py index 8aa3e5c..6999097 100644 --- a/imaginairy/http_app/stablestudio/routes.py +++ b/imaginairy/http_app/stablestudio/routes.py @@ -8,7 +8,7 @@ from imaginairy.http_app.stablestudio.models import ( StableStudioBatchResponse, StableStudioImage, StableStudioModel, - StableStudioSampler, + StableStudioSolver, ) from imaginairy.http_app.utils import generate_image_b64 @@ -37,11 +37,14 @@ async def generate(studio_request: StableStudioBatchRequest): @router.get("/samplers") async def list_samplers(): - from imaginairy.config import SAMPLER_TYPE_OPTIONS + from imaginairy.config import SOLVER_CONFIGS sampler_objs = [] - for sampler_type in SAMPLER_TYPE_OPTIONS: - sampler_obj = StableStudioSampler(id=sampler_type, name=sampler_type) + + for solver_config in SOLVER_CONFIGS: + sampler_obj = StableStudioSolver( + id=solver_config.aliases[0], name=solver_config.aliases[0] + ) sampler_objs.append(sampler_obj) return sampler_objs @@ -49,10 +52,10 @@ async def list_samplers(): @router.get("/models") async def list_models(): - from imaginairy.config import MODEL_CONFIGS + from imaginairy.config import MODEL_WEIGHT_CONFIGS model_objs = [] - for model_config in MODEL_CONFIGS: + for model_config in MODEL_WEIGHT_CONFIGS: if "inpaint" in model_config.description.lower(): continue model_obj = StableStudioModel( diff --git a/imaginairy/model_manager.py b/imaginairy/model_manager.py index a2a18c0..b097082 100644 --- a/imaginairy/model_manager.py +++ b/imaginairy/model_manager.py @@ -17,7 +17,7 @@ from refiners.foundationals.latent_diffusion import SD1ControlnetAdapter, SD1UNe from safetensors.torch import load_file from imaginairy import config as iconfig -from imaginairy.config import MODEL_SHORT_NAMES +from imaginairy.config import IMAGE_WEIGHTS_SHORT_NAMES, ModelArchitecture from imaginairy.modules import attention from imaginairy.paths import PKG_ROOT from imaginairy.utils import get_device, instantiate_from_config @@ -66,7 +66,7 @@ def load_state_dict(weights_location, half_mode=False, device=None): except FileNotFoundError as e: if e.errno == 2: logger.error( - f'Error: "{ckpt_path}" not a valid path to model weights.\nPreconfigured models you can use: {MODEL_SHORT_NAMES}.' + f'Error: "{ckpt_path}" not a valid path to model weights.\nPreconfigured models you can use: {IMAGE_WEIGHTS_SHORT_NAMES}.' ) sys.exit(1) raise @@ -149,7 +149,7 @@ def add_controlnet(base_state_dict, controlnet_state_dict): def get_diffusion_model( - weights_location=iconfig.DEFAULT_MODEL, + weights_location=iconfig.DEFAULT_MODEL_WEIGHTS, config_path="configs/stable-diffusion-v1.yaml", control_weights_locations=None, half_mode=None, @@ -174,7 +174,7 @@ def get_diffusion_model( f"Failed to load inpainting model. Attempting to fall-back to standard model. {e!s}" ) return _get_diffusion_model( - iconfig.DEFAULT_MODEL, + iconfig.DEFAULT_MODEL_WEIGHTS, config_path, half_mode, for_inpainting=False, @@ -184,8 +184,8 @@ def get_diffusion_model( def _get_diffusion_model( - weights_location=iconfig.DEFAULT_MODEL, - config_path="configs/stable-diffusion-v1.yaml", + weights_location=iconfig.DEFAULT_MODEL_WEIGHTS, + model_architecture="configs/stable-diffusion-v1.yaml", half_mode=None, for_inpainting=False, control_weights_locations=None, @@ -197,24 +197,20 @@ def _get_diffusion_model( """ global MOST_RECENTLY_LOADED_MODEL - ( - model_config, - weights_location, - config_path, - control_weights_locations, - ) = resolve_model_paths( - weights_path=weights_location, - config_path=config_path, - control_weights_paths=control_weights_locations, + model_weights_config = resolve_model_weights_config( + model_weights=weights_location, + default_model_architecture=model_architecture, for_inpainting=for_inpainting, ) # some models need the attention calculated in float32 - if model_config is not None: - attention.ATTENTION_PRECISION_OVERRIDE = model_config.forced_attn_precision + if model_weights_config is not None: + attention.ATTENTION_PRECISION_OVERRIDE = ( + model_weights_config.forced_attn_precision + ) else: attention.ATTENTION_PRECISION_OVERRIDE = "default" diffusion_model = _load_diffusion_model( - config_path=config_path, + config_path=model_weights_config.architecture.config_path, weights_location=weights_location, half_mode=half_mode, ) @@ -229,8 +225,8 @@ def _get_diffusion_model( def get_diffusion_model_refiners( - weights_location=iconfig.DEFAULT_MODEL, - config_path="configs/stable-diffusion-v1.yaml", + weights_location=iconfig.DEFAULT_MODEL_WEIGHTS, + model_architecture=None, control_weights_locations=None, dtype=None, for_inpainting=False, @@ -243,8 +239,8 @@ def get_diffusion_model_refiners( try: return _get_diffusion_model_refiners( weights_location, - config_path, - for_inpainting, + model_architecture=model_architecture, + for_inpainting=for_inpainting, dtype=dtype, control_weights_locations=control_weights_locations, ) @@ -254,8 +250,8 @@ def get_diffusion_model_refiners( f"Failed to load inpainting model. Attempting to fall-back to standard model. {e!s}" ) return _get_diffusion_model_refiners( - iconfig.DEFAULT_MODEL, - config_path, + iconfig.DEFAULT_MODEL_WEIGHTS, + model_architecture=model_architecture, dtype=dtype, for_inpainting=False, control_weights_locations=control_weights_locations, @@ -264,8 +260,8 @@ def get_diffusion_model_refiners( def _get_diffusion_model_refiners( - weights_location=iconfig.DEFAULT_MODEL, - config_path="configs/stable-diffusion-v1.yaml", + weights_location=iconfig.DEFAULT_MODEL_WEIGHTS, + model_architecture=None, for_inpainting=False, control_weights_locations=None, device=None, @@ -279,7 +275,7 @@ def _get_diffusion_model_refiners( sd = _get_diffusion_model_refiners_only( weights_location=weights_location, - config_path=config_path, + model_architecture=model_architecture, for_inpainting=for_inpainting, device=device, dtype=dtype, @@ -290,10 +286,9 @@ def _get_diffusion_model_refiners( @lru_cache(maxsize=1) def _get_diffusion_model_refiners_only( - weights_location=iconfig.DEFAULT_MODEL, - config_path="configs/stable-diffusion-v1.yaml", + weights_location=iconfig.DEFAULT_MODEL_WEIGHTS, + model_architecture=None, for_inpainting=False, - control_weights_locations=None, device=None, dtype=torch.float16, ): @@ -312,28 +307,17 @@ def _get_diffusion_model_refiners_only( device = device or get_device() - ( - model_config, - weights_location, - config_path, - control_weights_locations, - ) = resolve_model_paths( - weights_path=weights_location, - config_path=config_path, - control_weights_paths=control_weights_locations, + model_weights_config = resolve_model_weights_config( + model_weights=weights_location, + default_model_architecture=model_architecture, for_inpainting=for_inpainting, ) - # some models need the attention calculated in float32 - if model_config is not None: - attention.ATTENTION_PRECISION_OVERRIDE = model_config.forced_attn_precision - else: - attention.ATTENTION_PRECISION_OVERRIDE = "default" ( vae_weights, unet_weights, text_encoder_weights, - ) = load_stable_diffusion_compvis_weights(weights_location) + ) = load_stable_diffusion_compvis_weights(model_weights_config.weights_location) if for_inpainting: unet = SD1UNet(in_channels=9) @@ -422,58 +406,69 @@ def load_controlnet(control_weights_location, half_mode): return controlnet -def resolve_model_paths( - weights_path=iconfig.DEFAULT_MODEL, - config_path=None, - control_weights_paths=None, - for_inpainting=False, -): +def resolve_model_weights_config( + model_weights: str, + default_model_architecture: str | None = None, + for_inpainting: bool = False, +) -> iconfig.ModelWeightsConfig: """Resolve weight and config path if they happen to be shortcuts.""" - model_metadata_w = iconfig.MODEL_CONFIG_SHORTCUTS.get(weights_path, None) - model_metadata_c = iconfig.MODEL_CONFIG_SHORTCUTS.get(config_path, None) - - control_weights_paths = control_weights_paths or [] - control_net_metadatas = [ - iconfig.CONTROLNET_CONFIG_SHORTCUTS.get(control_weights_path, None) - for control_weights_path in control_weights_paths - ] - - if not control_net_metadatas and for_inpainting: - model_metadata_w = iconfig.MODEL_CONFIG_SHORTCUTS.get( - f"{weights_path}-inpaint", model_metadata_w + + if for_inpainting: + model_weights_config = iconfig.MODEL_WEIGHT_CONFIG_LOOKUP.get( + f"{model_weights.lower()}-inpaint", None + ) + if model_weights_config: + return model_weights_config + + model_weights_config = iconfig.MODEL_WEIGHT_CONFIG_LOOKUP.get( + model_weights.lower(), None + ) + if model_weights_config: + return model_weights_config + + if not default_model_architecture: + msg = "You must specify the model architecture when loading custom weights." + raise ValueError(msg) + + default_model_architecture = default_model_architecture.lower() + model_architecture_config = None + if for_inpainting: + model_architecture_config = iconfig.MODEL_ARCHITECTURE_LOOKUP.get( + f"{default_model_architecture}-inpaint", None ) - model_metadata_c = iconfig.MODEL_CONFIG_SHORTCUTS.get( - f"{config_path}-inpaint", model_metadata_c + + if not model_architecture_config: + model_architecture_config = iconfig.MODEL_ARCHITECTURE_LOOKUP.get( + default_model_architecture, None + ) + + if model_architecture_config is None: + msg = f"Invalid model architecture: {default_model_architecture}" + raise ValueError(msg) + + model_weights_config = iconfig.ModelWeightsConfig( + name="Custom Loaded", + aliases=[], + architecture=model_architecture_config, + weights_location=model_weights, + defaults={}, + ) + + return model_weights_config + + +def get_model_default_image_size(model_architecture: str | ModelArchitecture): + if isinstance(model_architecture, str): + model_architecture = iconfig.MODEL_WEIGHT_CONFIG_LOOKUP.get( + model_architecture, None ) + default_size = None + if model_architecture: + default_size = model_architecture.defaults.get("size") - if model_metadata_w: - if config_path is None: - config_path = model_metadata_w.config_path - - weights_path = model_metadata_w.weights_url - - if model_metadata_c: - config_path = model_metadata_c.config_path - - if config_path is None: - config_path = iconfig.MODEL_CONFIG_SHORTCUTS[iconfig.DEFAULT_MODEL].config_path - if control_net_metadatas: - if "stable-diffusion-v1" not in config_path: - msg = "Control net is only supported for stable diffusion v1. Please use a different model." - raise ValueError(msg) - control_weights_paths = [cnm.weights_url for cnm in control_net_metadatas] - config_path = control_net_metadatas[0].config_path - model_metadata = model_metadata_w or model_metadata_c - logger.debug(f"Loading model weights from: {weights_path}") - logger.debug(f"Loading model config from: {config_path}") - return model_metadata, weights_path, config_path, control_weights_paths - - -def get_model_default_image_size(weights_location): - model_config = iconfig.MODEL_CONFIG_SHORTCUTS.get(weights_location, None) - if model_config: - return model_config.default_image_size - return 512 + if default_size is None: + default_size = 512 + return default_size def get_current_diffusion_model(): diff --git a/imaginairy/samplers/__init__.py b/imaginairy/samplers/__init__.py index 229a5b3..825b6cd 100644 --- a/imaginairy/samplers/__init__.py +++ b/imaginairy/samplers/__init__.py @@ -1,21 +1,20 @@ -from imaginairy.samplers import kdiff -from imaginairy.samplers.base import SamplerName # noqa -from imaginairy.samplers.ddim import DDIMSampler +from imaginairy.samplers.base import SolverName # noqa +from imaginairy.samplers.ddim import DDIMSolver -SAMPLERS = [ - # PLMSSampler, - DDIMSampler, +SOLVERS = [ + # PLMSSolver, + DDIMSolver, # kdiff.DPMFastSampler, # kdiff.DPMAdaptiveSampler, # kdiff.LMSSampler, # kdiff.DPM2Sampler, # kdiff.DPM2AncestralSampler, - kdiff.DPMPP2MSampler, + # kdiff.DPMPP2MSampler, # kdiff.DPMPP2SAncestralSampler, # kdiff.EulerSampler, # kdiff.EulerAncestralSampler, # kdiff.HeunSampler, ] -SAMPLER_LOOKUP = {sampler.short_name: sampler for sampler in SAMPLERS} -SAMPLER_TYPE_OPTIONS = [sampler.short_name for sampler in SAMPLERS] +SOLVER_LOOKUP = {s.short_name: s for s in SOLVERS} +SOLVER_TYPE_OPTIONS = [s.short_name for s in SOLVERS] diff --git a/imaginairy/samplers/base.py b/imaginairy/samplers/base.py index 5f411df..e0203dc 100644 --- a/imaginairy/samplers/base.py +++ b/imaginairy/samplers/base.py @@ -16,9 +16,10 @@ from imaginairy.utils import get_device logger = logging.getLogger(__name__) -class SamplerName: +class SolverName: PLMS = "plms" DDIM = "ddim" + DPMPP = "dpmpp" K_DPM_FAST = "k_dpm_fast" K_DPM_ADAPTIVE = "k_dpm_adaptive" K_LMS = "k_lms" @@ -31,7 +32,7 @@ class SamplerName: K_HEUN = "k_heun" -class ImageSampler(ABC): +class ImageSolver(ABC): short_name: str name: str default_steps: int diff --git a/imaginairy/samplers/ddim.py b/imaginairy/samplers/ddim.py index 4405688..f3319e9 100644 --- a/imaginairy/samplers/ddim.py +++ b/imaginairy/samplers/ddim.py @@ -8,9 +8,9 @@ from tqdm import tqdm from imaginairy.log_utils import increment_step, log_latent from imaginairy.modules.diffusion.util import extract_into_tensor, noise_like from imaginairy.samplers.base import ( - ImageSampler, + ImageSolver, NoiseSchedule, - SamplerName, + SolverName, get_noise_prediction, mask_blend, ) @@ -19,14 +19,14 @@ from imaginairy.utils import get_device logger = logging.getLogger(__name__) -class DDIMSampler(ImageSampler): +class DDIMSolver(ImageSolver): """ Denoising Diffusion Implicit Models. https://arxiv.org/abs/2010.02502 """ - short_name = SamplerName.DDIM + short_name = SolverName.DDIM name = "Denoising Diffusion Implicit Models" default_steps = 50 diff --git a/imaginairy/samplers/kdiff.py b/imaginairy/samplers/kdiff.py index bbe8e7c..f00c14b 100644 --- a/imaginairy/samplers/kdiff.py +++ b/imaginairy/samplers/kdiff.py @@ -6,8 +6,8 @@ from torch import nn from imaginairy.log_utils import increment_step, log_latent from imaginairy.samplers.base import ( - ImageSampler, - SamplerName, + ImageSolver, + SolverName, get_noise_prediction, mask_blend, ) @@ -57,7 +57,7 @@ def sample_dpm_fast(model, x, sigmas, extra_args=None, disable=False, callback=N ) -class KDiffusionSampler(ImageSampler, ABC): +class KDiffusionSolver(ImageSolver, ABC): sampler_func: callable def __init__(self, model): @@ -98,9 +98,9 @@ class KDiffusionSampler(ImageSampler, ABC): # see https://github.com/crowsonkb/k-diffusion/issues/43#issuecomment-1305195666 if self.short_name in ( - SamplerName.K_DPM_2, - SamplerName.K_DPMPP_2M, - SamplerName.K_DPM_2_ANCESTRAL, + SolverName.K_DPM_2, + SolverName.K_DPMPP_2M, + SolverName.K_DPM_2_ANCESTRAL, ): sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) @@ -152,73 +152,73 @@ class KDiffusionSampler(ImageSampler, ABC): # -# class DPMFastSampler(KDiffusionSampler): -# short_name = SamplerName.K_DPM_FAST +# class DPMFastSampler(KDiffusionSolver): +# short_name = SolverName.K_DPM_FAST # name = "Diffusion probabilistic models - fast" # default_steps = 15 # sampler_func = staticmethod(sample_dpm_fast) # # -# class DPMAdaptiveSampler(KDiffusionSampler): -# short_name = SamplerName.K_DPM_ADAPTIVE +# class DPMAdaptiveSampler(KDiffusionSolver): +# short_name = SolverName.K_DPM_ADAPTIVE # name = "Diffusion probabilistic models - adaptive" # default_steps = 40 # sampler_func = staticmethod(sample_dpm_adaptive) # # -# class DPM2Sampler(KDiffusionSampler): -# short_name = SamplerName.K_DPM_2 +# class DPM2Sampler(KDiffusionSolver): +# short_name = SolverName.K_DPM_2 # name = "Diffusion probabilistic models - 2" # default_steps = 40 # sampler_func = staticmethod(k_sampling.sample_dpm_2) # # -# class DPM2AncestralSampler(KDiffusionSampler): -# short_name = SamplerName.K_DPM_2_ANCESTRAL +# class DPM2AncestralSampler(KDiffusionSolver): +# short_name = SolverName.K_DPM_2_ANCESTRAL # name = "Diffusion probabilistic models - 2 ancestral" # default_steps = 40 # sampler_func = staticmethod(k_sampling.sample_dpm_2_ancestral) # -class DPMPP2MSampler(KDiffusionSampler): - short_name = SamplerName.K_DPMPP_2M +class DPMPP2MSampler(KDiffusionSolver): + short_name = SolverName.K_DPMPP_2M name = "Diffusion probabilistic models - 2m" default_steps = 15 sampler_func = staticmethod(k_sampling.sample_dpmpp_2m) # -# class DPMPP2SAncestralSampler(KDiffusionSampler): -# short_name = SamplerName.K_DPMPP_2S_ANCESTRAL +# class DPMPP2SAncestralSampler(KDiffusionSolver): +# short_name = SolverName.K_DPMPP_2S_ANCESTRAL # name = "Ancestral sampling with DPM-Solver++(2S) second-order steps." # default_steps = 15 # sampler_func = staticmethod(k_sampling.sample_dpmpp_2s_ancestral) # # -# class EulerSampler(KDiffusionSampler): -# short_name = SamplerName.K_EULER +# class EulerSampler(KDiffusionSolver): +# short_name = SolverName.K_EULER # name = "Algorithm 2 (Euler steps) from Karras et al. (2022)" # default_steps = 40 # sampler_func = staticmethod(k_sampling.sample_euler) # # -# class EulerAncestralSampler(KDiffusionSampler): -# short_name = SamplerName.K_EULER_ANCESTRAL +# class EulerAncestralSampler(KDiffusionSolver): +# short_name = SolverName.K_EULER_ANCESTRAL # name = "Euler ancestral" # default_steps = 40 # sampler_func = staticmethod(k_sampling.sample_euler_ancestral) # # -# class HeunSampler(KDiffusionSampler): -# short_name = SamplerName.K_HEUN +# class HeunSampler(KDiffusionSolver): +# short_name = SolverName.K_HEUN # name = "Algorithm 2 (Heun steps) from Karras et al. (2022)." # default_steps = 40 # sampler_func = staticmethod(k_sampling.sample_heun) # # -# class LMSSampler(KDiffusionSampler): -# short_name = SamplerName.K_LMS +# class LMSSampler(KDiffusionSolver): +# short_name = SolverName.K_LMS # name = "LMS" # default_steps = 40 # sampler_func = staticmethod(k_sampling.sample_lms) diff --git a/imaginairy/samplers/plms.py b/imaginairy/samplers/plms.py index acc0ad3..6406e6e 100644 --- a/imaginairy/samplers/plms.py +++ b/imaginairy/samplers/plms.py @@ -8,9 +8,9 @@ from tqdm import tqdm from imaginairy.log_utils import increment_step, log_latent from imaginairy.modules.diffusion.util import extract_into_tensor, noise_like from imaginairy.samplers.base import ( - ImageSampler, + ImageSolver, NoiseSchedule, - SamplerName, + SolverName, get_noise_prediction, mask_blend, ) @@ -19,7 +19,7 @@ from imaginairy.utils import get_device logger = logging.getLogger(__name__) -class PLMSSampler(ImageSampler): +class PLMSSolver(ImageSolver): """ probabilistic least-mean-squares. @@ -29,7 +29,7 @@ class PLMSSampler(ImageSampler): https://github.com/luping-liu/PNDM """ - short_name = SamplerName.PLMS + short_name = SolverName.PLMS name = "probabilistic least-mean-squares sampler" default_steps = 40 diff --git a/imaginairy/schema.py b/imaginairy/schema.py index c871a69..64c93c2 100644 --- a/imaginairy/schema.py +++ b/imaginairy/schema.py @@ -7,43 +7,32 @@ import logging import os.path import random from datetime import datetime, timezone +from enum import Enum from io import BytesIO -from typing import TYPE_CHECKING, Any, List, Literal, Optional +from typing import TYPE_CHECKING, Any, List from pydantic import ( BaseModel, + ConfigDict, Field, GetCoreSchemaHandler, field_validator, model_validator, ) from pydantic_core import core_schema +from typing_extensions import Self from imaginairy import config if TYPE_CHECKING: + from pathlib import Path + from PIL import Image -else: - Image = Any logger = logging.getLogger(__name__) -def save_image_as_base64(image: "Image.Image") -> str: - buffered = io.BytesIO() - image.save(buffered, format="PNG") - img_bytes = buffered.getvalue() - return base64.b64encode(img_bytes).decode() - - -def load_image_from_base64(image_str: str) -> "Image.Image": - from PIL import Image - - img_bytes = base64.b64decode(image_str) - return Image.open(io.BytesIO(img_bytes)) - - class InvalidUrlError(ValueError): pass @@ -52,7 +41,12 @@ class LazyLoadingImage: """Image file encoded as base64 string.""" def __init__( - self, *, filepath=None, url=None, img: Image = None, b64: Optional[str] = None + self, + *, + filepath=None, + url=None, + img: "Image.Image" = None, + b64: str | None = None, ): if not filepath and not url and not img and not b64: msg = "You must specify a url or filepath or img or base64 string" @@ -208,10 +202,10 @@ class LazyLoadingImage: return f"" -class ControlNetInput(BaseModel): +class ControlInput(BaseModel): mode: str - image: Optional[LazyLoadingImage] = None - image_raw: Optional[LazyLoadingImage] = None + image: LazyLoadingImage | None = None + image_raw: LazyLoadingImage | None = None strength: float = Field(1, ge=0, le=1000) # @field_validator("image", "image_raw", mode="before") @@ -233,8 +227,8 @@ class ControlNetInput(BaseModel): @field_validator("mode") def mode_validate(cls, v): - if v not in config.CONTROLNET_CONFIG_SHORTCUTS: - valid_modes = list(config.CONTROLNET_CONFIG_SHORTCUTS.keys()) + if v not in config.CONTROL_CONFIG_SHORTCUTS: + valid_modes = list(config.CONTROL_CONFIG_SHORTCUTS.keys()) valid_modes = ", ".join(valid_modes) msg = f"Invalid controlnet mode: '{v}'. Valid modes are: {valid_modes}" raise ValueError(msg) @@ -249,43 +243,51 @@ class WeightedPrompt(BaseModel): return f"{self.weight}*({self.text})" +class MaskMode(str, Enum): + REPLACE = "replace" + KEEP = "keep" + + +MaskInput = MaskMode | str +PromptInput = str | WeightedPrompt | list[WeightedPrompt] | list[str] | None + + class ImaginePrompt(BaseModel, protected_namespaces=()): - prompt: Optional[List[WeightedPrompt]] = Field(default=None, validate_default=True) - negative_prompt: Optional[List[WeightedPrompt]] = Field( - default=None, validate_default=True - ) - prompt_strength: Optional[float] = Field( - default=7.5, le=10_000, ge=-10_000, validate_default=True - ) - init_image: Optional[LazyLoadingImage] = Field( + model_config = ConfigDict(extra="forbid", validate_assignment=True) + + prompt: List[WeightedPrompt] = Field(default=None, validate_default=True) + negative_prompt: List[WeightedPrompt] = Field(default=None, validate_default=True) + prompt_strength: float = Field(default=7.5, le=50, ge=-50, validate_default=True) + init_image: LazyLoadingImage | None = Field( None, description="base64 encoded image", validate_default=True ) - init_image_strength: Optional[float] = Field( + init_image_strength: float | None = Field( ge=0, le=1, default=None, validate_default=True ) - control_inputs: List[ControlNetInput] = Field( + control_inputs: List[ControlInput] = Field( default_factory=list, validate_default=True ) - mask_prompt: Optional[str] = Field( + mask_prompt: str | None = Field( default=None, description="text description of the things to be masked", validate_default=True, ) - mask_image: Optional[LazyLoadingImage] = Field(default=None, validate_default=True) - mask_mode: Optional[Literal["keep", "replace"]] = "replace" + mask_image: LazyLoadingImage | None = Field(default=None, validate_default=True) + mask_mode: MaskMode = MaskMode.REPLACE mask_modify_original: bool = True - outpaint: Optional[str] = "" - model: str = Field(default=config.DEFAULT_MODEL, validate_default=True) - model_config_path: Optional[str] = None - sampler_type: str = Field(default=config.DEFAULT_SAMPLER, validate_default=True) - seed: Optional[int] = Field(default=None, validate_default=True) - steps: Optional[int] = Field(default=None, validate_default=True) - height: Optional[int] = Field(None, ge=1, le=100_000, validate_default=True) - width: Optional[int] = Field(None, ge=1, le=100_000, validate_default=True) + outpaint: str | None = "" + model_architecture: str | None = None + model_weights: str = Field( + default=config.DEFAULT_MODEL_WEIGHTS, validate_default=True + ) + solver_type: str = Field(default=config.DEFAULT_SOLVER, validate_default=True) + seed: int | None = Field(default=None, validate_default=True) + steps: int | None = Field(default=None, validate_default=True) + size: tuple[int, int] | None = Field(default=None, validate_default=True) upscale: bool = False fix_faces: bool = False - fix_faces_fidelity: Optional[float] = Field(0.2, ge=0, le=1, validate_default=True) - conditioning: Optional[str] = None + fix_faces_fidelity: float | None = Field(0.2, ge=0, le=1, validate_default=True) + conditioning: str | None = None tile_mode: str = "" allow_compose_phase: bool = True is_intermediate: bool = False @@ -294,22 +296,87 @@ class ImaginePrompt(BaseModel, protected_namespaces=()): "", description="text to be overlaid on the image", validate_default=True ) - class MaskMode: - REPLACE = "replace" - KEEP = "keep" - - def __init__(self, prompt=None, **kwargs): - # allows `prompt` to be positional - super().__init__(prompt=prompt, **kwargs) + def __init__( + self, + prompt: PromptInput = "", + *, + negative_prompt: PromptInput = None, + prompt_strength: float | None = 7.5, + init_image: LazyLoadingImage | None = None, + init_image_strength: float | None = None, + control_inputs: List[ControlInput] | None = None, + mask_prompt: str | None = None, + mask_image: LazyLoadingImage | None = None, + mask_mode: MaskInput = MaskMode.REPLACE, + mask_modify_original: bool = True, + outpaint: str | None = "", + model_architecture: str | None = None, + model_weights: str = config.DEFAULT_MODEL_WEIGHTS, + solver_type: str = config.DEFAULT_SOLVER, + seed: int | None = None, + steps: int | None = None, + size: int | str | tuple[int, int] | None = None, + upscale: bool = False, + fix_faces: bool = False, + fix_faces_fidelity: float | None = 0.2, + conditioning: str | None = None, + tile_mode: str = "", + allow_compose_phase: bool = True, + is_intermediate: bool = False, + collect_progress_latents: bool = False, + caption_text: str = "", + ): + super().__init__( + prompt=prompt, + negative_prompt=negative_prompt, + prompt_strength=prompt_strength, + init_image=init_image, + init_image_strength=init_image_strength, + control_inputs=control_inputs, + mask_prompt=mask_prompt, + mask_image=mask_image, + mask_mode=mask_mode, + mask_modify_original=mask_modify_original, + outpaint=outpaint, + model_architecture=model_architecture, + model_weights=model_weights, + solver_type=solver_type, + seed=seed, + steps=steps, + size=size, + upscale=upscale, + fix_faces=fix_faces, + fix_faces_fidelity=fix_faces_fidelity, + conditioning=conditioning, + tile_mode=tile_mode, + allow_compose_phase=allow_compose_phase, + is_intermediate=is_intermediate, + collect_progress_latents=collect_progress_latents, + caption_text=caption_text, + ) @field_validator("prompt", "negative_prompt", mode="before") - @classmethod - def make_into_weighted_prompts(cls, v): - if isinstance(v, str): - v = [WeightedPrompt(text=v)] - elif isinstance(v, WeightedPrompt): - v = [v] - return v + def make_into_weighted_prompts( + cls, + value: PromptInput, + ) -> list[WeightedPrompt]: + match value: + case None: + return [] + + case str(): + if value: + return [WeightedPrompt(text=value)] + else: + return [] + case WeightedPrompt(): + return [value] + case list(): + if all(isinstance(item, str) for item in value): + return [WeightedPrompt(text=p) for p in value] + elif all(isinstance(item, WeightedPrompt) for item in value): + return value + raise ValueError("Invalid prompt input") @field_validator("prompt", "negative_prompt", mode="after") @classmethod @@ -328,16 +395,20 @@ class ImaginePrompt(BaseModel, protected_namespaces=()): @model_validator(mode="after") def validate_negative_prompt(self): - if self.negative_prompt is None: - model_config = config.MODEL_CONFIG_SHORTCUTS.get(self.model, None) - if model_config: - self.negative_prompt = [ - WeightedPrompt(text=model_config.default_negative_prompt) - ] - else: - self.negative_prompt = [ - WeightedPrompt(text=config.DEFAULT_NEGATIVE_PROMPT) - ] + if ( + self.negative_prompt == [WeightedPrompt(text="")] + or self.negative_prompt == [] + ): + model_weight_config = config.MODEL_WEIGHT_CONFIG_LOOKUP.get( + self.model_weights, None + ) + default_negative_prompt = config.DEFAULT_NEGATIVE_PROMPT + if model_weight_config: + default_negative_prompt = model_weight_config.defaults.get( + "negative_prompt", default_negative_prompt + ) + + self.negative_prompt = [WeightedPrompt(text=default_negative_prompt)] return self @field_validator("prompt_strength") @@ -426,10 +497,10 @@ class ImaginePrompt(BaseModel, protected_namespaces=()): raise ValueError(msg) return v - @field_validator("model", mode="before") + @field_validator("model_weights", mode="before") def set_default_diffusion_model(cls, v): if v is None: - return config.DEFAULT_MODEL + return config.DEFAULT_MODEL_WEIGHTS return v @@ -444,33 +515,32 @@ class ImaginePrompt(BaseModel, protected_namespaces=()): return v - @field_validator("sampler_type", mode="after") - def validate_sampler_type(cls, v, info: core_schema.FieldValidationInfo): - from imaginairy.samplers import SamplerName + @field_validator("solver_type", mode="after") + def validate_solver_type(cls, v, info: core_schema.FieldValidationInfo): + from imaginairy.samplers import SolverName if v is None: - v = config.DEFAULT_SAMPLER + v = config.DEFAULT_SOLVER v = v.lower() - if info.data.get("model") == "SD-2.0-v" and v == SamplerName.PLMS: - raise ValueError("PLMS sampler is not supported for SD-2.0-v model.") + if info.data.get("model") == "SD-2.0-v" and v == SolverName.PLMS: + raise ValueError("PLMS solvers is not supported for SD-2.0-v model.") if info.data.get("model") == "edit" and v in ( - SamplerName.PLMS, - SamplerName.DDIM, + SolverName.PLMS, + SolverName.DDIM, ): - msg = "PLMS and DDIM samplers are not supported for pix2pix edit model." + msg = "PLMS and DDIM solvers are not supported for pix2pix edit model." raise ValueError(msg) return v @field_validator("steps") def validate_steps(cls, v, info: core_schema.FieldValidationInfo): - from imaginairy.samplers import SAMPLER_LOOKUP + steps_lookup = {"ddim": 50, "dpmpp": 20} if v is None: - SamplerCls = SAMPLER_LOOKUP[info.data["sampler_type"]] - v = SamplerCls.default_steps + v = steps_lookup[info.data["solver_type"]] return int(v) @@ -486,14 +556,26 @@ class ImaginePrompt(BaseModel, protected_namespaces=()): return self - @field_validator("height", "width") + @field_validator("size", mode="before") def validate_image_size(cls, v, info: core_schema.FieldValidationInfo): from imaginairy.model_manager import get_model_default_image_size + from imaginairy.utils.named_resolutions import normalize_image_size if v is None: - v = get_model_default_image_size(info.data["model"]) + v = get_model_default_image_size(info.data["model_architecture"]) - return v + width, height = normalize_image_size(v) + min_size = 8 + max_size = 100_000 + if not min_size <= width <= max_size: + msg = f"Width must be between {min_size} and {max_size}. Got: {width}" + raise ValueError(msg) + + if not min_size <= height <= max_size: + msg = f"Height must be between {min_size} and {max_size}. Got: {height}" + raise ValueError(msg) + + return width, height @field_validator("caption_text", mode="before") def validate_caption_text(cls, v): @@ -507,7 +589,7 @@ class ImaginePrompt(BaseModel, protected_namespaces=()): return self.prompt @property - def prompt_text(self): + def prompt_text(self) -> str: if not self.prompt: return "" if len(self.prompt) == 1: @@ -515,18 +597,31 @@ class ImaginePrompt(BaseModel, protected_namespaces=()): return "|".join(str(p) for p in self.prompt) @property - def negative_prompt_text(self): + def negative_prompt_text(self) -> str: if not self.negative_prompt: return "" if len(self.negative_prompt) == 1: return self.negative_prompt[0].text return "|".join(str(p) for p in self.negative_prompt) + @property + def width(self) -> int: + return self.size[0] + + @property + def height(self) -> int: + return self.size[1] + def prompt_description(self): return ( f'"{self.prompt_text}" {self.width}x{self.height}px ' f'negative-prompt:"{self.negative_prompt_text}" ' - f"seed:{self.seed} prompt-strength:{self.prompt_strength} steps:{self.steps} sampler-type:{self.sampler_type} init-image-strength:{self.init_image_strength} model:{self.model}" + f"seed:{self.seed} " + f"prompt-strength:{self.prompt_strength} " + f"steps:{self.steps} solver-type:{self.solver_type} " + f"init-image-strength:{self.init_image_strength} " + f"arch:{self.model_architecture} " + f"weights: {self.model_weights}" ) def logging_dict(self): @@ -547,7 +642,7 @@ class ImaginePrompt(BaseModel, protected_namespaces=()): new_prompt = new_prompt.model_validate(dict(new_prompt)) return new_prompt - def make_concrete_copy(self): + def make_concrete_copy(self) -> Self: seed = self.seed if self.seed is not None else random.randint(1, 1_000_000_000) return self.full_copy( deep=False, @@ -574,10 +669,6 @@ class ImagineResult: prompt: ImaginePrompt, is_nsfw, safety_score, - upscaled_img=None, - modified_original=None, - mask_binary=None, - mask_grayscale=None, result_images=None, timings=None, progress_latents=None, @@ -594,20 +685,10 @@ class ImagineResult: self.images = {"generated": img} - if upscaled_img: - self.images["upscaled"] = upscaled_img - - if modified_original: - self.images["modified_original"] = modified_original - - if mask_binary: - self.images["mask_binary"] = mask_binary - - if mask_grayscale: - self.images["mask_grayscale"] = mask_grayscale - if result_images: for img_type, r_img in result_images.items(): + if r_img is None: + continue if isinstance(r_img, torch.Tensor): if r_img.shape[1] == 4: r_img = model_latent_to_pillow_img(r_img) @@ -620,7 +701,6 @@ class ImagineResult: # for backward compat self.img = img - self.upscaled_img = upscaled_img self.is_nsfw = is_nsfw self.safety_score = safety_score @@ -628,7 +708,7 @@ class ImagineResult: self.torch_backend = get_device() self.hardware_name = get_hardware_description(get_device()) - def md5(self): + def md5(self) -> str: return hashlib.md5(self.img.tobytes()).hexdigest() def metadata_dict(self): @@ -636,19 +716,19 @@ class ImagineResult: "prompt": self.prompt.logging_dict(), } - def timings_str(self): + def timings_str(self) -> str: if not self.timings: return "" return " ".join(f"{k}:{v:.2f}s" for k, v in self.timings.items()) - def _exif(self): + def _exif(self) -> "Image.Exif": from PIL import Image exif = Image.Exif() exif[ExifCodes.ImageDescription] = self.prompt.prompt_description() exif[ExifCodes.UserComment] = json.dumps(self.metadata_dict()) # help future web scrapes not ingest AI generated art - sd_version = self.prompt.model + sd_version = self.prompt.model_weights if len(sd_version) > 20: sd_version = "custom weights" exif[ExifCodes.Software] = f"Imaginairy / Stable Diffusion {sd_version}" @@ -656,7 +736,7 @@ class ImagineResult: exif[ExifCodes.HostComputer] = f"{self.torch_backend}:{self.hardware_name}" return exif - def save(self, save_path, image_type="generated"): + def save(self, save_path: "Path | str", image_type: str = "generated") -> None: img = self.images.get(image_type, None) if img is None: msg = f"Image of type {image_type} not stored. Options are: {self.images.keys()}" @@ -665,6 +745,6 @@ class ImagineResult: img.convert("RGB").save(save_path, exif=self._exif()) -class SafetyMode: +class SafetyMode(str, Enum): STRICT = "strict" RELAXED = "relaxed" diff --git a/imaginairy/surprise_me.py b/imaginairy/surprise_me.py index b82fee7..01752bc 100644 --- a/imaginairy/surprise_me.py +++ b/imaginairy/surprise_me.py @@ -10,7 +10,7 @@ from imaginairy import ImaginePrompt, LazyLoadingImage, imagine_image_files from imaginairy.animations import make_gif_animation from imaginairy.enhancers.facecrop import detect_faces from imaginairy.img_utils import add_caption_to_image, pillow_fit_image_within -from imaginairy.schema import ControlNetInput +from imaginairy.schema import ControlInput preserve_head_kwargs = { "mask_prompt": "head|face", @@ -142,7 +142,7 @@ def surprise_me_prompts( for prompt_text, strength, kwargs in generic_prompts: if use_controlnet: strength = 5 - control_input = ControlNetInput(mode="edit", strength=2) + control_input = ControlInput(mode="edit", strength=2) prompts.append( ImaginePrompt( prompt_text, @@ -163,7 +163,7 @@ def surprise_me_prompts( prompt_text, init_image=img, prompt_strength=strength, - model="edit", + model_weights="edit", steps=steps, width=width, height=height, @@ -178,7 +178,7 @@ def surprise_me_prompts( for prompt_subconfig in prompt_subconfigs: prompt_text, strength, kwargs = prompt_subconfig if use_controlnet: - control_input = ControlNetInput( + control_input = ControlInput( mode="edit", ) prompts.append( @@ -201,7 +201,7 @@ def surprise_me_prompts( prompt_text, init_image=img, prompt_strength=strength, - model="edit", + model_weights="edit", steps=steps, width=width, height=height, diff --git a/imaginairy/train.py b/imaginairy/train.py new file mode 100644 index 0000000..131cf60 --- /dev/null +++ b/imaginairy/train.py @@ -0,0 +1,533 @@ +import datetime +import logging +import os +import signal +import time +from functools import partial + +import numpy as np +import pytorch_lightning as pl +import torch +import torchvision +from omegaconf import OmegaConf +from PIL import Image +from pytorch_lightning import seed_everything +from pytorch_lightning.callbacks import Callback, LearningRateMonitor + +try: + from pytorch_lightning.strategies import DDPStrategy +except ImportError: + # let's not break all of imaginairy just because a training import doesn't exist in an older version of PL + # Use >= 1.6.0 to make this work + DDPStrategy = None +import contextlib + +from pytorch_lightning.trainer import Trainer +from pytorch_lightning.utilities import rank_zero_info +from pytorch_lightning.utilities.distributed import rank_zero_only +from torch.utils.data import DataLoader, Dataset + +from imaginairy import config +from imaginairy.model_manager import get_diffusion_model +from imaginairy.training_tools.single_concept import SingleConceptDataset +from imaginairy.utils import get_device, instantiate_from_config + +mod_logger = logging.getLogger(__name__) + +referenced_by_string = [LearningRateMonitor] + + +class WrappedDataset(Dataset): + """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset.""" + + def __init__(self, dataset): + self.data = dataset + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + return self.data[idx] + + +def worker_init_fn(_): + worker_info = torch.utils.data.get_worker_info() + + dataset = worker_info.dataset + worker_id = worker_info.id + + if isinstance(dataset, SingleConceptDataset): + # split_size = dataset.num_records // worker_info.num_workers + # reset num_records to the true number to retain reliable length information + # dataset.sample_ids = dataset.valid_ids[ + # worker_id * split_size : (worker_id + 1) * split_size + # ] + current_id = np.random.choice(len(np.random.get_state()[1]), 1) + return np.random.seed(np.random.get_state()[1][current_id] + worker_id) + + return np.random.seed(np.random.get_state()[1][0] + worker_id) + + +class DataModuleFromConfig(pl.LightningDataModule): + def __init__( + self, + batch_size, + train=None, + validation=None, + test=None, + predict=None, + wrap=False, + num_workers=None, + shuffle_test_loader=False, + use_worker_init_fn=False, + shuffle_val_dataloader=False, + num_val_workers=0, + ): + super().__init__() + self.batch_size = batch_size + self.dataset_configs = {} + self.num_workers = num_workers if num_workers is not None else batch_size * 2 + if num_val_workers is None: + self.num_val_workers = self.num_workers + else: + self.num_val_workers = num_val_workers + self.use_worker_init_fn = use_worker_init_fn + if train is not None: + self.dataset_configs["train"] = train + self.train_dataloader = self._train_dataloader + if validation is not None: + self.dataset_configs["validation"] = validation + self.val_dataloader = partial( + self._val_dataloader, shuffle=shuffle_val_dataloader + ) + if test is not None: + self.dataset_configs["test"] = test + self.test_dataloader = partial( + self._test_dataloader, shuffle=shuffle_test_loader + ) + if predict is not None: + self.dataset_configs["predict"] = predict + self.predict_dataloader = self._predict_dataloader + self.wrap = wrap + self.datasets = None + + def prepare_data(self): + for data_cfg in self.dataset_configs.values(): + instantiate_from_config(data_cfg) + + def setup(self, stage=None): + self.datasets = { + k: instantiate_from_config(c) for k, c in self.dataset_configs.items() + } + if self.wrap: + self.datasets = {k: WrappedDataset(v) for k, v in self.datasets.items()} + + def _train_dataloader(self): + is_iterable_dataset = isinstance(self.datasets["train"], SingleConceptDataset) + if is_iterable_dataset or self.use_worker_init_fn: + pass + else: + pass + return DataLoader( + self.datasets["train"], + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + worker_init_fn=worker_init_fn, + ) + + def _val_dataloader(self, shuffle=False): + if ( + isinstance(self.datasets["validation"], SingleConceptDataset) + or self.use_worker_init_fn + ): + init_fn = worker_init_fn + else: + init_fn = None + return DataLoader( + self.datasets["validation"], + batch_size=self.batch_size, + num_workers=self.num_val_workers, + worker_init_fn=init_fn, + shuffle=shuffle, + ) + + def _test_dataloader(self, shuffle=False): + is_iterable_dataset = isinstance(self.datasets["train"], SingleConceptDataset) + if is_iterable_dataset or self.use_worker_init_fn: + init_fn = worker_init_fn + else: + init_fn = None + is_iterable_dataset = False + + # do not shuffle dataloader for iterable dataset + shuffle = shuffle and (not is_iterable_dataset) + + return DataLoader( + self.datasets["test"], + batch_size=self.batch_size, + num_workers=self.num_workers, + worker_init_fn=init_fn, + shuffle=shuffle, + ) + + def _predict_dataloader(self, shuffle=False): + if ( + isinstance(self.datasets["predict"], SingleConceptDataset) + or self.use_worker_init_fn + ): + init_fn = worker_init_fn + else: + init_fn = None + return DataLoader( + self.datasets["predict"], + batch_size=self.batch_size, + num_workers=self.num_workers, + worker_init_fn=init_fn, + ) + + +class SetupCallback(Callback): + def __init__( + self, + resume, + now, + logdir, + ckptdir, + cfgdir, + ): + super().__init__() + self.resume = resume + self.now = now + self.logdir = logdir + self.ckptdir = ckptdir + self.cfgdir = cfgdir + + def on_keyboard_interrupt(self, trainer, pl_module): + if trainer.global_rank == 0: + mod_logger.info("Stopping execution and saving final checkpoint.") + ckpt_path = os.path.join(self.ckptdir, "last.ckpt") + trainer.save_checkpoint(ckpt_path) + + def on_fit_start(self, trainer, pl_module): + if trainer.global_rank == 0: + # Create logdirs and save configs + os.makedirs(self.logdir, exist_ok=True) + os.makedirs(self.ckptdir, exist_ok=True) + os.makedirs(self.cfgdir, exist_ok=True) + + else: + # ModelCheckpoint callback created log directory --- remove it + if not self.resume and os.path.exists(self.logdir): + dst, name = os.path.split(self.logdir) + dst = os.path.join(dst, "child_runs", name) + os.makedirs(os.path.split(dst)[0], exist_ok=True) + with contextlib.suppress(FileNotFoundError): + os.rename(self.logdir, dst) + + +class ImageLogger(Callback): + def __init__( + self, + batch_frequency, + max_images, + clamp=True, + increase_log_steps=True, + rescale=True, + disabled=False, + log_on_batch_idx=False, + log_first_step=False, + log_images_kwargs=None, + log_all_val=False, + concept_label=None, + ): + super().__init__() + self.rescale = rescale + self.batch_freq = batch_frequency + self.max_images = max_images + self.logger_log_images = {} + self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)] + if not increase_log_steps: + self.log_steps = [self.batch_freq] + self.clamp = clamp + self.disabled = disabled + self.log_on_batch_idx = log_on_batch_idx + self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} + self.log_first_step = log_first_step + self.log_all_val = log_all_val + self.concept_label = concept_label + + @rank_zero_only + def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx): + root = os.path.join(save_dir, "logs", "images", split) + for k in images: + grid = torchvision.utils.make_grid(images[k], nrow=4) + if self.rescale: + grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w + grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) + grid = grid.numpy() + grid = (grid * 255).astype(np.uint8) + filename = ( + f"{k}_gs-{global_step:06}_e-{current_epoch:06}_b-{batch_idx:06}.png" + ) + path = os.path.join(root, filename) + os.makedirs(os.path.split(path)[0], exist_ok=True) + Image.fromarray(grid).save(path) + + def log_img(self, pl_module, batch, batch_idx, split="train"): + # always generate the concept label + batch["txt"][0] = self.concept_label + + check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step + if self.log_all_val and split == "val": + should_log = True + else: + should_log = self.check_frequency(check_idx) + if ( + should_log + and (batch_idx % self.batch_freq == 0) + and hasattr(pl_module, "log_images") + and callable(pl_module.log_images) + and self.max_images > 0 + ): + logger = type(pl_module.logger) + + is_train = pl_module.training + if is_train: + pl_module.eval() + + with torch.no_grad(): + images = pl_module.log_images( + batch, split=split, **self.log_images_kwargs + ) + + for k in images: + N = min(images[k].shape[0], self.max_images) + images[k] = images[k][:N] + if isinstance(images[k], torch.Tensor): + images[k] = images[k].detach().cpu() + if self.clamp: + images[k] = torch.clamp(images[k], -1.0, 1.0) + + self.log_local( + pl_module.logger.save_dir, + split, + images, + pl_module.global_step, + pl_module.current_epoch, + batch_idx, + ) + + logger_log_images = self.logger_log_images.get( + logger, lambda *args, **kwargs: None + ) + logger_log_images(pl_module, images, pl_module.global_step, split) + + if is_train: + pl_module.train() + + def check_frequency(self, check_idx): + if (check_idx % self.batch_freq) == 0 and ( + check_idx > 0 or self.log_first_step + ): + return True + return False + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + if not self.disabled and (pl_module.global_step > 0 or self.log_first_step): + self.log_img(pl_module, batch, batch_idx, split="train") + + def on_validation_batch_end( + self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx + ): + if not self.disabled and pl_module.global_step > 0: + self.log_img(pl_module, batch, batch_idx, split="val") + if ( + hasattr(pl_module, "calibrate_grad_norm") + and (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) + and batch_idx > 0 + ): + self.log_gradients(trainer, pl_module, batch_idx=batch_idx) + + +class CUDACallback(Callback): + # see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py + def on_train_epoch_start(self, trainer, pl_module): + # Reset the memory use counter + if "cuda" in get_device(): + torch.cuda.reset_peak_memory_stats(trainer.strategy.root_device.index) + torch.cuda.synchronize(trainer.strategy.root_device.index) + self.start_time = time.time() + + def on_train_epoch_end(self, trainer, pl_module): + if "cuda" in get_device(): + torch.cuda.synchronize(trainer.strategy.root_device.index) + max_memory = ( + torch.cuda.max_memory_allocated(trainer.strategy.root_device.index) + / 2**20 + ) + epoch_time = time.time() - self.start_time + + try: + max_memory = trainer.training_type_plugin.reduce(max_memory) + epoch_time = trainer.training_type_plugin.reduce(epoch_time) + + rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds") + rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB") + except AttributeError: + pass + + +def train_diffusion_model( + concept_label, + concept_images_dir, + class_label, + class_images_dir, + weights_location=config.DEFAULT_MODEL_WEIGHTS, + logdir="logs", + learning_rate=1e-6, + accumulate_grad_batches=32, + resume=None, +): + """ + Train a diffusion model on a single concept. + + accumulate_grad_batches used to simulate a bigger batch size - https://arxiv.org/pdf/1711.00489.pdf + """ + if DDPStrategy is None: + msg = "Please install pytorch-lightning>=1.6.0 to train a model" + raise ImportError(msg) + + batch_size = 1 + seed = 23 + num_workers = 1 + num_val_workers = 0 + now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") # noqa: DTZ005 + logdir = os.path.join(logdir, now) + + ckpt_output_dir = os.path.join(logdir, "checkpoints") + cfg_output_dir = os.path.join(logdir, "configs") + seed_everything(seed) + model = get_diffusion_model( + weights_location=weights_location, half_mode=False, for_training=True + )._model + model.learning_rate = learning_rate * accumulate_grad_batches * batch_size + + # add callback which sets up log directory + default_callbacks_cfg = { + "setup_callback": { + "target": "imaginairy.train.SetupCallback", + "params": { + "resume": False, + "now": now, + "logdir": logdir, + "ckptdir": ckpt_output_dir, + "cfgdir": cfg_output_dir, + }, + }, + "image_logger": { + "target": "imaginairy.train.ImageLogger", + "params": { + "batch_frequency": 10, + "max_images": 1, + "clamp": True, + "increase_log_steps": False, + "log_first_step": True, + "log_all_val": True, + "concept_label": concept_label, + "log_images_kwargs": { + "use_ema_scope": True, + "inpaint": False, + "plot_progressive_rows": False, + "plot_diffusion_rows": False, + "N": 1, + "unconditional_guidance_scale:": 7.5, + "unconditional_guidance_label": [""], + "ddim_steps": 20, + }, + }, + }, + "learning_rate_logger": { + "target": "imaginairy.train.LearningRateMonitor", + "params": { + "logging_interval": "step", + # "log_momentum": True + }, + }, + "cuda_callback": {"target": "imaginairy.train.CUDACallback"}, + } + + default_modelckpt_cfg = { + "target": "pytorch_lightning.callbacks.ModelCheckpoint", + "params": { + "dirpath": ckpt_output_dir, + "filename": "{epoch:06}", + "verbose": True, + "save_last": True, + "every_n_train_steps": 50, + "save_top_k": -1, + "monitor": None, + }, + } + + modelckpt_cfg = OmegaConf.create(default_modelckpt_cfg) + default_callbacks_cfg.update({"checkpoint_callback": modelckpt_cfg}) + + callbacks_cfg = OmegaConf.create(default_callbacks_cfg) + + dataset_config = { + "concept_label": concept_label, + "concept_images_dir": concept_images_dir, + "class_label": class_label, + "class_images_dir": class_images_dir, + "image_transforms": [ + { + "target": "torchvision.transforms.Resize", + "params": {"size": 512, "interpolation": 3}, + }, + {"target": "torchvision.transforms.RandomCrop", "params": {"size": 512}}, + ], + } + + data_module_config = { + "batch_size": batch_size, + "num_workers": num_workers, + "num_val_workers": num_val_workers, + "train": { + "target": "imaginairy.training_tools.single_concept.SingleConceptDataset", + "params": dataset_config, + }, + } + trainer = Trainer( + benchmark=True, + num_sanity_val_steps=0, + accumulate_grad_batches=accumulate_grad_batches, + strategy=DDPStrategy(), + callbacks=[instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg], + gpus=1, + default_root_dir=".", + ) + trainer.logdir = logdir + + data = DataModuleFromConfig(**data_module_config) + # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html + # calling these ourselves should not be necessary but it is. + # lightning still takes care of proper multiprocessing though + data.prepare_data() + data.setup() + + def melk(*args, **kwargs): + if trainer.global_rank == 0: + mod_logger.info("Summoning checkpoint.") + ckpt_path = os.path.join(ckpt_output_dir, "last.ckpt") + trainer.save_checkpoint(ckpt_path) + + signal.signal(signal.SIGUSR1, melk) + try: + try: + trainer.fit(model, data) + except Exception: + melk() + raise + finally: + mod_logger.info(trainer.profiler.summary()) diff --git a/imaginairy/utils/data_distorter.py b/imaginairy/utils/data_distorter.py index 3d133a3..a55f6ca 100644 --- a/imaginairy/utils/data_distorter.py +++ b/imaginairy/utils/data_distorter.py @@ -170,7 +170,21 @@ def replace_value_at_path(data, path, new_value): parent = get_path(data, path[:-1]) last_key = path[-1] if new_value == NODE_DELETE: - del parent[last_key] + if isinstance(parent, tuple): + grandparent = get_path(data, path[:-2]) + grandparent_key = path[-2] + new_parent = list(parent) + del new_parent[last_key] + grandparent[grandparent_key] = tuple(new_parent) + else: + del parent[last_key] else: - parent[last_key] = new_value + if isinstance(parent, tuple): + grandparent = get_path(data, path[:-2]) + grandparent_key = path[-2] + new_parent = list(parent) + new_parent[last_key] = new_value + grandparent[grandparent_key] = tuple(new_parent) + else: + parent[last_key] = new_value return data diff --git a/imaginairy/utils/named_resolutions.py b/imaginairy/utils/named_resolutions.py index 8e77cda..dd19072 100644 --- a/imaginairy/utils/named_resolutions.py +++ b/imaginairy/utils/named_resolutions.py @@ -43,27 +43,37 @@ _NAMED_RESOLUTIONS = { "SVD": (1024, 576), # stable video diffusion } +_NAMED_RESOLUTIONS = {k.upper(): v for k, v in _NAMED_RESOLUTIONS.items()} -def get_named_resolution(resolution: str): - resolution = resolution.upper() - size = _NAMED_RESOLUTIONS.get(resolution) - - if size is None: - # is it WIDTHxHEIGHT format? - try: - width, height = resolution.split("X") - size = (int(width), int(height)) - except ValueError: - pass - - if size is None: - # is it just a single number? - with contextlib.suppress(ValueError): - size = (int(resolution), int(resolution)) - - if size is None: - msg = f"Unknown resolution: {resolution}" +def normalize_image_size(resolution: str | int | tuple[int, int]) -> tuple[int, int]: + match resolution: + case (int(), int()): + size = resolution + case int(): + size = resolution, resolution + case str(): + resolution = resolution.strip().upper() + resolution = resolution.replace(" ", "").replace("X", ",").replace("*", ",") + size = _NAMED_RESOLUTIONS.get(resolution.upper()) + if size is None: + # is it WIDTH,HEIGHT format? + try: + width, height = resolution.split(",") + size = int(width), int(height) + except ValueError: + pass + if size is None: + # is it just a single number? + with contextlib.suppress(ValueError): + size = (int(resolution), int(resolution)) + if size is None: + msg = f"Invalid resolution: '{resolution}'" + raise ValueError(msg) + case _: + msg = f"Invalid resolution: {resolution!r}" + raise ValueError(msg) + if size[0] <= 0 or size[1] <= 0: + msg = f"Invalid resolution: {resolution!r}" raise ValueError(msg) - return size diff --git a/imaginairy/video_sample.py b/imaginairy/video_sample.py index dfe5a27..7fe6c74 100644 --- a/imaginairy/video_sample.py +++ b/imaginairy/video_sample.py @@ -87,7 +87,7 @@ def generate_video( device="cpu", num_frames=num_frames, num_steps=num_steps, - weights_url=video_model_config["weights_url"], + weights_url=video_model_config["weights_location"], ) torch.manual_seed(seed) diff --git a/imaginairy/weight_management/generate_weight_info.py b/imaginairy/weight_management/generate_weight_info.py index 3cccb86..723d0b1 100644 --- a/imaginairy/weight_management/generate_weight_info.py +++ b/imaginairy/weight_management/generate_weight_info.py @@ -3,7 +3,7 @@ import safetensors from imaginairy.model_manager import ( get_cached_url_path, open_weights, - resolve_model_paths, + resolve_model_weights_config, ) from imaginairy.weight_management import utils from imaginairy.weight_management.pattern_collapse import find_state_dict_key_patterns @@ -11,15 +11,12 @@ from imaginairy.weight_management.utils import save_model_info def save_compvis_patterns(): - ( - model_metadata, - weights_url, - config_path, - control_weights_paths, - ) = resolve_model_paths( - weights_path="openjourney-v1", + model_weights_config = resolve_model_weights_config( + model_weights="openjourney-v1", + ) + weights_path = get_cached_url_path( + model_weights_config.weights_location, category="weights" ) - weights_path = get_cached_url_path(weights_url, category="weights") with safetensors.safe_open(weights_path, "pytorch") as f: weights_keys = f.keys() @@ -98,7 +95,7 @@ def save_weight_info( model_name, component_name, format_name, weights_url=None, weights_keys=None ): if weights_keys is None and weights_url is None: - msg = "Either weights_keys or weights_url must be provided" + msg = "Either weights_keys or weights_location must be provided" raise ValueError(msg) if weights_keys is None: diff --git a/tests/conftest.py b/tests/conftest.py index f7fc6a6..18d8ad5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,7 +12,6 @@ from urllib3 import HTTPConnectionPool from imaginairy import ImaginePrompt, api, imagine from imaginairy.log_utils import configure_logging, suppress_annoying_logs_and_warnings -from imaginairy.samplers import SAMPLER_TYPE_OPTIONS from imaginairy.utils import ( fix_torch_group_norm, fix_torch_nn_layer_norm, @@ -26,13 +25,13 @@ if "pytest" in str(sys.argv): logger = logging.getLogger(__name__) -SAMPLERS_FOR_TESTING = SAMPLER_TYPE_OPTIONS -if get_device() == "mps:0": - SAMPLERS_FOR_TESTING = ["plms", "k_euler_a"] -elif get_device() == "cpu": - SAMPLERS_FOR_TESTING = [] +# SOLVERS_FOR_TESTING = SOLVER_TYPE_OPTIONS +# if get_device() == "mps:0": +# SOLVERS_FOR_TESTING = ["plms", "k_euler_a"] +# elif get_device() == "cpu": +# SOLVERS_FOR_TESTING = [] -SAMPLERS_FOR_TESTING = ["ddim", "k_dpmpp_2m"] +SOLVERS_FOR_TESTING = ["ddim", "dpmpp"] @pytest.fixture(scope="session", autouse=True) @@ -90,8 +89,8 @@ def filename_base_for_orig_outputs(request): return filename_base -@pytest.fixture(params=SAMPLERS_FOR_TESTING) -def sampler_type(request): +@pytest.fixture(params=SOLVERS_FOR_TESTING) +def solver_type(request): return request.param @@ -118,11 +117,10 @@ def default_model_loaded(): """ prompt = ImaginePrompt( "dogs lying on a hot pink couch", - width=64, - height=64, + size=64, steps=2, seed=1, - sampler_type="ddim", + solver_type="ddim", ) next(imagine(prompt)) diff --git a/tests/expected_output/test_imagine[dpmpp]_.png b/tests/expected_output/test_imagine[dpmpp]_.png new file mode 100644 index 0000000..ea9ebfe Binary files /dev/null and b/tests/expected_output/test_imagine[dpmpp]_.png differ diff --git a/tests/expected_output/test_img2img_beach_to_sunset[dpmpp]_.png b/tests/expected_output/test_img2img_beach_to_sunset[dpmpp]_.png new file mode 100644 index 0000000..24d4fa7 Binary files /dev/null and b/tests/expected_output/test_img2img_beach_to_sunset[dpmpp]_.png differ diff --git a/tests/expected_output/test_img2img_low_noise[dpmpp]_.png b/tests/expected_output/test_img2img_low_noise[dpmpp]_.png new file mode 100644 index 0000000..a396df6 Binary files /dev/null and b/tests/expected_output/test_img2img_low_noise[dpmpp]_.png differ diff --git a/tests/expected_output/test_img_to_img_from_url_cats[dpmpp]_.png b/tests/expected_output/test_img_to_img_from_url_cats[dpmpp]_.png new file mode 100644 index 0000000..ed8aa77 Binary files /dev/null and b/tests/expected_output/test_img_to_img_from_url_cats[dpmpp]_.png differ diff --git a/tests/expected_output/test_img_to_img_fruit_2_gold[dpmpp-0.05]_.png b/tests/expected_output/test_img_to_img_fruit_2_gold[dpmpp-0.05]_.png new file mode 100644 index 0000000..d4ce628 Binary files /dev/null and b/tests/expected_output/test_img_to_img_fruit_2_gold[dpmpp-0.05]_.png differ diff --git a/tests/expected_output/test_img_to_img_fruit_2_gold[dpmpp-0.2]_.png b/tests/expected_output/test_img_to_img_fruit_2_gold[dpmpp-0.2]_.png new file mode 100644 index 0000000..2613763 Binary files /dev/null and b/tests/expected_output/test_img_to_img_fruit_2_gold[dpmpp-0.2]_.png differ diff --git a/tests/expected_output/test_img_to_img_fruit_2_gold[dpmpp-0]_.png b/tests/expected_output/test_img_to_img_fruit_2_gold[dpmpp-0]_.png new file mode 100644 index 0000000..264f6a8 Binary files /dev/null and b/tests/expected_output/test_img_to_img_fruit_2_gold[dpmpp-0]_.png differ diff --git a/tests/expected_output/test_img_to_img_fruit_2_gold[dpmpp-1]_.png b/tests/expected_output/test_img_to_img_fruit_2_gold[dpmpp-1]_.png new file mode 100644 index 0000000..e32a842 Binary files /dev/null and b/tests/expected_output/test_img_to_img_fruit_2_gold[dpmpp-1]_.png differ diff --git a/tests/test_api.py b/tests/test_api.py index cf8eb58..38abef4 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -6,22 +6,22 @@ from imaginairy import LazyLoadingImage from imaginairy.api import imagine, imagine_image_files from imaginairy.img_processors.control_modes import CONTROL_MODES from imaginairy.img_utils import pillow_fit_image_within -from imaginairy.schema import ControlNetInput, ImaginePrompt +from imaginairy.schema import ControlInput, ImaginePrompt, MaskMode from imaginairy.utils import get_device from . import TESTS_FOLDER from .utils import assert_image_similar_to_expectation -def test_imagine(sampler_type, filename_base_for_outputs): +def test_imagine(solver_type, filename_base_for_outputs): prompt_text = "a scenic old-growth forest with diffuse light poking through the canopy. high resolution nature photography" prompt = ImaginePrompt( - prompt_text, width=512, height=512, steps=20, seed=1, sampler_type=sampler_type + prompt_text, size=512, steps=20, seed=1, solver_type=solver_type ) result = next(imagine(prompt)) threshold_lookup = {"k_dpm_2_a": 26000} - threshold = threshold_lookup.get(sampler_type, 10000) + threshold = threshold_lookup.get(solver_type, 10000) img_path = f"{filename_base_for_outputs}.png" assert_image_similar_to_expectation( @@ -49,25 +49,25 @@ def test_model_versions(filename_base_for_orig_outputs, model_version): ImaginePrompt( prompt_text, seed=1, - model=model_version, + model_weights=model_version, ) ) threshold = 35000 - - for i, result in enumerate(imagine(prompts)): - img_path = f"{filename_base_for_orig_outputs}_{result.prompt.prompt_text}_{result.prompt.model}.png" + results = list(imagine(prompts)) + for i, result in enumerate(results): + img_path = f"{filename_base_for_orig_outputs}_{result.prompt.prompt_text}_{result.prompt.model_weights}.png" result.img.save(img_path) - for i, result in enumerate(imagine(prompts)): - img_path = f"{filename_base_for_orig_outputs}_{result.prompt.prompt_text}_{result.prompt.model}.png" + for i, result in enumerate(results): + img_path = f"{filename_base_for_orig_outputs}_{result.prompt.prompt_text}_{result.prompt.model_weights}.png" assert_image_similar_to_expectation( result.img, img_path=img_path, threshold=threshold ) def test_img2img_beach_to_sunset( - sampler_type, filename_base_for_outputs, filename_base_for_orig_outputs + solver_type, filename_base_for_outputs, filename_base_for_orig_outputs ): img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/beach_at_sainte_adresse.jpg") prompt = ImaginePrompt( @@ -77,11 +77,10 @@ def test_img2img_beach_to_sunset( prompt_strength=15, mask_prompt="(sky|clouds) AND !(buildings|trees)", mask_mode="replace", - width=512, - height=512, + size=512, steps=40 * 2, seed=1, - sampler_type=sampler_type, + solver_type=solver_type, ) result = next(imagine(prompt)) @@ -91,7 +90,7 @@ def test_img2img_beach_to_sunset( def test_img_to_img_from_url_cats( - sampler_type, + solver_type, filename_base_for_outputs, mocked_responses, filename_base_for_orig_outputs, @@ -113,11 +112,10 @@ def test_img_to_img_from_url_cats( "dogs lying on a hot pink couch", init_image=img, init_image_strength=0.5, - width=512, - height=512, + size=512, steps=50, seed=1, - sampler_type=sampler_type, + solver_type=solver_type, ) result = next(imagine(prompt)) @@ -130,7 +128,7 @@ def test_img_to_img_from_url_cats( def test_img2img_low_noise( filename_base_for_outputs, - sampler_type, + solver_type, ): fruit_path = os.path.join(TESTS_FOLDER, "data", "bowl_of_fruit.jpg") img = LazyLoadingImage(filepath=fruit_path) @@ -144,17 +142,18 @@ def test_img2img_low_noise( mask_mode="replace", # steps=40, seed=1, - sampler_type=sampler_type, + solver_type=solver_type, ) result = next(imagine(prompt)) threshold_lookup = { + "dpmpp": 26000, "k_dpm_2_a": 26000, "k_euler_a": 18000, "k_dpm_adaptive": 13000, } - threshold = threshold_lookup.get(sampler_type, 14000) + threshold = threshold_lookup.get(solver_type, 14000) img_path = f"{filename_base_for_outputs}.png" assert_image_similar_to_expectation( @@ -165,7 +164,7 @@ def test_img2img_low_noise( @pytest.mark.parametrize("init_strength", [0, 0.05, 0.2, 1]) def test_img_to_img_fruit_2_gold( filename_base_for_outputs, - sampler_type, + solver_type, init_strength, filename_base_for_orig_outputs, ): @@ -183,7 +182,7 @@ def test_img_to_img_fruit_2_gold( mask_mode="replace", steps=needed_steps, seed=1, - sampler_type=sampler_type, + solver_type=solver_type, ) result = next(imagine(prompt)) @@ -194,7 +193,7 @@ def test_img_to_img_fruit_2_gold( "k_dpm_adaptive": 13000, "k_dpmpp_2s": 16000, } - threshold = threshold_lookup.get(sampler_type, 16000) + threshold = threshold_lookup.get(solver_type, 16000) pillow_fit_image_within(img).save(f"{filename_base_for_orig_outputs}__orig.jpg") img_path = f"{filename_base_for_outputs}.png" @@ -227,7 +226,7 @@ def test_img_to_img_fruit_2_gold_repeat(): ] for result in imagine(prompts, debug_img_callback=None): result.img.save( - f"{TESTS_FOLDER}/test_output/img2img_fruit_2_gold_{result.prompt.sampler_type}_{get_device()}_run-{run_count:02}.jpg" + f"{TESTS_FOLDER}/test_output/img2img_fruit_2_gold_{result.prompt.solver_type}_{get_device()}_run-{run_count:02}.jpg" ) run_count += 1 @@ -236,9 +235,8 @@ def test_img_to_img_fruit_2_gold_repeat(): def test_img_to_file(): prompt = ImaginePrompt( "an old growth forest, diffuse light poking through the canopy. high-resolution, nature photography, nat geo photo", - width=512 + 64, - height=512 - 64, - steps=20, + size=(512 + 64, 512 - 64), + steps=2, seed=2, upscale=True, ) @@ -254,8 +252,7 @@ def test_inpainting_bench(filename_base_for_outputs, filename_base_for_orig_outp init_image=img, init_image_strength=0.4, mask_image=LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/bench2_mask.png"), - width=512, - height=512, + size=512, steps=40, seed=1, ) @@ -279,9 +276,8 @@ def test_cliptext_inpainting_pearl_doctor( init_image=img, init_image_strength=0.2, mask_prompt="face AND NOT (bandana OR hair OR blue fabric){*5}", - mask_mode=ImaginePrompt.MaskMode.KEEP, - width=512, - height=512, + mask_mode=MaskMode.KEEP, + size=512, steps=40, seed=181509347, ) @@ -297,8 +293,7 @@ def test_tile_mode(filename_base_for_outputs): prompt_text = "gold coins" prompt = ImaginePrompt( prompt_text, - width=400, - height=400, + size=400, steps=15, seed=1, tile_mode="xy", @@ -317,7 +312,7 @@ control_modes = list(CONTROL_MODES.keys()) def test_controlnet(filename_base_for_outputs, control_mode): prompt_text = "a photo of a woman sitting on a bench" img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/bench2.png") - control_input = ControlNetInput( + control_input = ControlInput( mode=control_mode, image=img, ) @@ -327,30 +322,27 @@ def test_controlnet(filename_base_for_outputs, control_mode): prompt_text = "a wise old man" seed = 1 mask_image = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/bench2_mask.png") - control_input = ControlNetInput( + control_input = ControlInput( mode=control_mode, image=mask_image, ) prompt = ImaginePrompt( prompt_text, - width=512, - height=512, + size=512, steps=45, seed=seed, init_image=img, init_image_strength=0, control_inputs=[control_input], fix_faces=True, - sampler="ddim", + solver_type="ddim", ) prompt.steps = 1 - prompt.width = 256 - prompt.height = 256 + prompt.size = 256 result = next(imagine(prompt)) prompt.steps = 15 - prompt.width = 512 - prompt.height = 512 + prompt.size = 512 result = next(imagine(prompt)) img_path = f"{filename_base_for_outputs}.png" @@ -365,8 +357,7 @@ def test_large_image(filename_base_for_outputs): prompt_text = "a stormy ocean. oil painting" prompt = ImaginePrompt( prompt_text, - width=1920, - height=1080, + size="1080p", steps=30, seed=0, ) diff --git a/tests/test_cmds.py b/tests/test_cmds.py index 5da6650..5c401db 100644 --- a/tests/test_cmds.py +++ b/tests/test_cmds.py @@ -72,8 +72,7 @@ def test_edit_demo(monkeypatch): ImaginePrompt( "", steps=1, - width=256, - height=256, + size=256, # model="empty", ) ] @@ -89,7 +88,7 @@ def test_edit_demo(monkeypatch): f"{TESTS_FOLDER}/test_output", ], ) - assert result.exit_code == 0 + assert result.exit_code == 0, result.stdout def test_upscale(monkeypatch): diff --git a/tests/test_config.py b/tests/test_config.py index 5bc784c..34a7ac8 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,6 +1,6 @@ -from imaginairy import config -from imaginairy.samplers import SAMPLER_TYPE_OPTIONS - - -def test_sampler_options(): - assert set(config.SAMPLER_TYPE_OPTIONS) == set(SAMPLER_TYPE_OPTIONS) +# from imaginairy import config +# from imaginairy.samplers import SOLVER_TYPE_OPTIONS +# +# +# def test_sampler_options(): +# assert set(config.SOLVER_TYPE_NAMES) == set(SOLVER_TYPE_OPTIONS) diff --git a/tests/test_enhancers.py b/tests/test_enhancers.py index 087b10c..3bd4d74 100644 --- a/tests/test_enhancers.py +++ b/tests/test_enhancers.py @@ -58,7 +58,7 @@ def test_clip_masking(filename_base_for_outputs): upscale=False, fix_faces=True, seed=42, - # sampler_type="plms", + # solver_type="plms", ) result = next(imagine(prompt)) diff --git a/tests/test_model_manager.py b/tests/test_model_manager.py index ac1be82..16bcb0d 100644 --- a/tests/test_model_manager.py +++ b/tests/test_model_manager.py @@ -1,24 +1,25 @@ from imaginairy import config -from imaginairy.model_manager import resolve_model_paths +from imaginairy.model_manager import resolve_model_weights_config def test_resolved_paths(): """Test that the resolved model path is correct.""" - ( - model_metadata, - weights_path, - config_path, - control_weights_path, - ) = resolve_model_paths() - assert model_metadata.short_name == config.DEFAULT_MODEL - assert model_metadata.config_path == config_path - default_config_path = config_path + model_weights_config = resolve_model_weights_config(config.DEFAULT_MODEL_WEIGHTS) + assert config.DEFAULT_MODEL_WEIGHTS.lower() in model_weights_config.aliases + assert ( + config.DEFAULT_MODEL_ARCHITECTURE in model_weights_config.architecture.aliases + ) - ( - model_metadata, - weights_path, - config_path, - control_weights_path, - ) = resolve_model_paths(weights_path="foo.ckpt") - assert weights_path == "foo.ckpt" - assert config_path == default_config_path + model_weights_config = resolve_model_weights_config( + model_weights="foo.ckpt", + default_model_architecture="sd15", + ) + print(model_weights_config) + assert model_weights_config.aliases == [] + assert "sd15" in model_weights_config.architecture.aliases + + model_weights_config = resolve_model_weights_config( + model_weights="foo.ckpt", default_model_architecture="sd15", for_inpainting=True + ) + assert model_weights_config.aliases == [] + assert "sd15-inpaint" in model_weights_config.architecture.aliases diff --git a/tests/test_schema/test_controlnetinput.py b/tests/test_schema/test_controlnetinput.py index 90d9a09..fa257fa 100644 --- a/tests/test_schema/test_controlnetinput.py +++ b/tests/test_schema/test_controlnetinput.py @@ -2,7 +2,7 @@ import pytest from pydantic import ValidationError from imaginairy import LazyLoadingImage -from imaginairy.schema import ControlNetInput +from imaginairy.schema import ControlInput from tests import TESTS_FOLDER @@ -12,29 +12,29 @@ def _lazy_img(): def test_controlnetinput_basic(lazy_img): - ControlNetInput(mode="canny", image=lazy_img) - ControlNetInput(mode="canny", image_raw=lazy_img) + ControlInput(mode="canny", image=lazy_img) + ControlInput(mode="canny", image_raw=lazy_img) def test_controlnetinput_invalid_mode(lazy_img): with pytest.raises(ValueError, match=r".*Invalid controlnet mode.*"): - ControlNetInput(mode="pizza", image=lazy_img) + ControlInput(mode="pizza", image=lazy_img) def test_controlnetinput_both_images(lazy_img): with pytest.raises(ValueError, match=r".*cannot specify both.*"): - ControlNetInput(mode="canny", image=lazy_img, image_raw=lazy_img) + ControlInput(mode="canny", image=lazy_img, image_raw=lazy_img) def test_controlnetinput_filepath_input(lazy_img): """Test that we accept filepaths here.""" - c = ControlNetInput(mode="canny", image=f"{TESTS_FOLDER}/data/red.png") + c = ControlInput(mode="canny", image=f"{TESTS_FOLDER}/data/red.png") c.image.convert("RGB") - c = ControlNetInput(mode="canny", image_raw=f"{TESTS_FOLDER}/data/red.png") + c = ControlInput(mode="canny", image_raw=f"{TESTS_FOLDER}/data/red.png") c.image_raw.convert("RGB") def test_controlnetinput_big(lazy_img): - ControlNetInput(mode="canny", strength=2) + ControlInput(mode="canny", strength=2) with pytest.raises(ValidationError, match=r".*float_type.*"): - ControlNetInput(mode="canny", strength=2**2048) + ControlInput(mode="canny", strength=2**2048) diff --git a/tests/test_schema/test_imagineprompt.py b/tests/test_schema/test_imagineprompt.py index 1245e08..4ac25b2 100644 --- a/tests/test_schema/test_imagineprompt.py +++ b/tests/test_schema/test_imagineprompt.py @@ -2,13 +2,21 @@ import pytest from pydantic import ValidationError from imaginairy import LazyLoadingImage, config -from imaginairy.schema import ControlNetInput, ImaginePrompt, WeightedPrompt +from imaginairy.schema import ControlInput, ImaginePrompt, WeightedPrompt from imaginairy.utils.data_distorter import DataDistorter from tests import TESTS_FOLDER +def test_imagine_prompt_default(): + prompt = ImaginePrompt() + assert prompt.prompt == [] + assert prompt.negative_prompt == [ + WeightedPrompt(text=config.DEFAULT_NEGATIVE_PROMPT) + ] + + def test_imagine_prompt_has_default_negative(): - prompt = ImaginePrompt("fruit salad", model="foobar") + prompt = ImaginePrompt("fruit salad", model_weights="foobar") assert isinstance(prompt.prompt[0], WeightedPrompt) assert isinstance(prompt.negative_prompt[0], WeightedPrompt) @@ -21,10 +29,10 @@ def test_imagine_prompt_custom_negative_prompt(): def test_imagine_prompt_model_specific_negative_prompt(): - prompt = ImaginePrompt("fruit salad", model="openjourney-v1") + prompt = ImaginePrompt("fruit salad", model_weights="openjourney-v1") assert isinstance(prompt.prompt[0], WeightedPrompt) assert isinstance(prompt.negative_prompt[0], WeightedPrompt) - assert prompt.negative_prompt[0].text == "" + assert prompt.negative_prompt[0].text == "poor quality" def test_imagine_prompt_weighted_prompts(): @@ -84,7 +92,7 @@ def test_imagine_prompt_control_inputs(): prompt = ImaginePrompt( "fruit", control_inputs=[ - ControlNetInput(mode="depth", image=img), + ControlInput(mode="depth", image=img), ], ) prompt.control_inputs[0].image.convert("RGB") @@ -98,7 +106,7 @@ def test_imagine_prompt_control_inputs(): "fruit", init_image=img, control_inputs=[ - ControlNetInput(mode="depth"), + ControlInput(mode="depth"), ], ) assert prompt.control_inputs[0].image is not None @@ -107,7 +115,7 @@ def test_imagine_prompt_control_inputs(): prompt = ImaginePrompt( "fruit", control_inputs=[ - ControlNetInput(mode="depth"), + ControlInput(mode="depth"), ], ) assert prompt.control_inputs[0].image is None @@ -136,8 +144,8 @@ def test_imagine_prompt_mask_params(): def test_imagine_prompt_default_model(): - prompt = ImaginePrompt("fruit", model=None) - assert prompt.model == config.DEFAULT_MODEL + prompt = ImaginePrompt("fruit", model_weights=None) + assert prompt.model_weights == config.DEFAULT_MODEL_WEIGHTS def test_imagine_prompt_default_negative(): @@ -152,7 +160,7 @@ def test_imagine_prompt_fix_faces_fidelity(): def test_imagine_prompt_init_strength_zero(): lazy_img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/red.png") prompt = ImaginePrompt( - "fruit", control_inputs=[ControlNetInput(mode="depth", image=lazy_img)] + "fruit", control_inputs=[ControlInput(mode="depth", image=lazy_img)] ) assert prompt.init_image_strength == 0.0 @@ -171,12 +179,12 @@ def test_distorted_prompts(): init_image=LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/red.png"), init_image_strength=0.5, control_inputs=[ - ControlNetInput( + ControlInput( mode="details", image=LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/red.png"), strength=2, ), - ControlNetInput( + ControlInput( mode="depth", image_raw=LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/red.png"), strength=3, @@ -187,13 +195,11 @@ def test_distorted_prompts(): mask_mode="replace", mask_modify_original=False, outpaint="all5,up0,down20", - model=config.DEFAULT_MODEL, - model_config_path=None, - sampler_type=config.DEFAULT_SAMPLER, + model_weights=config.DEFAULT_MODEL_WEIGHTS, + solver_type=config.DEFAULT_SOLVER, seed=42, steps=10, - height=256, - width=256, + size=256, upscale=True, fix_faces=True, fix_faces_fidelity=0.7, diff --git a/tests/test_utils/test_named_resolutions.py b/tests/test_utils/test_named_resolutions.py index d5162c5..07d6e80 100644 --- a/tests/test_utils/test_named_resolutions.py +++ b/tests/test_utils/test_named_resolutions.py @@ -1,6 +1,6 @@ import pytest -from imaginairy.utils.named_resolutions import get_named_resolution +from imaginairy.utils.named_resolutions import normalize_image_size valid_cases = [ ("HD", (1280, 720)), @@ -12,11 +12,25 @@ valid_cases = [ ("1920x1080", (1920, 1080)), ("1280x720", (1280, 720)), ("1024x768", (1024, 768)), + ("1024,768", (1024, 768)), + ("1024*768", (1024, 768)), + ("1024, 768", (1024, 768)), ("800", (800, 800)), ("1024", (1024, 1024)), + ("1080p", (1920, 1080)), + ("1080P", (1920, 1080)), + (512, (512, 512)), + ((512, 512), (512, 512)), + ("1x1", (1, 1)), ] invalid_cases = [ + None, + 3.14, + (3.14, 3.14), + "", + " ", "abc", + "-512", "1920xABC", "1920x1080x1234", "x1920", @@ -30,10 +44,10 @@ invalid_cases = [ @pytest.mark.parametrize(("named_resolution", "expected"), valid_cases) def test_named_resolutions(named_resolution, expected): - assert get_named_resolution(named_resolution) == expected + assert normalize_image_size(named_resolution) == expected @pytest.mark.parametrize("named_resolution", invalid_cases) def test_invalid_inputs(named_resolution): - with pytest.raises(ValueError, match="Unknown resolution"): - get_named_resolution(named_resolution) + with pytest.raises(ValueError, match="Invalid resolution"): + normalize_image_size(named_resolution)