|
|
|
@ -2,6 +2,7 @@ import math
|
|
|
|
|
from typing import Literal
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from refiners.fluxion.layers.attentions import ScaledDotProductAttention
|
|
|
|
|
from refiners.fluxion.layers.chain import ChainError
|
|
|
|
|
from refiners.foundationals.latent_diffusion import (
|
|
|
|
|
StableDiffusion_1 as RefinerStableDiffusion_1,
|
|
|
|
@ -210,3 +211,50 @@ class SD1AutoencoderSliced(SD1Autoencoder):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return final_tensor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def add_sliced_attention_to_scaled_dot_product_attention(cls):
|
|
|
|
|
"""
|
|
|
|
|
Patch refiners ScaledDotProductAttention so that it uses sliced attention
|
|
|
|
|
|
|
|
|
|
It reduces peak memory usage.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def _sliced_attention(self, query, key, value, slice_size, is_causal=None):
|
|
|
|
|
_, num_queries, _ = query.shape
|
|
|
|
|
output = torch.zeros_like(query)
|
|
|
|
|
for start_idx in range(0, num_queries, slice_size):
|
|
|
|
|
end_idx = min(start_idx + slice_size, num_queries)
|
|
|
|
|
output[:, start_idx:end_idx, :] = self._process_attention(
|
|
|
|
|
query[:, start_idx:end_idx, :], key, value, is_causal
|
|
|
|
|
)
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
cls._sliced_attention = _sliced_attention
|
|
|
|
|
|
|
|
|
|
def new_forward(self, query, key, value, is_causal=None):
|
|
|
|
|
return self._sliced_attention(
|
|
|
|
|
query, key, value, is_causal=is_causal, slice_size=2048
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
cls.forward = new_forward
|
|
|
|
|
|
|
|
|
|
def _process_attention(self, query, key, value, is_causal=None):
|
|
|
|
|
return self.merge_multi_head(
|
|
|
|
|
x=self.dot_product(
|
|
|
|
|
query=self.split_to_multi_head(query),
|
|
|
|
|
key=self.split_to_multi_head(key),
|
|
|
|
|
value=self.split_to_multi_head(value),
|
|
|
|
|
is_causal=(
|
|
|
|
|
is_causal
|
|
|
|
|
if is_causal is not None
|
|
|
|
|
else (self.is_causal if self.is_causal is not None else False)
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
cls._process_attention = _process_attention
|
|
|
|
|
logger.debug(f"Patched {cls.__name__} with sliced attention")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
add_sliced_attention_to_scaled_dot_product_attention(ScaledDotProductAttention)
|
|
|
|
|