feature: patch refiners ScaledDotProductAttention for sliced attention

pull/408/head
Bryce 6 months ago committed by Bryce Drennan
parent 1b15d6dcd4
commit ba57393022

@ -2,6 +2,7 @@ import math
from typing import Literal
import torch
from refiners.fluxion.layers.attentions import ScaledDotProductAttention
from refiners.fluxion.layers.chain import ChainError
from refiners.foundationals.latent_diffusion import (
StableDiffusion_1 as RefinerStableDiffusion_1,
@ -210,3 +211,50 @@ class SD1AutoencoderSliced(SD1Autoencoder):
)
return final_tensor
def add_sliced_attention_to_scaled_dot_product_attention(cls):
"""
Patch refiners ScaledDotProductAttention so that it uses sliced attention
It reduces peak memory usage.
"""
def _sliced_attention(self, query, key, value, slice_size, is_causal=None):
_, num_queries, _ = query.shape
output = torch.zeros_like(query)
for start_idx in range(0, num_queries, slice_size):
end_idx = min(start_idx + slice_size, num_queries)
output[:, start_idx:end_idx, :] = self._process_attention(
query[:, start_idx:end_idx, :], key, value, is_causal
)
return output
cls._sliced_attention = _sliced_attention
def new_forward(self, query, key, value, is_causal=None):
return self._sliced_attention(
query, key, value, is_causal=is_causal, slice_size=2048
)
cls.forward = new_forward
def _process_attention(self, query, key, value, is_causal=None):
return self.merge_multi_head(
x=self.dot_product(
query=self.split_to_multi_head(query),
key=self.split_to_multi_head(key),
value=self.split_to_multi_head(value),
is_causal=(
is_causal
if is_causal is not None
else (self.is_causal if self.is_causal is not None else False)
),
)
)
cls._process_attention = _process_attention
logger.debug(f"Patched {cls.__name__} with sliced attention")
add_sliced_attention_to_scaled_dot_product_attention(ScaledDotProductAttention)

Loading…
Cancel
Save