feature: sdxl inpaint support (#450)

pull/453/head
Bryce Drennan 4 months ago committed by GitHub
parent 700cb457b9
commit 502ffbdc63
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -250,7 +250,10 @@ def generate_single_image(
comp_image_t = comp_image_t.to(sd.lda.device, dtype=sd.lda.dtype)
init_latent = sd.lda.encode(comp_image_t)
compose_control_inputs: list[ControlInput]
if prompt.model_weights.architecture.primary_alias == "sdxl":
if prompt.model_weights.architecture.primary_alias in (
"sdxl",
"sdxlinpaint",
):
compose_control_inputs = []
else:
compose_control_inputs = [

@ -58,6 +58,12 @@ MODEL_ARCHITECTURES = [
output_modality="image",
defaults={"size": "1024"},
),
ModelArchitecture(
name="Stable Diffusion XL - Inpainting",
aliases=["sdxlinpaint", "sd-xlinpaint", "sdxl-inpaint"],
output_modality="image",
defaults={"size": "1024"},
),
ModelArchitecture(
name="Stable Video Diffusion",
aliases=["svd", "stablevideo"],
@ -162,7 +168,7 @@ MODEL_WEIGHT_CONFIGS = [
defaults={"negative_prompt": DEFAULT_NEGATIVE_PROMPT},
),
ModelWeightsConfig(
name="Modern Disney",
name="Redshift Diffusion",
aliases=["redshift-diffusion", "red", "redshift-diffusion-15", "red15"],
architecture=MODEL_ARCHITECTURE_LOOKUP["sd15"],
weights_location="https://huggingface.co/nitrosocke/redshift-diffusion/tree/80837fe18df05807861ab91c3bad3693c9342e4c/",
@ -179,6 +185,16 @@ MODEL_WEIGHT_CONFIGS = [
},
weights_location="https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/tree/462165984030d82259a11f4367a4eed129e94a7b/",
),
ModelWeightsConfig(
name="Stable Diffusion XL - Inpainting",
aliases=MODEL_ARCHITECTURE_LOOKUP["sdxl-inpaint"].aliases,
architecture=MODEL_ARCHITECTURE_LOOKUP["sdxl-inpaint"],
defaults={
"negative_prompt": DEFAULT_NEGATIVE_PROMPT,
"composition_strength": 0.6,
},
weights_location="https://huggingface.co/diffusers/stable-diffusion-xl-1.0-inpainting-0.1/tree/115134f363124c53c7d878647567d04daf26e41e/",
),
ModelWeightsConfig(
name="OpenDalle V1.1",
aliases=["opendalle11", "odv11", "opendalle11", "opendalle", "od"],

@ -5,7 +5,9 @@ import math
from functools import lru_cache
from typing import Any, List, Literal
import numpy as np
import torch
from PIL import Image
from torch import Tensor, device as Device, dtype as DType, nn
from torch.nn import functional as F
@ -16,6 +18,7 @@ from imaginairy.vendored.refiners.fluxion.layers.attentions import (
ScaledDotProductAttention,
)
from imaginairy.vendored.refiners.fluxion.layers.chain import ChainError
from imaginairy.vendored.refiners.fluxion.utils import image_to_tensor, interpolate
from imaginairy.vendored.refiners.foundationals.clip.text_encoder import (
CLIPTextEncoderL,
)
@ -395,6 +398,131 @@ class StableDiffusion_1_Inpainting(TileModeMixin, RefinerStableDiffusion_1_Inpai
return conditioning
class StableDiffusion_XL_Inpainting(StableDiffusion_XL):
def __init__(
self,
unet: SDXLUNet | None = None,
lda: SDXLAutoencoder | None = None,
clip_text_encoder: DoubleTextEncoder | None = None,
scheduler: Scheduler | None = None,
device: Device | str | None = "cpu",
dtype: DType | None = None,
) -> None:
self.mask_latents: Tensor | None = None
self.target_image_latents: Tensor | None = None
super().__init__(
unet=unet,
lda=lda,
clip_text_encoder=clip_text_encoder,
scheduler=scheduler,
device=device,
dtype=dtype,
)
def forward(
self,
x: Tensor,
step: int,
*,
clip_text_embedding: Tensor,
pooled_text_embedding: Tensor,
time_ids: Tensor | None = None,
condition_scale: float = 5.0,
**_: Tensor,
) -> Tensor:
assert self.mask_latents is not None
assert self.target_image_latents is not None
x = torch.cat(tensors=(x, self.mask_latents, self.target_image_latents), dim=1)
return super().forward(
x=x,
step=step,
clip_text_embedding=clip_text_embedding,
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
condition_scale=condition_scale,
)
def set_inpainting_conditions(
self,
target_image: Image.Image,
mask: Image.Image,
latents_size: tuple[int, int] = (64, 64),
) -> tuple[Tensor, Tensor]:
target_image = target_image.convert(mode="RGB")
mask = mask.convert(mode="L")
mask_tensor = torch.tensor(
data=np.array(object=mask).astype(dtype=np.float32) / 255.0
).to(device=self.device)
mask_tensor = (
(mask_tensor > 0.5)
.unsqueeze(dim=0)
.unsqueeze(dim=0)
.to(dtype=self.unet.dtype)
)
self.mask_latents = interpolate(x=mask_tensor, factor=torch.Size(latents_size))
init_image_tensor = (
image_to_tensor(
image=target_image, device=self.device, dtype=self.unet.dtype
)
* 2
- 1
)
masked_init_image = init_image_tensor * (1 - mask_tensor)
self.target_image_latents = self.lda.encode(
x=masked_init_image.to(dtype=self.lda.dtype)
)
assert self.target_image_latents is not None
self.target_image_latents = self.target_image_latents.to(dtype=self.unet.dtype)
return self.mask_latents, self.target_image_latents # type: ignore
def compute_self_attention_guidance(
self,
x: Tensor,
noise: Tensor,
step: int,
*,
clip_text_embedding: Tensor,
pooled_text_embedding: Tensor,
time_ids: Tensor,
**kwargs: Tensor,
) -> Tensor:
sag = self._find_sag_adapter()
assert sag is not None
assert self.mask_latents is not None
assert self.target_image_latents is not None
degraded_latents = sag.compute_degraded_latents(
scheduler=self.scheduler,
latents=x,
noise=noise,
step=step,
classifier_free_guidance=True,
)
negative_embedding, _ = clip_text_embedding.chunk(2)
negative_pooled_embedding, _ = pooled_text_embedding.chunk(2)
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
time_ids, _ = time_ids.chunk(2)
self.set_unet_context(
timestep=timestep,
clip_text_embedding=negative_embedding,
pooled_text_embedding=negative_pooled_embedding,
time_ids=time_ids,
**kwargs,
)
x = torch.cat(
tensors=(degraded_latents, self.mask_latents, self.target_image_latents),
dim=1,
)
degraded_noise = self.unet(x)
return sag.scale * (noise - degraded_noise)
class SlicedEncoderMixin(nn.Module):
max_chunk_size = 2048
min_chunk_size = 32

@ -21,7 +21,11 @@ from safetensors.torch import load_file
from imaginairy import config as iconfig
from imaginairy.config import IMAGE_WEIGHTS_SHORT_NAMES, ModelArchitecture
from imaginairy.modules import attention
from imaginairy.modules.refiners_sd import SDXLAutoencoderSliced, StableDiffusion_XL
from imaginairy.modules.refiners_sd import (
SDXLAutoencoderSliced,
StableDiffusion_XL,
StableDiffusion_XL_Inpainting,
)
from imaginairy.utils import clear_gpu_cache, get_device, instantiate_from_config
from imaginairy.utils.model_cache import memory_managed_model
from imaginairy.utils.named_resolutions import normalize_image_size
@ -268,8 +272,10 @@ def _get_diffusion_model_refiners(
device=device,
dtype=dtype,
)
elif architecture.primary_alias == "sdxl":
sd = load_sdxl_pipeline(base_url=weights_location, device=device)
elif architecture.primary_alias in ("sdxl", "sdxlinpaint"):
sd = load_sdxl_pipeline(
base_url=weights_location, device=device, for_inpainting=for_inpainting
)
else:
msg = f"Invalid architecture {architecture.primary_alias}"
raise ValueError(msg)
@ -734,7 +740,7 @@ def load_sd15_diffusers_weights(base_url: str, device=None):
def load_sdxl_pipeline_from_diffusers_weights(
base_url: str, device=None, dtype=torch.float16
base_url: str, for_inpainting=False, device=None, dtype=torch.float16
):
from imaginairy.utils import get_device
@ -764,7 +770,10 @@ def load_sdxl_pipeline_from_diffusers_weights(
source_path=unet_weights_path,
device="cpu",
)
unet = SDXLUNet(device="cpu", dtype=dtype, in_channels=4)
if for_inpainting:
unet = SDXLUNet(device="cpu", dtype=dtype, in_channels=9)
else:
unet = SDXLUNet(device="cpu", dtype=dtype, in_channels=4)
unet.load_state_dict(unet_weights, assign=True)
del unet_weights
@ -789,7 +798,12 @@ def load_sdxl_pipeline_from_diffusers_weights(
lda = lda.to(device=device, dtype=torch.float32)
unet = unet.to(device=device)
text_encoder = text_encoder.to(device=device)
sd = StableDiffusion_XL(
if for_inpainting:
StableDiffusionCls = StableDiffusion_XL_Inpainting
else:
StableDiffusionCls = StableDiffusion_XL
sd = StableDiffusionCls(
device=device, dtype=None, lda=lda, unet=unet, clip_text_encoder=text_encoder
)
@ -797,7 +811,7 @@ def load_sdxl_pipeline_from_diffusers_weights(
def load_sdxl_pipeline_from_compvis_weights(
base_url: str, device=None, dtype=torch.float16
base_url: str, for_inpainting=False, device=None, dtype=torch.float16
):
from imaginairy.utils import get_device
@ -809,7 +823,10 @@ def load_sdxl_pipeline_from_compvis_weights(
lda.load_state_dict(vae_weights, assign=True)
del vae_weights
unet = SDXLUNet(device="cpu", dtype=dtype, in_channels=4)
if for_inpainting:
unet = SDXLUNet(device="cpu", dtype=dtype, in_channels=9)
else:
unet = SDXLUNet(device="cpu", dtype=dtype, in_channels=4)
unet.load_state_dict(unet_weights, assign=True)
del unet_weights
@ -819,21 +836,31 @@ def load_sdxl_pipeline_from_compvis_weights(
lda = lda.to(device=device, dtype=torch.float32)
unet = unet.to(device=device)
text_encoder = text_encoder.to(device=device)
sd = StableDiffusion_XL(
if for_inpainting:
StableDiffusionCls = StableDiffusion_XL_Inpainting
else:
StableDiffusionCls = StableDiffusion_XL
sd = StableDiffusionCls(
device=device, dtype=None, lda=lda, unet=unet, clip_text_encoder=text_encoder
)
return sd
def load_sdxl_pipeline(base_url, device=None):
def load_sdxl_pipeline(base_url, device=None, for_inpainting=False):
logger.info(f"Loading SDXL weights from {base_url}")
device = device or get_device()
with logger.timed_info(f"Loaded SDXL pipeline from {base_url}"):
if is_diffusers_repo_url(base_url):
sd = load_sdxl_pipeline_from_diffusers_weights(base_url, device=device)
sd = load_sdxl_pipeline_from_diffusers_weights(
base_url, for_inpainting=for_inpainting, device=device
)
else:
sd = load_sdxl_pipeline_from_compvis_weights(base_url, device=device)
sd = load_sdxl_pipeline_from_compvis_weights(
base_url, for_inpainting=for_inpainting, device=device
)
return sd

@ -3,7 +3,8 @@
"text_model.embeddings.token_embedding": "Sum.TokenEncoder",
"text_model.embeddings.position_embedding": "Sum.PositionalEncoder.Embedding",
"text_model.final_layer_norm": "LayerNorm",
"text_projection": "Linear"
"text_projection": "Linear",
"text_model.embeddings.position_ids": null
},
"regex_map": {
"text_model\\.encoder\\.layers\\.(?P<layer>\\d+)\\.layer_norm(?P<norm>\\d+)": "TransformerLayer_{int(layer) + 1}.Residual_{norm}.LayerNorm",

Loading…
Cancel
Save