diff --git a/imaginairy/api/generate_refiners.py b/imaginairy/api/generate_refiners.py index 9c68247..7fd2645 100644 --- a/imaginairy/api/generate_refiners.py +++ b/imaginairy/api/generate_refiners.py @@ -250,7 +250,10 @@ def generate_single_image( comp_image_t = comp_image_t.to(sd.lda.device, dtype=sd.lda.dtype) init_latent = sd.lda.encode(comp_image_t) compose_control_inputs: list[ControlInput] - if prompt.model_weights.architecture.primary_alias == "sdxl": + if prompt.model_weights.architecture.primary_alias in ( + "sdxl", + "sdxlinpaint", + ): compose_control_inputs = [] else: compose_control_inputs = [ diff --git a/imaginairy/config.py b/imaginairy/config.py index cf5a91b..eb945db 100644 --- a/imaginairy/config.py +++ b/imaginairy/config.py @@ -58,6 +58,12 @@ MODEL_ARCHITECTURES = [ output_modality="image", defaults={"size": "1024"}, ), + ModelArchitecture( + name="Stable Diffusion XL - Inpainting", + aliases=["sdxlinpaint", "sd-xlinpaint", "sdxl-inpaint"], + output_modality="image", + defaults={"size": "1024"}, + ), ModelArchitecture( name="Stable Video Diffusion", aliases=["svd", "stablevideo"], @@ -162,7 +168,7 @@ MODEL_WEIGHT_CONFIGS = [ defaults={"negative_prompt": DEFAULT_NEGATIVE_PROMPT}, ), ModelWeightsConfig( - name="Modern Disney", + name="Redshift Diffusion", aliases=["redshift-diffusion", "red", "redshift-diffusion-15", "red15"], architecture=MODEL_ARCHITECTURE_LOOKUP["sd15"], weights_location="https://huggingface.co/nitrosocke/redshift-diffusion/tree/80837fe18df05807861ab91c3bad3693c9342e4c/", @@ -179,6 +185,16 @@ MODEL_WEIGHT_CONFIGS = [ }, weights_location="https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/tree/462165984030d82259a11f4367a4eed129e94a7b/", ), + ModelWeightsConfig( + name="Stable Diffusion XL - Inpainting", + aliases=MODEL_ARCHITECTURE_LOOKUP["sdxl-inpaint"].aliases, + architecture=MODEL_ARCHITECTURE_LOOKUP["sdxl-inpaint"], + defaults={ + "negative_prompt": DEFAULT_NEGATIVE_PROMPT, + "composition_strength": 0.6, + }, + weights_location="https://huggingface.co/diffusers/stable-diffusion-xl-1.0-inpainting-0.1/tree/115134f363124c53c7d878647567d04daf26e41e/", + ), ModelWeightsConfig( name="OpenDalle V1.1", aliases=["opendalle11", "odv11", "opendalle11", "opendalle", "od"], diff --git a/imaginairy/modules/refiners_sd.py b/imaginairy/modules/refiners_sd.py index 62e483d..e758073 100644 --- a/imaginairy/modules/refiners_sd.py +++ b/imaginairy/modules/refiners_sd.py @@ -5,7 +5,9 @@ import math from functools import lru_cache from typing import Any, List, Literal +import numpy as np import torch +from PIL import Image from torch import Tensor, device as Device, dtype as DType, nn from torch.nn import functional as F @@ -16,6 +18,7 @@ from imaginairy.vendored.refiners.fluxion.layers.attentions import ( ScaledDotProductAttention, ) from imaginairy.vendored.refiners.fluxion.layers.chain import ChainError +from imaginairy.vendored.refiners.fluxion.utils import image_to_tensor, interpolate from imaginairy.vendored.refiners.foundationals.clip.text_encoder import ( CLIPTextEncoderL, ) @@ -395,6 +398,131 @@ class StableDiffusion_1_Inpainting(TileModeMixin, RefinerStableDiffusion_1_Inpai return conditioning +class StableDiffusion_XL_Inpainting(StableDiffusion_XL): + def __init__( + self, + unet: SDXLUNet | None = None, + lda: SDXLAutoencoder | None = None, + clip_text_encoder: DoubleTextEncoder | None = None, + scheduler: Scheduler | None = None, + device: Device | str | None = "cpu", + dtype: DType | None = None, + ) -> None: + self.mask_latents: Tensor | None = None + self.target_image_latents: Tensor | None = None + super().__init__( + unet=unet, + lda=lda, + clip_text_encoder=clip_text_encoder, + scheduler=scheduler, + device=device, + dtype=dtype, + ) + + def forward( + self, + x: Tensor, + step: int, + *, + clip_text_embedding: Tensor, + pooled_text_embedding: Tensor, + time_ids: Tensor | None = None, + condition_scale: float = 5.0, + **_: Tensor, + ) -> Tensor: + assert self.mask_latents is not None + assert self.target_image_latents is not None + x = torch.cat(tensors=(x, self.mask_latents, self.target_image_latents), dim=1) + return super().forward( + x=x, + step=step, + clip_text_embedding=clip_text_embedding, + pooled_text_embedding=pooled_text_embedding, + time_ids=time_ids, + condition_scale=condition_scale, + ) + + def set_inpainting_conditions( + self, + target_image: Image.Image, + mask: Image.Image, + latents_size: tuple[int, int] = (64, 64), + ) -> tuple[Tensor, Tensor]: + target_image = target_image.convert(mode="RGB") + mask = mask.convert(mode="L") + + mask_tensor = torch.tensor( + data=np.array(object=mask).astype(dtype=np.float32) / 255.0 + ).to(device=self.device) + mask_tensor = ( + (mask_tensor > 0.5) + .unsqueeze(dim=0) + .unsqueeze(dim=0) + .to(dtype=self.unet.dtype) + ) + + self.mask_latents = interpolate(x=mask_tensor, factor=torch.Size(latents_size)) + + init_image_tensor = ( + image_to_tensor( + image=target_image, device=self.device, dtype=self.unet.dtype + ) + * 2 + - 1 + ) + masked_init_image = init_image_tensor * (1 - mask_tensor) + self.target_image_latents = self.lda.encode( + x=masked_init_image.to(dtype=self.lda.dtype) + ) + assert self.target_image_latents is not None + self.target_image_latents = self.target_image_latents.to(dtype=self.unet.dtype) + + return self.mask_latents, self.target_image_latents # type: ignore + + def compute_self_attention_guidance( + self, + x: Tensor, + noise: Tensor, + step: int, + *, + clip_text_embedding: Tensor, + pooled_text_embedding: Tensor, + time_ids: Tensor, + **kwargs: Tensor, + ) -> Tensor: + sag = self._find_sag_adapter() + assert sag is not None + assert self.mask_latents is not None + assert self.target_image_latents is not None + + degraded_latents = sag.compute_degraded_latents( + scheduler=self.scheduler, + latents=x, + noise=noise, + step=step, + classifier_free_guidance=True, + ) + + negative_embedding, _ = clip_text_embedding.chunk(2) + negative_pooled_embedding, _ = pooled_text_embedding.chunk(2) + timestep = self.scheduler.timesteps[step].unsqueeze(dim=0) + time_ids, _ = time_ids.chunk(2) + self.set_unet_context( + timestep=timestep, + clip_text_embedding=negative_embedding, + pooled_text_embedding=negative_pooled_embedding, + time_ids=time_ids, + **kwargs, + ) + x = torch.cat( + tensors=(degraded_latents, self.mask_latents, self.target_image_latents), + dim=1, + ) + degraded_noise = self.unet(x) + + return sag.scale * (noise - degraded_noise) + + class SlicedEncoderMixin(nn.Module): max_chunk_size = 2048 min_chunk_size = 32 diff --git a/imaginairy/utils/model_manager.py b/imaginairy/utils/model_manager.py index e926d36..4c9dd4a 100644 --- a/imaginairy/utils/model_manager.py +++ b/imaginairy/utils/model_manager.py @@ -21,7 +21,11 @@ from safetensors.torch import load_file from imaginairy import config as iconfig from imaginairy.config import IMAGE_WEIGHTS_SHORT_NAMES, ModelArchitecture from imaginairy.modules import attention -from imaginairy.modules.refiners_sd import SDXLAutoencoderSliced, StableDiffusion_XL +from imaginairy.modules.refiners_sd import ( + SDXLAutoencoderSliced, + StableDiffusion_XL, + StableDiffusion_XL_Inpainting, +) from imaginairy.utils import clear_gpu_cache, get_device, instantiate_from_config from imaginairy.utils.model_cache import memory_managed_model from imaginairy.utils.named_resolutions import normalize_image_size @@ -268,8 +272,10 @@ def _get_diffusion_model_refiners( device=device, dtype=dtype, ) - elif architecture.primary_alias == "sdxl": - sd = load_sdxl_pipeline(base_url=weights_location, device=device) + elif architecture.primary_alias in ("sdxl", "sdxlinpaint"): + sd = load_sdxl_pipeline( + base_url=weights_location, device=device, for_inpainting=for_inpainting + ) else: msg = f"Invalid architecture {architecture.primary_alias}" raise ValueError(msg) @@ -734,7 +740,7 @@ def load_sd15_diffusers_weights(base_url: str, device=None): def load_sdxl_pipeline_from_diffusers_weights( - base_url: str, device=None, dtype=torch.float16 + base_url: str, for_inpainting=False, device=None, dtype=torch.float16 ): from imaginairy.utils import get_device @@ -764,7 +770,10 @@ def load_sdxl_pipeline_from_diffusers_weights( source_path=unet_weights_path, device="cpu", ) - unet = SDXLUNet(device="cpu", dtype=dtype, in_channels=4) + if for_inpainting: + unet = SDXLUNet(device="cpu", dtype=dtype, in_channels=9) + else: + unet = SDXLUNet(device="cpu", dtype=dtype, in_channels=4) unet.load_state_dict(unet_weights, assign=True) del unet_weights @@ -789,7 +798,12 @@ def load_sdxl_pipeline_from_diffusers_weights( lda = lda.to(device=device, dtype=torch.float32) unet = unet.to(device=device) text_encoder = text_encoder.to(device=device) - sd = StableDiffusion_XL( + if for_inpainting: + StableDiffusionCls = StableDiffusion_XL_Inpainting + else: + StableDiffusionCls = StableDiffusion_XL + + sd = StableDiffusionCls( device=device, dtype=None, lda=lda, unet=unet, clip_text_encoder=text_encoder ) @@ -797,7 +811,7 @@ def load_sdxl_pipeline_from_diffusers_weights( def load_sdxl_pipeline_from_compvis_weights( - base_url: str, device=None, dtype=torch.float16 + base_url: str, for_inpainting=False, device=None, dtype=torch.float16 ): from imaginairy.utils import get_device @@ -809,7 +823,10 @@ def load_sdxl_pipeline_from_compvis_weights( lda.load_state_dict(vae_weights, assign=True) del vae_weights - unet = SDXLUNet(device="cpu", dtype=dtype, in_channels=4) + if for_inpainting: + unet = SDXLUNet(device="cpu", dtype=dtype, in_channels=9) + else: + unet = SDXLUNet(device="cpu", dtype=dtype, in_channels=4) unet.load_state_dict(unet_weights, assign=True) del unet_weights @@ -819,21 +836,31 @@ def load_sdxl_pipeline_from_compvis_weights( lda = lda.to(device=device, dtype=torch.float32) unet = unet.to(device=device) text_encoder = text_encoder.to(device=device) - sd = StableDiffusion_XL( + + if for_inpainting: + StableDiffusionCls = StableDiffusion_XL_Inpainting + else: + StableDiffusionCls = StableDiffusion_XL + sd = StableDiffusionCls( device=device, dtype=None, lda=lda, unet=unet, clip_text_encoder=text_encoder ) return sd -def load_sdxl_pipeline(base_url, device=None): +def load_sdxl_pipeline(base_url, device=None, for_inpainting=False): + logger.info(f"Loading SDXL weights from {base_url}") device = device or get_device() with logger.timed_info(f"Loaded SDXL pipeline from {base_url}"): if is_diffusers_repo_url(base_url): - sd = load_sdxl_pipeline_from_diffusers_weights(base_url, device=device) + sd = load_sdxl_pipeline_from_diffusers_weights( + base_url, for_inpainting=for_inpainting, device=device + ) else: - sd = load_sdxl_pipeline_from_compvis_weights(base_url, device=device) + sd = load_sdxl_pipeline_from_compvis_weights( + base_url, for_inpainting=for_inpainting, device=device + ) return sd diff --git a/imaginairy/weight_management/weight_maps/Transformers-ClipTextEncoder.weightmap.json b/imaginairy/weight_management/weight_maps/Transformers-ClipTextEncoder.weightmap.json index 3962c9d..d9e23ab 100644 --- a/imaginairy/weight_management/weight_maps/Transformers-ClipTextEncoder.weightmap.json +++ b/imaginairy/weight_management/weight_maps/Transformers-ClipTextEncoder.weightmap.json @@ -3,7 +3,8 @@ "text_model.embeddings.token_embedding": "Sum.TokenEncoder", "text_model.embeddings.position_embedding": "Sum.PositionalEncoder.Embedding", "text_model.final_layer_norm": "LayerNorm", - "text_projection": "Linear" + "text_projection": "Linear", + "text_model.embeddings.position_ids": null }, "regex_map": { "text_model\\.encoder\\.layers\\.(?P\\d+)\\.layer_norm(?P\\d+)": "TransformerLayer_{int(layer) + 1}.Residual_{norm}.LayerNorm",