fix: cache the controlnet models

pull/411/head^2
Bryce 6 months ago committed by Bryce Drennan
parent 9b95e8b0b6
commit c299cfffd9

@ -2,7 +2,6 @@ import logging
from typing import List, Optional
from imaginairy.config import CONTROL_CONFIG_SHORTCUTS
from imaginairy.model_manager import load_controlnet_adapter
from imaginairy.schema import ImaginePrompt, MaskMode, WeightedPrompt
logger = logging.getLogger(__name__)
@ -251,13 +250,14 @@ def _generate_single_image(
if not control_config:
msg = f"Unknown control mode: {control_input.mode}"
raise ValueError(msg)
from refiners.foundationals.latent_diffusion import SD1ControlnetAdapter
controlnet = load_controlnet_adapter(
controlnet = SD1ControlnetAdapter(
name=control_input.mode,
control_weights_location=control_config.weights_location,
target_unet=sd.unet,
scale=control_input.strength,
target=sd.unet,
weights_location=control_config.weights_location,
)
controlnets.append((controlnet, control_image_t))
if prompt.allow_compose_phase:

@ -13,7 +13,7 @@ from huggingface_hub import (
try_to_load_from_cache,
)
from omegaconf import OmegaConf
from refiners.foundationals.latent_diffusion import SD1ControlnetAdapter, SD1UNet
from refiners.foundationals.latent_diffusion import SD1UNet
from safetensors.torch import load_file
from imaginairy import config as iconfig
@ -22,7 +22,6 @@ from imaginairy.modules import attention
from imaginairy.paths import PKG_ROOT
from imaginairy.utils import get_device, instantiate_from_config
from imaginairy.utils.model_cache import memory_managed_model
from imaginairy.weight_management.conversion import cast_weights
logger = logging.getLogger(__name__)
@ -363,32 +362,6 @@ def _load_diffusion_model(config_path, weights_location, half_mode):
return model
def load_controlnet_adapter(
name,
control_weights_location,
target_unet,
scale=1.0,
):
controlnet_state_dict = load_state_dict(control_weights_location, half_mode=False)
controlnet_state_dict = cast_weights(
source_weights=controlnet_state_dict,
source_model_name="controlnet-1-1",
source_component_name="all",
source_format="diffusers",
dest_format="refiners",
)
for key in controlnet_state_dict:
controlnet_state_dict[key] = controlnet_state_dict[key].to(
device=target_unet.device, dtype=target_unet.dtype
)
adapter = SD1ControlnetAdapter(
target=target_unet, name=name, scale=scale, weights=controlnet_state_dict
)
return adapter
@memory_managed_model("controlnet")
def load_controlnet(control_weights_location, half_mode):
controlnet_state_dict = load_state_dict(

@ -1,13 +1,20 @@
import logging
import math
from functools import lru_cache
from typing import Literal
import torch
from refiners.fluxion.layers.attentions import ScaledDotProductAttention
from refiners.fluxion.layers.chain import ChainError
from refiners.foundationals.latent_diffusion import (
SD1ControlnetAdapter,
SD1UNet,
StableDiffusion_1 as RefinerStableDiffusion_1,
StableDiffusion_1_Inpainting as RefinerStableDiffusion_1_Inpainting,
)
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import (
Controlnet,
)
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import (
SD1Autoencoder,
)
@ -16,7 +23,9 @@ from torch.nn import functional as F
from torch.nn.modules.utils import _pair
from imaginairy.feather_tile import rebuild_image, tile_image
from imaginairy.modules.autoencoder import logger
from imaginairy.weight_management.conversion import cast_weights
logger = logging.getLogger(__name__)
TileModeType = Literal["", "x", "y", "xy"]
@ -258,3 +267,54 @@ def add_sliced_attention_to_scaled_dot_product_attention(cls):
add_sliced_attention_to_scaled_dot_product_attention(ScaledDotProductAttention)
@lru_cache
def monkeypatch_sd1controlnetadapter():
"""
Another horrible thing.
I needed to be able to cache the controlnet objects so I wouldn't be making new ones on every image generation.
"""
def __init__(
self,
target: SD1UNet,
name: str,
weights_location: str,
) -> None:
self.name = name
controlnet = get_controlnet(
name=name,
weights_location=weights_location,
device=target.device,
dtype=target.dtype,
)
self._controlnet: list[Controlnet] = [controlnet] # not registered by PyTorch
with self.setup_adapter(target):
super(SD1ControlnetAdapter, self).__init__(target)
SD1ControlnetAdapter.__init__ = __init__
monkeypatch_sd1controlnetadapter()
@lru_cache(maxsize=4)
def get_controlnet(name, weights_location, device, dtype):
from imaginairy.model_manager import load_state_dict
controlnet_state_dict = load_state_dict(weights_location, half_mode=False)
controlnet_state_dict = cast_weights(
source_weights=controlnet_state_dict,
source_model_name="controlnet-1-1",
source_component_name="all",
source_format="diffusers",
dest_format="refiners",
)
controlnet = Controlnet(name=name, scale=1, device=device, dtype=dtype)
controlnet.load_state_dict(controlnet_state_dict)
return controlnet

Loading…
Cancel
Save