|
|
|
@ -1,13 +1,20 @@
|
|
|
|
|
import logging
|
|
|
|
|
import math
|
|
|
|
|
from functools import lru_cache
|
|
|
|
|
from typing import Literal
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from refiners.fluxion.layers.attentions import ScaledDotProductAttention
|
|
|
|
|
from refiners.fluxion.layers.chain import ChainError
|
|
|
|
|
from refiners.foundationals.latent_diffusion import (
|
|
|
|
|
SD1ControlnetAdapter,
|
|
|
|
|
SD1UNet,
|
|
|
|
|
StableDiffusion_1 as RefinerStableDiffusion_1,
|
|
|
|
|
StableDiffusion_1_Inpainting as RefinerStableDiffusion_1_Inpainting,
|
|
|
|
|
)
|
|
|
|
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import (
|
|
|
|
|
Controlnet,
|
|
|
|
|
)
|
|
|
|
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import (
|
|
|
|
|
SD1Autoencoder,
|
|
|
|
|
)
|
|
|
|
@ -16,7 +23,9 @@ from torch.nn import functional as F
|
|
|
|
|
from torch.nn.modules.utils import _pair
|
|
|
|
|
|
|
|
|
|
from imaginairy.feather_tile import rebuild_image, tile_image
|
|
|
|
|
from imaginairy.modules.autoencoder import logger
|
|
|
|
|
from imaginairy.weight_management.conversion import cast_weights
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
TileModeType = Literal["", "x", "y", "xy"]
|
|
|
|
|
|
|
|
|
@ -258,3 +267,54 @@ def add_sliced_attention_to_scaled_dot_product_attention(cls):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
add_sliced_attention_to_scaled_dot_product_attention(ScaledDotProductAttention)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@lru_cache
|
|
|
|
|
def monkeypatch_sd1controlnetadapter():
|
|
|
|
|
"""
|
|
|
|
|
Another horrible thing.
|
|
|
|
|
|
|
|
|
|
I needed to be able to cache the controlnet objects so I wouldn't be making new ones on every image generation.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
target: SD1UNet,
|
|
|
|
|
name: str,
|
|
|
|
|
weights_location: str,
|
|
|
|
|
) -> None:
|
|
|
|
|
self.name = name
|
|
|
|
|
controlnet = get_controlnet(
|
|
|
|
|
name=name,
|
|
|
|
|
weights_location=weights_location,
|
|
|
|
|
device=target.device,
|
|
|
|
|
dtype=target.dtype,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self._controlnet: list[Controlnet] = [controlnet] # not registered by PyTorch
|
|
|
|
|
|
|
|
|
|
with self.setup_adapter(target):
|
|
|
|
|
super(SD1ControlnetAdapter, self).__init__(target)
|
|
|
|
|
|
|
|
|
|
SD1ControlnetAdapter.__init__ = __init__
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
monkeypatch_sd1controlnetadapter()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@lru_cache(maxsize=4)
|
|
|
|
|
def get_controlnet(name, weights_location, device, dtype):
|
|
|
|
|
from imaginairy.model_manager import load_state_dict
|
|
|
|
|
|
|
|
|
|
controlnet_state_dict = load_state_dict(weights_location, half_mode=False)
|
|
|
|
|
controlnet_state_dict = cast_weights(
|
|
|
|
|
source_weights=controlnet_state_dict,
|
|
|
|
|
source_model_name="controlnet-1-1",
|
|
|
|
|
source_component_name="all",
|
|
|
|
|
source_format="diffusers",
|
|
|
|
|
dest_format="refiners",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
controlnet = Controlnet(name=name, scale=1, device=device, dtype=dtype)
|
|
|
|
|
controlnet.load_state_dict(controlnet_state_dict)
|
|
|
|
|
return controlnet
|
|
|
|
|