refactor: simplify model_weights/architecture

pull/411/head^2
Bryce 6 months ago committed by Bryce Drennan
parent 37ecd1e5e0
commit 203747b14f

@ -3,6 +3,8 @@ import os
import re
from typing import TYPE_CHECKING, Callable
from imaginairy.utils.named_resolutions import normalize_image_size
if TYPE_CHECKING:
from imaginairy.schema import ImaginePrompt
@ -543,21 +545,22 @@ def _generate_single_image_compvis(
and not is_controlnet_model
and model.cond_stage_key != "edit"
):
default_size = get_model_default_image_size(
prompt.model_weights.architecture
)
if prompt.init_image:
comp_image = _generate_composition_image(
prompt=prompt,
target_height=init_image.height,
target_width=init_image.width,
cutoff=get_model_default_image_size(
prompt.model_weights.model_architecture
),
cutoff=default_size,
)
else:
comp_image = _generate_composition_image(
prompt=prompt,
target_height=prompt.height,
target_width=prompt.width,
cutoff=get_model_default_image_size(prompt.model_architecture),
cutoff=default_size,
)
if comp_image is not None:
result_images["composition"] = comp_image
@ -668,18 +671,15 @@ def _prompts_to_embeddings(prompts, model):
return conditioning
def calc_scale_to_fit_within(
height,
width,
max_size,
):
if max(height, width) < max_size:
def calc_scale_to_fit_within(height: int, width: int, max_size) -> float:
max_width, max_height = normalize_image_size(max_size)
if width <= max_width and height <= max_height:
return 1
if width > height:
return max_size / width
width_ratio = max_width / width
height_ratio = max_height / height
return max_size / height
return min(width_ratio, height_ratio)
def _scale_latent(
@ -698,14 +698,19 @@ def _scale_latent(
def _generate_composition_image(
prompt, target_height, target_width, cutoff=512, dtype=None
prompt,
target_height,
target_width,
cutoff: tuple[int, int] = (512, 512),
dtype=None,
):
from PIL import Image
from imaginairy.api_refiners import _generate_single_image
from imaginairy.utils import default, get_default_dtype
if prompt.width <= cutoff and prompt.height <= cutoff:
cutoff = normalize_image_size(cutoff)
if prompt.width <= cutoff[0] and prompt.height <= cutoff[1]:
return None, None
dtype = default(dtype, get_default_dtype)
@ -727,6 +732,7 @@ def _generate_composition_image(
"upscale": False,
"fix_faces": False,
"mask_modify_original": False,
"allow_compose_phase": False,
},
)

@ -14,8 +14,6 @@ def _generate_single_image(
progress_img_interval_steps=3,
progress_img_interval_min_s=0.1,
add_caption=False,
# controlnet, finetune, naive, auto
inpaint_method="finetune",
return_latent=False,
dtype=None,
half_mode=None,
@ -63,20 +61,11 @@ def _generate_single_image(
clear_gpu_cache()
prompt = prompt.make_concrete_copy()
control_modes = []
control_inputs = prompt.control_inputs or []
control_inputs = control_inputs.copy()
for_inpainting = bool(prompt.mask_image or prompt.mask_prompt or prompt.outpaint)
if control_inputs:
control_modes = [c.mode for c in prompt.control_inputs]
sd = get_diffusion_model_refiners(
weights_location=prompt.model_weights,
model_architecture=prompt.model_architecture,
control_weights_locations=tuple(control_modes),
weights_config=prompt.model_weights,
for_inpainting=prompt.should_use_inpainting
and prompt.inpaint_method == "finetune",
dtype=dtype,
for_inpainting=for_inpainting and inpaint_method == "finetune",
)
seed_everything(prompt.seed)
@ -126,6 +115,14 @@ def _generate_single_image(
init_latent = None
noise_step = None
control_modes = []
control_inputs = prompt.control_inputs or []
control_inputs = control_inputs.copy()
if control_inputs:
control_modes = [c.mode for c in prompt.control_inputs]
if prompt.init_image:
starting_image = prompt.init_image
first_step = int((prompt.steps) * prompt.init_image_strength)
@ -175,7 +172,7 @@ def _generate_single_image(
pillow_mask_to_latent_mask(
mask_image, downsampling_factor=downsampling_factor
).to(get_device())
if inpaint_method == "controlnet":
if prompt.inpaint_method == "controlnet":
result_images["control-inpaint"] = mask_image
control_inputs.append(
ControlInput(mode="inpaint", image=mask_image)

@ -33,7 +33,7 @@ remove_option(edit_options, "allow_compose_phase")
"--model",
help=f"Model to use. Should be one of {', '.join(config.IMAGE_WEIGHTS_SHORT_NAMES)}, or a path to custom weights.",
show_default=True,
default="SD-1.5",
default="sd15",
)
@click.option(
"--negative-prompt",

@ -133,6 +133,15 @@ def _imagine_cmd(
prompt_expanding_iterators = {}
from imaginairy.enhancers.prompt_expansion import expand_prompts
if model_weights_path.lower() not in config.MODEL_WEIGHT_CONFIG_LOOKUP:
model_weights_path = config.ModelWeightsConfig(
name="custom weights",
aliases=["custom"],
weights_location=model_weights_path,
architecture=model_architecture,
defaults={"negative_prompt": config.DEFAULT_NEGATIVE_PROMPT},
)
for _ in range(repeats):
for prompt_text in prompt_texts:
if prompt_text not in prompt_expanding_iterators:
@ -174,7 +183,6 @@ def _imagine_cmd(
tile_mode=_tile_mode,
allow_compose_phase=allow_compose_phase,
model_weights=model_weights_path,
model_architecture=model_architecture,
caption_text=caption_text,
)
from imaginairy.prompt_schedules import (

@ -26,15 +26,6 @@ class ModelArchitecture:
config_path: str | None = None
@dataclass
class ModelWeightsConfig:
name: str
aliases: List[str]
architecture: ModelArchitecture
defaults: dict[str, Any]
weights_location: str
MODEL_ARCHITECTURES = [
ModelArchitecture(
name="Stable Diffusion 1.5",
@ -107,6 +98,22 @@ for m in MODEL_ARCHITECTURES:
MODEL_ARCHITECTURE_LOOKUP[a] = m
@dataclass
class ModelWeightsConfig:
name: str
aliases: List[str]
architecture: ModelArchitecture
defaults: dict[str, Any]
weights_location: str
def __post_init__(self):
if isinstance(self.architecture, str):
self.architecture = MODEL_ARCHITECTURE_LOOKUP[self.architecture]
if not isinstance(self.architecture, ModelArchitecture):
msg = f"You must specify an architecture {self.architecture}"
raise ValueError(msg) # noqa
MODEL_WEIGHT_CONFIGS = [
ModelWeightsConfig(
name="Stable Diffusion 1.5",

@ -14,6 +14,7 @@ from huggingface_hub import (
)
from omegaconf import OmegaConf
from refiners.foundationals.latent_diffusion import SD1UNet
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
from safetensors.torch import load_file
from imaginairy import config as iconfig
@ -22,6 +23,7 @@ from imaginairy.modules import attention
from imaginairy.paths import PKG_ROOT
from imaginairy.utils import get_device, instantiate_from_config
from imaginairy.utils.model_cache import memory_managed_model
from imaginairy.utils.named_resolutions import normalize_image_size
logger = logging.getLogger(__name__)
@ -224,73 +226,25 @@ def _get_diffusion_model(
def get_diffusion_model_refiners(
weights_location=iconfig.DEFAULT_MODEL_WEIGHTS,
model_architecture=None,
control_weights_locations=None,
dtype=None,
weights_config: iconfig.ModelWeightsConfig,
for_inpainting=False,
):
"""
Load a diffusion model.
Weights location may also be shortcut name, e.g. "SD-1.5"
"""
try:
return _get_diffusion_model_refiners(
weights_location,
model_architecture=model_architecture,
for_inpainting=for_inpainting,
dtype=dtype,
control_weights_locations=control_weights_locations,
)
except HuggingFaceAuthorizationError as e:
if for_inpainting:
logger.warning(
f"Failed to load inpainting model. Attempting to fall-back to standard model. {e!s}"
)
return _get_diffusion_model_refiners(
iconfig.DEFAULT_MODEL_WEIGHTS,
model_architecture=model_architecture,
dtype=dtype,
for_inpainting=False,
control_weights_locations=control_weights_locations,
)
raise
def _get_diffusion_model_refiners(
weights_location=iconfig.DEFAULT_MODEL_WEIGHTS,
model_architecture=None,
for_inpainting=False,
control_weights_locations=None,
device=None,
dtype=torch.float16,
):
"""
Load a diffusion model.
Weights location may also be shortcut name, e.g. "SD-1.5"
"""
sd = _get_diffusion_model_refiners_only(
weights_location=weights_location,
model_architecture=model_architecture,
dtype=None,
) -> LatentDiffusionModel:
"""Load a diffusion model."""
return _get_diffusion_model_refiners(
weights_location=weights_config.weights_location,
for_inpainting=for_inpainting,
device=device,
dtype=dtype,
)
return sd
@lru_cache(maxsize=1)
def _get_diffusion_model_refiners_only(
weights_location=iconfig.DEFAULT_MODEL_WEIGHTS,
model_architecture=None,
for_inpainting=False,
def _get_diffusion_model_refiners(
weights_location: str,
for_inpainting: bool = False,
device=None,
dtype=torch.float16,
):
) -> LatentDiffusionModel:
"""
Load a diffusion model.
@ -306,17 +260,11 @@ def _get_diffusion_model_refiners_only(
device = device or get_device()
model_weights_config = resolve_model_weights_config(
model_weights=weights_location,
default_model_architecture=model_architecture,
for_inpainting=for_inpainting,
)
(
vae_weights,
unet_weights,
text_encoder_weights,
) = load_stable_diffusion_compvis_weights(model_weights_config.weights_location)
) = load_stable_diffusion_compvis_weights(weights_location)
if for_inpainting:
unet = SD1UNet(in_channels=9)
@ -380,11 +328,23 @@ def load_controlnet(control_weights_location, half_mode):
def resolve_model_weights_config(
model_weights: str,
model_weights: str | iconfig.ModelWeightsConfig,
default_model_architecture: str | None = None,
for_inpainting: bool = False,
) -> iconfig.ModelWeightsConfig:
"""Resolve weight and config path if they happen to be shortcuts."""
if isinstance(model_weights, iconfig.ModelWeightsConfig):
return model_weights
if not isinstance(model_weights, str):
msg = f"Invalid model weights: {model_weights}"
raise ValueError(msg) # noqa
if default_model_architecture is not None and not isinstance(
default_model_architecture, str
):
msg = f"Invalid model architecture: {default_model_architecture}"
raise ValueError(msg)
if for_inpainting:
model_weights_config = iconfig.MODEL_WEIGHT_CONFIG_LOOKUP.get(
@ -441,6 +401,7 @@ def get_model_default_image_size(model_architecture: str | ModelArchitecture):
if default_size is None:
default_size = 512
default_size = normalize_image_size(default_size)
return default_size
@ -648,7 +609,6 @@ def open_weights(filepath, device=None):
return state_dict
@lru_cache
def load_stable_diffusion_compvis_weights(weights_url):
from imaginairy.model_manager import get_cached_url_path
from imaginairy.utils import get_device

@ -9,7 +9,7 @@ import random
from datetime import datetime, timezone
from enum import Enum
from io import BytesIO
from typing import TYPE_CHECKING, Any, List, cast
from typing import TYPE_CHECKING, Any, List, Literal, cast
from pydantic import (
BaseModel,
@ -250,6 +250,7 @@ class MaskMode(str, Enum):
MaskInput = MaskMode | str
PromptInput = str | WeightedPrompt | list[WeightedPrompt] | list[str] | None
InpaintMethod = Literal["finetune", "control"]
class ImaginePrompt(BaseModel, protected_namespaces=()):
@ -278,8 +279,7 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
mask_mode: MaskMode = MaskMode.REPLACE
mask_modify_original: bool = True
outpaint: str | None = ""
model_architecture: str | None = None
model_weights: str = Field(
model_weights: config.ModelWeightsConfig = Field(
default=config.DEFAULT_MODEL_WEIGHTS, validate_default=True
)
solver_type: str = Field(default=config.DEFAULT_SOLVER, validate_default=True)
@ -297,6 +297,7 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
caption_text: str = Field(
"", description="text to be overlaid on the image", validate_default=True
)
inpaint_method: InpaintMethod = "finetune"
def __init__(
self,
@ -312,8 +313,7 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
mask_mode: MaskInput = MaskMode.REPLACE,
mask_modify_original: bool = True,
outpaint: str | None = "",
model_architecture: str | None = None,
model_weights: str = config.DEFAULT_MODEL_WEIGHTS,
model_weights: str | config.ModelWeightsConfig = config.DEFAULT_MODEL_WEIGHTS,
solver_type: str = config.DEFAULT_SOLVER,
seed: int | None = None,
steps: int | None = None,
@ -327,6 +327,7 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
is_intermediate: bool = False,
collect_progress_latents: bool = False,
caption_text: str = "",
inpaint_method: InpaintMethod = "finetune",
):
super().__init__(
prompt=prompt,
@ -340,7 +341,6 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
mask_mode=mask_mode,
mask_modify_original=mask_modify_original,
outpaint=outpaint,
model_architecture=model_architecture,
model_weights=model_weights,
solver_type=solver_type,
seed=seed,
@ -355,6 +355,7 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
is_intermediate=is_intermediate,
collect_progress_latents=collect_progress_latents,
caption_text=caption_text,
inpaint_method=inpaint_method,
)
@field_validator("prompt", "negative_prompt", mode="before")
@ -398,12 +399,9 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
@model_validator(mode="after")
def validate_negative_prompt(self):
if self.negative_prompt == []:
model_weight_config = config.MODEL_WEIGHT_CONFIG_LOOKUP.get(
self.model_weights, None
)
default_negative_prompt = config.DEFAULT_NEGATIVE_PROMPT
if model_weight_config:
default_negative_prompt = model_weight_config.defaults.get(
if self.model_weights:
default_negative_prompt = self.model_weights.defaults.get(
"negative_prompt", default_negative_prompt
)
@ -496,12 +494,30 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
raise ValueError(msg)
return v
@field_validator("model_weights", mode="before")
def set_default_diffusion_model(cls, v):
if v is None:
return config.DEFAULT_MODEL_WEIGHTS
@model_validator(mode="before")
def resolve_model_weights(cls, data: Any):
if not isinstance(data, dict):
return data
return v
model_weights = data.get("model_weights")
if model_weights is None:
model_weights = config.DEFAULT_MODEL_WEIGHTS
from imaginairy.model_manager import resolve_model_weights_config
should_use_inpainting = (
data.get("mask_image") or data.get("mask_prompt") or data.get("outpaint")
)
should_use_inpainting_weights = (
should_use_inpainting and data.get("inpaint_method") == "finetune"
)
model_weights_config = resolve_model_weights_config(
model_weights=model_weights,
default_model_architecture=None,
for_inpainting=should_use_inpainting_weights,
)
data["model_weights"] = model_weights_config
return data
@field_validator("seed")
def validate_seed(cls, v):
@ -564,9 +580,15 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
from imaginairy.utils.named_resolutions import normalize_image_size
if v is None:
v = get_model_default_image_size(info.data["model_architecture"])
v = get_model_default_image_size(info.data["model_weights"].architecture)
width, height = normalize_image_size(v)
return width, height
@field_validator("size", mode="after")
def validate_image_size_after(cls, v, info: core_schema.FieldValidationInfo):
width, height = v
min_size = 8
max_size = 100_000
if not min_size <= width <= max_size:
@ -576,8 +598,7 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
if not min_size <= height <= max_size:
msg = f"Height must be between {min_size} and {max_size}. Got: {height}"
raise ValueError(msg)
return width, height
return v
@field_validator("caption_text", mode="before")
def validate_caption_text(cls, v):
@ -614,6 +635,18 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
def height(self) -> int:
return self.size[1]
@property
def should_use_inpainting(self) -> bool:
return bool(self.outpaint or self.mask_image or self.mask_prompt)
@property
def should_use_inpainting_weights(self) -> bool:
return self.should_use_inpainting and self.inpaint_method == "finetune"
@property
def model_architecture(self) -> config.ModelArchitecture:
return self.model_weights.architecture
def prompt_description(self):
return (
f'"{self.prompt_text}" {self.width}x{self.height}px '
@ -622,8 +655,8 @@ class ImaginePrompt(BaseModel, protected_namespaces=()):
f"prompt-strength:{self.prompt_strength} "
f"steps:{self.steps} solver-type:{self.solver_type} "
f"init-image-strength:{self.init_image_strength} "
f"arch:{self.model_architecture} "
f"weights: {self.model_weights}"
f"arch:{self.model_architecture.aliases[0]} "
f"weights: {self.model_weights.aliases[0]}"
)
def logging_dict(self):
@ -730,8 +763,8 @@ class ImagineResult:
exif[ExifCodes.ImageDescription] = self.prompt.prompt_description()
exif[ExifCodes.UserComment] = json.dumps(self.metadata_dict())
# help future web scrapes not ingest AI generated art
sd_version = self.prompt.model_weights
if len(sd_version) > 20:
sd_version = self.prompt.model_weights.name
if len(sd_version) > 40:
sd_version = "custom weights"
exif[ExifCodes.Software] = f"Imaginairy / Stable Diffusion {sd_version}"
exif[ExifCodes.DateTime] = self.created_at.isoformat(sep=" ")[:19]

@ -55,11 +55,11 @@ def test_model_versions(filename_base_for_orig_outputs, model_version):
threshold = 35000
results = list(imagine(prompts))
for i, result in enumerate(results):
img_path = f"{filename_base_for_orig_outputs}_{result.prompt.prompt_text}_{result.prompt.model_weights}.png"
img_path = f"{filename_base_for_orig_outputs}_{result.prompt.prompt_text}_{result.prompt.model_weights.aliases[0]}.png"
result.img.save(img_path)
for i, result in enumerate(results):
img_path = f"{filename_base_for_orig_outputs}_{result.prompt.prompt_text}_{result.prompt.model_weights}.png"
img_path = f"{filename_base_for_orig_outputs}_{result.prompt.prompt_text}_{result.prompt.model_weights.aliases[0]}.png"
assert_image_similar_to_expectation(
result.img, img_path=img_path, threshold=threshold
)

@ -22,9 +22,20 @@ def test_imagine_prompt_default():
prompt = ImaginePrompt(negative_prompt="")
assert prompt.negative_prompt == [WeightedPrompt(text="")]
assert prompt.width == 512
def test_imagine_prompt_has_default_negative():
prompt = ImaginePrompt("fruit salad", model_weights="foobar")
prompt = ImaginePrompt(
"fruit salad",
model_weights=config.ModelWeightsConfig(
name="foobar",
aliases=["foobar"],
weights_location="foobar",
architecture="sd15",
defaults={},
),
)
assert isinstance(prompt.prompt[0], WeightedPrompt)
assert isinstance(prompt.negative_prompt[0], WeightedPrompt)
@ -153,7 +164,7 @@ def test_imagine_prompt_mask_params():
def test_imagine_prompt_default_model():
prompt = ImaginePrompt("fruit", model_weights=None)
assert prompt.model_weights == config.DEFAULT_MODEL_WEIGHTS
assert config.DEFAULT_MODEL_WEIGHTS in prompt.model_weights.aliases
def test_imagine_prompt_default_negative():

@ -4,10 +4,10 @@ norecursedirs = build dist downloads other prolly_delete imaginairy/vendored
filterwarnings =
ignore::DeprecationWarning
ignore::UserWarning
markers =
gputest: uses the gpu
[mypy]
plugins = pydantic.mypy
exclude = ^(downloads|dist|other|testing_support|imaginairy/vendored|imaginairy/modules/sgm)/.*
ignore_missing_imports = True
markers =
gputest: uses the gpu
Loading…
Cancel
Save