|
|
@ -21,7 +21,11 @@ from safetensors.torch import load_file
|
|
|
|
from imaginairy import config as iconfig
|
|
|
|
from imaginairy import config as iconfig
|
|
|
|
from imaginairy.config import IMAGE_WEIGHTS_SHORT_NAMES, ModelArchitecture
|
|
|
|
from imaginairy.config import IMAGE_WEIGHTS_SHORT_NAMES, ModelArchitecture
|
|
|
|
from imaginairy.modules import attention
|
|
|
|
from imaginairy.modules import attention
|
|
|
|
from imaginairy.modules.refiners_sd import SDXLAutoencoderSliced, StableDiffusion_XL
|
|
|
|
from imaginairy.modules.refiners_sd import (
|
|
|
|
|
|
|
|
SDXLAutoencoderSliced,
|
|
|
|
|
|
|
|
StableDiffusion_XL,
|
|
|
|
|
|
|
|
StableDiffusion_XL_Inpainting,
|
|
|
|
|
|
|
|
)
|
|
|
|
from imaginairy.utils import clear_gpu_cache, get_device, instantiate_from_config
|
|
|
|
from imaginairy.utils import clear_gpu_cache, get_device, instantiate_from_config
|
|
|
|
from imaginairy.utils.model_cache import memory_managed_model
|
|
|
|
from imaginairy.utils.model_cache import memory_managed_model
|
|
|
|
from imaginairy.utils.named_resolutions import normalize_image_size
|
|
|
|
from imaginairy.utils.named_resolutions import normalize_image_size
|
|
|
@ -268,8 +272,10 @@ def _get_diffusion_model_refiners(
|
|
|
|
device=device,
|
|
|
|
device=device,
|
|
|
|
dtype=dtype,
|
|
|
|
dtype=dtype,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
elif architecture.primary_alias == "sdxl":
|
|
|
|
elif architecture.primary_alias in ("sdxl", "sdxlinpaint"):
|
|
|
|
sd = load_sdxl_pipeline(base_url=weights_location, device=device)
|
|
|
|
sd = load_sdxl_pipeline(
|
|
|
|
|
|
|
|
base_url=weights_location, device=device, for_inpainting=for_inpainting
|
|
|
|
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
msg = f"Invalid architecture {architecture.primary_alias}"
|
|
|
|
msg = f"Invalid architecture {architecture.primary_alias}"
|
|
|
|
raise ValueError(msg)
|
|
|
|
raise ValueError(msg)
|
|
|
@ -734,7 +740,7 @@ def load_sd15_diffusers_weights(base_url: str, device=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_sdxl_pipeline_from_diffusers_weights(
|
|
|
|
def load_sdxl_pipeline_from_diffusers_weights(
|
|
|
|
base_url: str, device=None, dtype=torch.float16
|
|
|
|
base_url: str, for_inpainting=False, device=None, dtype=torch.float16
|
|
|
|
):
|
|
|
|
):
|
|
|
|
from imaginairy.utils import get_device
|
|
|
|
from imaginairy.utils import get_device
|
|
|
|
|
|
|
|
|
|
|
@ -764,7 +770,10 @@ def load_sdxl_pipeline_from_diffusers_weights(
|
|
|
|
source_path=unet_weights_path,
|
|
|
|
source_path=unet_weights_path,
|
|
|
|
device="cpu",
|
|
|
|
device="cpu",
|
|
|
|
)
|
|
|
|
)
|
|
|
|
unet = SDXLUNet(device="cpu", dtype=dtype, in_channels=4)
|
|
|
|
if for_inpainting:
|
|
|
|
|
|
|
|
unet = SDXLUNet(device="cpu", dtype=dtype, in_channels=9)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
unet = SDXLUNet(device="cpu", dtype=dtype, in_channels=4)
|
|
|
|
unet.load_state_dict(unet_weights, assign=True)
|
|
|
|
unet.load_state_dict(unet_weights, assign=True)
|
|
|
|
del unet_weights
|
|
|
|
del unet_weights
|
|
|
|
|
|
|
|
|
|
|
@ -789,7 +798,12 @@ def load_sdxl_pipeline_from_diffusers_weights(
|
|
|
|
lda = lda.to(device=device, dtype=torch.float32)
|
|
|
|
lda = lda.to(device=device, dtype=torch.float32)
|
|
|
|
unet = unet.to(device=device)
|
|
|
|
unet = unet.to(device=device)
|
|
|
|
text_encoder = text_encoder.to(device=device)
|
|
|
|
text_encoder = text_encoder.to(device=device)
|
|
|
|
sd = StableDiffusion_XL(
|
|
|
|
if for_inpainting:
|
|
|
|
|
|
|
|
StableDiffusionCls = StableDiffusion_XL_Inpainting
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
StableDiffusionCls = StableDiffusion_XL
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sd = StableDiffusionCls(
|
|
|
|
device=device, dtype=None, lda=lda, unet=unet, clip_text_encoder=text_encoder
|
|
|
|
device=device, dtype=None, lda=lda, unet=unet, clip_text_encoder=text_encoder
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
@ -797,7 +811,7 @@ def load_sdxl_pipeline_from_diffusers_weights(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_sdxl_pipeline_from_compvis_weights(
|
|
|
|
def load_sdxl_pipeline_from_compvis_weights(
|
|
|
|
base_url: str, device=None, dtype=torch.float16
|
|
|
|
base_url: str, for_inpainting=False, device=None, dtype=torch.float16
|
|
|
|
):
|
|
|
|
):
|
|
|
|
from imaginairy.utils import get_device
|
|
|
|
from imaginairy.utils import get_device
|
|
|
|
|
|
|
|
|
|
|
@ -809,7 +823,10 @@ def load_sdxl_pipeline_from_compvis_weights(
|
|
|
|
lda.load_state_dict(vae_weights, assign=True)
|
|
|
|
lda.load_state_dict(vae_weights, assign=True)
|
|
|
|
del vae_weights
|
|
|
|
del vae_weights
|
|
|
|
|
|
|
|
|
|
|
|
unet = SDXLUNet(device="cpu", dtype=dtype, in_channels=4)
|
|
|
|
if for_inpainting:
|
|
|
|
|
|
|
|
unet = SDXLUNet(device="cpu", dtype=dtype, in_channels=9)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
unet = SDXLUNet(device="cpu", dtype=dtype, in_channels=4)
|
|
|
|
unet.load_state_dict(unet_weights, assign=True)
|
|
|
|
unet.load_state_dict(unet_weights, assign=True)
|
|
|
|
del unet_weights
|
|
|
|
del unet_weights
|
|
|
|
|
|
|
|
|
|
|
@ -819,21 +836,31 @@ def load_sdxl_pipeline_from_compvis_weights(
|
|
|
|
lda = lda.to(device=device, dtype=torch.float32)
|
|
|
|
lda = lda.to(device=device, dtype=torch.float32)
|
|
|
|
unet = unet.to(device=device)
|
|
|
|
unet = unet.to(device=device)
|
|
|
|
text_encoder = text_encoder.to(device=device)
|
|
|
|
text_encoder = text_encoder.to(device=device)
|
|
|
|
sd = StableDiffusion_XL(
|
|
|
|
|
|
|
|
|
|
|
|
if for_inpainting:
|
|
|
|
|
|
|
|
StableDiffusionCls = StableDiffusion_XL_Inpainting
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
StableDiffusionCls = StableDiffusion_XL
|
|
|
|
|
|
|
|
sd = StableDiffusionCls(
|
|
|
|
device=device, dtype=None, lda=lda, unet=unet, clip_text_encoder=text_encoder
|
|
|
|
device=device, dtype=None, lda=lda, unet=unet, clip_text_encoder=text_encoder
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
return sd
|
|
|
|
return sd
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_sdxl_pipeline(base_url, device=None):
|
|
|
|
def load_sdxl_pipeline(base_url, device=None, for_inpainting=False):
|
|
|
|
|
|
|
|
logger.info(f"Loading SDXL weights from {base_url}")
|
|
|
|
device = device or get_device()
|
|
|
|
device = device or get_device()
|
|
|
|
|
|
|
|
|
|
|
|
with logger.timed_info(f"Loaded SDXL pipeline from {base_url}"):
|
|
|
|
with logger.timed_info(f"Loaded SDXL pipeline from {base_url}"):
|
|
|
|
if is_diffusers_repo_url(base_url):
|
|
|
|
if is_diffusers_repo_url(base_url):
|
|
|
|
sd = load_sdxl_pipeline_from_diffusers_weights(base_url, device=device)
|
|
|
|
sd = load_sdxl_pipeline_from_diffusers_weights(
|
|
|
|
|
|
|
|
base_url, for_inpainting=for_inpainting, device=device
|
|
|
|
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
sd = load_sdxl_pipeline_from_compvis_weights(base_url, device=device)
|
|
|
|
sd = load_sdxl_pipeline_from_compvis_weights(
|
|
|
|
|
|
|
|
base_url, for_inpainting=for_inpainting, device=device
|
|
|
|
|
|
|
|
)
|
|
|
|
return sd
|
|
|
|
return sd
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|