imaginAIry/imaginairy/modules/refiners_sd.py
2023-12-21 05:48:02 -08:00

311 lines
10 KiB
Python

"""Refinement modules for image generation"""
import logging
import math
from functools import lru_cache
from typing import Literal
import torch
from refiners.fluxion.layers.attentions import ScaledDotProductAttention
from refiners.fluxion.layers.chain import ChainError
from refiners.foundationals.latent_diffusion import (
SD1ControlnetAdapter,
SD1UNet,
StableDiffusion_1 as RefinerStableDiffusion_1,
StableDiffusion_1_Inpainting as RefinerStableDiffusion_1_Inpainting,
)
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import (
Controlnet,
)
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import (
SD1Autoencoder,
)
from torch import Tensor, nn
from torch.nn import functional as F
from imaginairy.utils.feather_tile import rebuild_image, tile_image
from imaginairy.weight_management.conversion import cast_weights
logger = logging.getLogger(__name__)
TileModeType = Literal["", "x", "y", "xy"]
def _tile_mode_conv2d_conv_forward(
self, input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor # noqa
):
if self.padding_mode_x == self.padding_mode_y:
self.padding_mode = self.padding_mode_x
return self._orig_conv_forward(input, weight, bias)
w1 = F.pad(input, self.padding_x, mode=self.padding_modeX)
del input
w2 = F.pad(w1, self.padding_y, mode=self.padding_modeY)
del w1
return F.conv2d(w2, weight, bias, self.stride, (0, 0), self.dilation, self.groups)
class TileModeMixin(nn.Module):
def set_tile_mode(self, tile_mode: TileModeType = ""):
"""
For creating seamless tile images.
Args:
tile_mode: One of "", "x", "y", "xy". If "x", the image will be tiled horizontally. If "y", the image will be
tiled vertically. If "xy", the image will be tiled both horizontally and vertically.
"""
padding_mode_x = "circular" if "x" in tile_mode else "constant"
padding_mode_y = "circular" if "y" in tile_mode else "constant"
for m in self.modules():
if not isinstance(m, nn.Conv2d):
continue
if not hasattr(m, "_orig_conv_forward"):
# patch with a function that can handle tiling in a single direction
m._initial_padding_mode = m.padding_mode # type: ignore
m._orig_conv_forward = m._conv_forward # type: ignore
m._conv_forward = _tile_mode_conv2d_conv_forward.__get__(m, nn.Conv2d) # type: ignore
m.padding_mode_x = padding_mode_x # type: ignore
m.padding_mode_y = padding_mode_y # type: ignore
rprt: list[int] = m._reversed_padding_repeated_twice
m.padding_x = (rprt[0], rprt[1], 0, 0) # type: ignore
m.padding_y = (0, 0, rprt[2], rprt[3]) # type: ignore
class StableDiffusion_1(TileModeMixin, RefinerStableDiffusion_1):
pass
class StableDiffusion_1_Inpainting(TileModeMixin, RefinerStableDiffusion_1_Inpainting):
def compute_self_attention_guidance(
self,
x: Tensor,
noise: Tensor,
step: int,
*,
clip_text_embedding: Tensor,
**kwargs: Tensor,
) -> Tensor:
sag = self._find_sag_adapter()
assert sag is not None
assert self.mask_latents is not None
assert self.target_image_latents is not None
degraded_latents = sag.compute_degraded_latents(
scheduler=self.scheduler,
latents=x,
noise=noise,
step=step,
classifier_free_guidance=True,
)
negative_embedding, _ = clip_text_embedding.chunk(2)
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
self.set_unet_context(
timestep=timestep, clip_text_embedding=negative_embedding, **kwargs
)
x = torch.cat(
tensors=(degraded_latents, self.mask_latents, self.target_image_latents),
dim=1,
)
degraded_noise = self.unet(x)
return sag.scale * (noise - degraded_noise)
class SD1AutoencoderSliced(SD1Autoencoder):
max_chunk_size = 2048
min_chunk_size = 64
def encode(self, x: Tensor) -> Tensor:
return self.sliced_encode(x)
def sliced_encode(self, x: Tensor, chunk_size: int = 128 * 8) -> Tensor:
"""
Encodes the image in slices (for lower memory usage).
"""
b, c, h, w = x.size()
final_tensor = torch.zeros(
[1, 4, math.floor(h / 8), math.floor(w / 8)], device=x.device
)
overlap_pct = 0.5
for x_img in x.split(1):
chunks = tile_image(
x_img, tile_size=chunk_size, overlap_percent=overlap_pct
)
encoded_chunks = [super(SD1Autoencoder, self).encode(ic) for ic in chunks]
final_tensor = rebuild_image(
encoded_chunks,
base_img=final_tensor,
tile_size=chunk_size // 8,
overlap_percent=overlap_pct,
)
return final_tensor
def decode(self, x):
while self.__class__.max_chunk_size > self.__class__.min_chunk_size:
if self.max_chunk_size**2 > x.shape[2] * x.shape[3]:
try:
return self.decode_all_at_once(x)
except ChainError as e:
if "OutOfMemoryError" not in str(e):
raise
self.__class__.max_chunk_size = (
int(math.sqrt(x.shape[2] * x.shape[3])) // 2
)
logger.info(
f"Ran out of memory. Trying tiled decode with chunk size {self.__class__.max_chunk_size}"
)
else:
try:
return self.decode_sliced(x, chunk_size=self.max_chunk_size)
except ChainError as e:
if "OutOfMemoryError" not in str(e):
raise
self.__class__.max_chunk_size = self.max_chunk_size // 2
self.__class__.max_chunk_size = max(
self.__class__.max_chunk_size, self.__class__.min_chunk_size
)
logger.info(
f"Ran out of memory. Trying tiled decode with chunk size {self.__class__.max_chunk_size}"
)
raise RuntimeError("Could not decode image")
def decode_all_at_once(self, x: Tensor) -> Tensor:
decoder = self[1]
x = decoder(x / self.encoder_scale)
return x
def decode_sliced(self, x, chunk_size=128):
"""
decodes the tensor in slices.
This results in image portions that don't exactly match, so we overlap, feather, and merge to reduce
(but not completely eliminate) impact.
"""
b, c, h, w = x.size()
final_tensor = torch.zeros([1, 3, h * 8, w * 8], device=x.device)
for x_latent in x.split(1):
decoded_chunks = []
overlap_pct = 0.5
chunks = tile_image(
x_latent, tile_size=chunk_size, overlap_percent=overlap_pct
)
for latent_chunk in chunks:
# latent_chunk = self.post_quant_conv(latent_chunk)
dec = self.decode_all_at_once(latent_chunk)
decoded_chunks.append(dec)
final_tensor = rebuild_image(
decoded_chunks,
base_img=final_tensor,
tile_size=chunk_size * 8,
overlap_percent=overlap_pct,
)
return final_tensor
def add_sliced_attention_to_scaled_dot_product_attention(cls):
"""
Patch refiners ScaledDotProductAttention so that it uses sliced attention
It reduces peak memory usage.
"""
def _sliced_attention(self, query, key, value, slice_size, is_causal=None):
_, num_queries, _ = query.shape
output = torch.zeros_like(query)
for start_idx in range(0, num_queries, slice_size):
end_idx = min(start_idx + slice_size, num_queries)
output[:, start_idx:end_idx, :] = self._process_attention(
query[:, start_idx:end_idx, :], key, value, is_causal
)
return output
cls._sliced_attention = _sliced_attention
def new_forward(self, query, key, value, is_causal=None):
return self._sliced_attention(
query, key, value, is_causal=is_causal, slice_size=2048
)
cls.forward = new_forward
def _process_attention(self, query, key, value, is_causal=None):
return self.merge_multi_head(
x=self.dot_product(
query=self.split_to_multi_head(query),
key=self.split_to_multi_head(key),
value=self.split_to_multi_head(value),
is_causal=(
is_causal
if is_causal is not None
else (self.is_causal if self.is_causal is not None else False)
),
)
)
cls._process_attention = _process_attention
logger.debug(f"Patched {cls.__name__} with sliced attention")
add_sliced_attention_to_scaled_dot_product_attention(ScaledDotProductAttention)
@lru_cache
def monkeypatch_sd1controlnetadapter():
"""
Another horrible thing.
I needed to be able to cache the controlnet objects so I wouldn't be making new ones on every image generation.
"""
def __init__(
self,
target: SD1UNet,
name: str,
weights_location: str,
) -> None:
self.name = name
controlnet = get_controlnet(
name=name,
weights_location=weights_location,
device=target.device,
dtype=target.dtype,
)
self._controlnet: list[Controlnet] = [ # type: ignore
controlnet
] # not registered by PyTorch
with self.setup_adapter(target):
super(SD1ControlnetAdapter, self).__init__(target)
SD1ControlnetAdapter.__init__ = __init__
monkeypatch_sd1controlnetadapter()
@lru_cache(maxsize=4)
def get_controlnet(name, weights_location, device, dtype):
from imaginairy.utils.model_manager import load_state_dict
controlnet_state_dict = load_state_dict(weights_location, half_mode=False)
controlnet_state_dict = cast_weights(
source_weights=controlnet_state_dict,
source_model_name="controlnet-1-1",
source_component_name="all",
source_format="diffusers",
dest_format="refiners",
)
controlnet = Controlnet(name=name, scale=1, device=device, dtype=dtype)
controlnet.load_state_dict(controlnet_state_dict)
return controlnet