You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

95 lines
3.9 KiB

import math
from typing import Any, Callable, Generic, TypeVar
import torch
from torch import Tensor
from torch.fft import fftn, fftshift, ifftn, ifftshift # type: ignore
import imaginairy.vendored.refiners.fluxion.layers as fl
from imaginairy.vendored.refiners.fluxion.adapters.adapter import Adapter
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import ResidualConcatenator, SD1UNet
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
T = TypeVar("T", bound="SD1UNet | SDXLUNet")
TSDFreeUAdapter = TypeVar("TSDFreeUAdapter", bound="SDFreeUAdapter[Any]") # Self (see PEP 673)
def fourier_filter(x: Tensor, scale: float = 1, threshold: int = 1) -> Tensor:
"""Fourier filter as introduced in FreeU (
This version of the method comes from here:
batch, channels, height, width = x.shape
dtype = x.dtype
device = x.device
if not (math.log2(height).is_integer() and math.log2(width).is_integer()):
x =
x_freq = fftn(x, dim=(-2, -1)) # type: ignore
x_freq = fftshift(x_freq, dim=(-2, -1)) # type: ignore
mask = torch.ones((batch, channels, height, width), device=device) # type: ignore
center_row, center_col = height // 2, width // 2 # type: ignore
mask[..., center_row - threshold : center_row + threshold, center_col - threshold : center_col + threshold] = scale
x_freq = x_freq * mask # type: ignore
x_freq = ifftshift(x_freq, dim=(-2, -1)) # type: ignore
x_filtered = ifftn(x_freq, dim=(-2, -1)).real # type: ignore
return # type: ignore
class FreeUBackboneFeatures(fl.Module):
def __init__(self, backbone_scale: float) -> None:
self.backbone_scale = backbone_scale
def forward(self, x: Tensor) -> Tensor:
num_half_channels = x.shape[1] // 2
x[:, :num_half_channels] = x[:, :num_half_channels] * self.backbone_scale
return x
class FreeUSkipFeatures(fl.Chain):
def __init__(self, n: int, skip_scale: float) -> None:
apply_filter: Callable[[Tensor], Tensor] = lambda x: fourier_filter(x, scale=skip_scale)
fl.UseContext(context="unet", key="residuals").compose(lambda residuals: residuals[n]),
class FreeUResidualConcatenator(fl.Concatenate):
def __init__(self, n: int, backbone_scale: float, skip_scale: float) -> None:
FreeUSkipFeatures(n, skip_scale),
class SDFreeUAdapter(Generic[T], fl.Chain, Adapter[T]):
def __init__(self, target: T, backbone_scales: list[float], skip_scales: list[float]) -> None:
assert len(backbone_scales) == len(skip_scales)
assert len(backbone_scales) <= len(target.UpBlocks)
self.backbone_scales = backbone_scales
self.skip_scales = skip_scales
with self.setup_adapter(target):
def inject(self: TSDFreeUAdapter, parent: fl.Chain | None = None) -> TSDFreeUAdapter:
for n, (backbone_scale, skip_scale) in enumerate(zip(self.backbone_scales, self.skip_scales)):
block =[n]
concat = block.ensure_find(ResidualConcatenator)
block.replace(concat, FreeUResidualConcatenator(-n - 2, backbone_scale, skip_scale))
return super().inject(parent)
def eject(self) -> None:
for n in range(len(self.backbone_scales)):
block =[n]
concat = block.ensure_find(FreeUResidualConcatenator)
block.replace(concat, ResidualConcatenator(-n - 2))