mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-11-19 03:25:41 +00:00
311 lines
10 KiB
Python
311 lines
10 KiB
Python
"""Refinement modules for image generation"""
|
|
|
|
import logging
|
|
import math
|
|
from functools import lru_cache
|
|
from typing import Literal
|
|
|
|
import torch
|
|
from refiners.fluxion.layers.attentions import ScaledDotProductAttention
|
|
from refiners.fluxion.layers.chain import ChainError
|
|
from refiners.foundationals.latent_diffusion import (
|
|
SD1ControlnetAdapter,
|
|
SD1UNet,
|
|
StableDiffusion_1 as RefinerStableDiffusion_1,
|
|
StableDiffusion_1_Inpainting as RefinerStableDiffusion_1_Inpainting,
|
|
)
|
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import (
|
|
Controlnet,
|
|
)
|
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import (
|
|
SD1Autoencoder,
|
|
)
|
|
from torch import Tensor, nn
|
|
from torch.nn import functional as F
|
|
|
|
from imaginairy.utils.feather_tile import rebuild_image, tile_image
|
|
from imaginairy.weight_management.conversion import cast_weights
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
TileModeType = Literal["", "x", "y", "xy"]
|
|
|
|
|
|
def _tile_mode_conv2d_conv_forward(
|
|
self, input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor # noqa
|
|
):
|
|
if self.padding_mode_x == self.padding_mode_y:
|
|
self.padding_mode = self.padding_mode_x
|
|
return self._orig_conv_forward(input, weight, bias)
|
|
|
|
w1 = F.pad(input, self.padding_x, mode=self.padding_modeX)
|
|
del input
|
|
|
|
w2 = F.pad(w1, self.padding_y, mode=self.padding_modeY)
|
|
del w1
|
|
|
|
return F.conv2d(w2, weight, bias, self.stride, (0, 0), self.dilation, self.groups)
|
|
|
|
|
|
class TileModeMixin(nn.Module):
|
|
def set_tile_mode(self, tile_mode: TileModeType = ""):
|
|
"""
|
|
For creating seamless tile images.
|
|
|
|
Args:
|
|
tile_mode: One of "", "x", "y", "xy". If "x", the image will be tiled horizontally. If "y", the image will be
|
|
tiled vertically. If "xy", the image will be tiled both horizontally and vertically.
|
|
"""
|
|
padding_mode_x = "circular" if "x" in tile_mode else "constant"
|
|
padding_mode_y = "circular" if "y" in tile_mode else "constant"
|
|
for m in self.modules():
|
|
if not isinstance(m, nn.Conv2d):
|
|
continue
|
|
if not hasattr(m, "_orig_conv_forward"):
|
|
# patch with a function that can handle tiling in a single direction
|
|
m._initial_padding_mode = m.padding_mode # type: ignore
|
|
m._orig_conv_forward = m._conv_forward # type: ignore
|
|
m._conv_forward = _tile_mode_conv2d_conv_forward.__get__(m, nn.Conv2d) # type: ignore
|
|
m.padding_mode_x = padding_mode_x # type: ignore
|
|
m.padding_mode_y = padding_mode_y # type: ignore
|
|
rprt: list[int] = m._reversed_padding_repeated_twice
|
|
m.padding_x = (rprt[0], rprt[1], 0, 0) # type: ignore
|
|
m.padding_y = (0, 0, rprt[2], rprt[3]) # type: ignore
|
|
|
|
|
|
class StableDiffusion_1(TileModeMixin, RefinerStableDiffusion_1):
|
|
pass
|
|
|
|
|
|
class StableDiffusion_1_Inpainting(TileModeMixin, RefinerStableDiffusion_1_Inpainting):
|
|
def compute_self_attention_guidance(
|
|
self,
|
|
x: Tensor,
|
|
noise: Tensor,
|
|
step: int,
|
|
*,
|
|
clip_text_embedding: Tensor,
|
|
**kwargs: Tensor,
|
|
) -> Tensor:
|
|
sag = self._find_sag_adapter()
|
|
assert sag is not None
|
|
assert self.mask_latents is not None
|
|
assert self.target_image_latents is not None
|
|
|
|
degraded_latents = sag.compute_degraded_latents(
|
|
scheduler=self.scheduler,
|
|
latents=x,
|
|
noise=noise,
|
|
step=step,
|
|
classifier_free_guidance=True,
|
|
)
|
|
|
|
negative_embedding, _ = clip_text_embedding.chunk(2)
|
|
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
|
|
self.set_unet_context(
|
|
timestep=timestep, clip_text_embedding=negative_embedding, **kwargs
|
|
)
|
|
x = torch.cat(
|
|
tensors=(degraded_latents, self.mask_latents, self.target_image_latents),
|
|
dim=1,
|
|
)
|
|
degraded_noise = self.unet(x)
|
|
|
|
return sag.scale * (noise - degraded_noise)
|
|
|
|
|
|
class SD1AutoencoderSliced(SD1Autoencoder):
|
|
max_chunk_size = 2048
|
|
min_chunk_size = 64
|
|
|
|
def encode(self, x: Tensor) -> Tensor:
|
|
return self.sliced_encode(x)
|
|
|
|
def sliced_encode(self, x: Tensor, chunk_size: int = 128 * 8) -> Tensor:
|
|
"""
|
|
Encodes the image in slices (for lower memory usage).
|
|
"""
|
|
b, c, h, w = x.size()
|
|
final_tensor = torch.zeros(
|
|
[1, 4, math.floor(h / 8), math.floor(w / 8)], device=x.device
|
|
)
|
|
overlap_pct = 0.5
|
|
|
|
for x_img in x.split(1):
|
|
chunks = tile_image(
|
|
x_img, tile_size=chunk_size, overlap_percent=overlap_pct
|
|
)
|
|
encoded_chunks = [super(SD1Autoencoder, self).encode(ic) for ic in chunks]
|
|
|
|
final_tensor = rebuild_image(
|
|
encoded_chunks,
|
|
base_img=final_tensor,
|
|
tile_size=chunk_size // 8,
|
|
overlap_percent=overlap_pct,
|
|
)
|
|
|
|
return final_tensor
|
|
|
|
def decode(self, x):
|
|
while self.__class__.max_chunk_size > self.__class__.min_chunk_size:
|
|
if self.max_chunk_size**2 > x.shape[2] * x.shape[3]:
|
|
try:
|
|
return self.decode_all_at_once(x)
|
|
except ChainError as e:
|
|
if "OutOfMemoryError" not in str(e):
|
|
raise
|
|
self.__class__.max_chunk_size = (
|
|
int(math.sqrt(x.shape[2] * x.shape[3])) // 2
|
|
)
|
|
logger.info(
|
|
f"Ran out of memory. Trying tiled decode with chunk size {self.__class__.max_chunk_size}"
|
|
)
|
|
else:
|
|
try:
|
|
return self.decode_sliced(x, chunk_size=self.max_chunk_size)
|
|
except ChainError as e:
|
|
if "OutOfMemoryError" not in str(e):
|
|
raise
|
|
self.__class__.max_chunk_size = self.max_chunk_size // 2
|
|
self.__class__.max_chunk_size = max(
|
|
self.__class__.max_chunk_size, self.__class__.min_chunk_size
|
|
)
|
|
logger.info(
|
|
f"Ran out of memory. Trying tiled decode with chunk size {self.__class__.max_chunk_size}"
|
|
)
|
|
raise RuntimeError("Could not decode image")
|
|
|
|
def decode_all_at_once(self, x: Tensor) -> Tensor:
|
|
decoder = self[1]
|
|
x = decoder(x / self.encoder_scale)
|
|
return x
|
|
|
|
def decode_sliced(self, x, chunk_size=128):
|
|
"""
|
|
decodes the tensor in slices.
|
|
|
|
This results in image portions that don't exactly match, so we overlap, feather, and merge to reduce
|
|
(but not completely eliminate) impact.
|
|
"""
|
|
b, c, h, w = x.size()
|
|
final_tensor = torch.zeros([1, 3, h * 8, w * 8], device=x.device)
|
|
for x_latent in x.split(1):
|
|
decoded_chunks = []
|
|
overlap_pct = 0.5
|
|
chunks = tile_image(
|
|
x_latent, tile_size=chunk_size, overlap_percent=overlap_pct
|
|
)
|
|
|
|
for latent_chunk in chunks:
|
|
# latent_chunk = self.post_quant_conv(latent_chunk)
|
|
dec = self.decode_all_at_once(latent_chunk)
|
|
decoded_chunks.append(dec)
|
|
final_tensor = rebuild_image(
|
|
decoded_chunks,
|
|
base_img=final_tensor,
|
|
tile_size=chunk_size * 8,
|
|
overlap_percent=overlap_pct,
|
|
)
|
|
|
|
return final_tensor
|
|
|
|
|
|
def add_sliced_attention_to_scaled_dot_product_attention(cls):
|
|
"""
|
|
Patch refiners ScaledDotProductAttention so that it uses sliced attention
|
|
|
|
It reduces peak memory usage.
|
|
"""
|
|
|
|
def _sliced_attention(self, query, key, value, slice_size, is_causal=None):
|
|
_, num_queries, _ = query.shape
|
|
output = torch.zeros_like(query)
|
|
for start_idx in range(0, num_queries, slice_size):
|
|
end_idx = min(start_idx + slice_size, num_queries)
|
|
output[:, start_idx:end_idx, :] = self._process_attention(
|
|
query[:, start_idx:end_idx, :], key, value, is_causal
|
|
)
|
|
return output
|
|
|
|
cls._sliced_attention = _sliced_attention
|
|
|
|
def new_forward(self, query, key, value, is_causal=None):
|
|
return self._sliced_attention(
|
|
query, key, value, is_causal=is_causal, slice_size=2048
|
|
)
|
|
|
|
cls.forward = new_forward
|
|
|
|
def _process_attention(self, query, key, value, is_causal=None):
|
|
return self.merge_multi_head(
|
|
x=self.dot_product(
|
|
query=self.split_to_multi_head(query),
|
|
key=self.split_to_multi_head(key),
|
|
value=self.split_to_multi_head(value),
|
|
is_causal=(
|
|
is_causal
|
|
if is_causal is not None
|
|
else (self.is_causal if self.is_causal is not None else False)
|
|
),
|
|
)
|
|
)
|
|
|
|
cls._process_attention = _process_attention
|
|
logger.debug(f"Patched {cls.__name__} with sliced attention")
|
|
|
|
|
|
add_sliced_attention_to_scaled_dot_product_attention(ScaledDotProductAttention)
|
|
|
|
|
|
@lru_cache
|
|
def monkeypatch_sd1controlnetadapter():
|
|
"""
|
|
Another horrible thing.
|
|
|
|
I needed to be able to cache the controlnet objects so I wouldn't be making new ones on every image generation.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
target: SD1UNet,
|
|
name: str,
|
|
weights_location: str,
|
|
) -> None:
|
|
self.name = name
|
|
controlnet = get_controlnet(
|
|
name=name,
|
|
weights_location=weights_location,
|
|
device=target.device,
|
|
dtype=target.dtype,
|
|
)
|
|
|
|
self._controlnet: list[Controlnet] = [ # type: ignore
|
|
controlnet
|
|
] # not registered by PyTorch
|
|
|
|
with self.setup_adapter(target):
|
|
super(SD1ControlnetAdapter, self).__init__(target)
|
|
|
|
SD1ControlnetAdapter.__init__ = __init__
|
|
|
|
|
|
monkeypatch_sd1controlnetadapter()
|
|
|
|
|
|
@lru_cache(maxsize=4)
|
|
def get_controlnet(name, weights_location, device, dtype):
|
|
from imaginairy.utils.model_manager import load_state_dict
|
|
|
|
controlnet_state_dict = load_state_dict(weights_location, half_mode=False)
|
|
controlnet_state_dict = cast_weights(
|
|
source_weights=controlnet_state_dict,
|
|
source_model_name="controlnet-1-1",
|
|
source_component_name="all",
|
|
source_format="diffusers",
|
|
dest_format="refiners",
|
|
)
|
|
|
|
controlnet = Controlnet(name=name, scale=1, device=device, dtype=dtype)
|
|
controlnet.load_state_dict(controlnet_state_dict)
|
|
return controlnet
|