You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
imaginAIry/imaginairy/vendored/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py

289 lines
11 KiB
Python

from typing import Iterable, cast
from torch import Tensor, device as Device, dtype as DType
import imaginairy.vendored.refiners.fluxion.layers as fl
from imaginairy.vendored.refiners.fluxion.context import Contexts
from imaginairy.vendored.refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
from imaginairy.vendored.refiners.foundationals.latent_diffusion.range_adapter import RangeAdapter2d, RangeEncoder
class TimestepEncoder(fl.Passthrough):
def __init__(
self,
context_key: str = "timestep_embedding",
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__(
fl.UseContext("diffusion", "timestep"),
RangeEncoder(320, 1280, device=device, dtype=dtype),
fl.SetContext("range_adapter", context_key),
)
class ResidualBlock(fl.Sum):
def __init__(
self,
in_channels: int,
out_channels: int,
num_groups: int = 32,
eps: float = 1e-5,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
if in_channels % num_groups != 0 or out_channels % num_groups != 0:
raise ValueError("Number of input and output channels must be divisible by num_groups.")
self.in_channels = in_channels
self.out_channels = out_channels
self.num_groups = num_groups
self.eps = eps
shortcut = (
fl.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, device=device, dtype=dtype)
if in_channels != out_channels
else fl.Identity()
)
super().__init__(
fl.Chain(
fl.GroupNorm(channels=in_channels, num_groups=num_groups, eps=eps, device=device, dtype=dtype),
fl.SiLU(),
fl.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
padding=1,
device=device,
dtype=dtype,
),
fl.GroupNorm(channels=out_channels, num_groups=num_groups, eps=eps, device=device, dtype=dtype),
fl.SiLU(),
fl.Conv2d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
padding=1,
device=device,
dtype=dtype,
),
),
shortcut,
)
class CLIPLCrossAttention(CrossAttentionBlock2d):
def __init__(
self,
channels: int,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__(
channels=channels,
context_embedding_dim=768,
context_key="clip_text_embedding",
num_attention_heads=8,
use_bias=False,
device=device,
dtype=dtype,
)
class DownBlocks(fl.Chain):
def __init__(
self,
in_channels: int,
device: Device | str | None = None,
dtype: DType | None = None,
):
self.in_channels = in_channels
super().__init__(
fl.Chain(
fl.Conv2d(
in_channels=in_channels, out_channels=320, kernel_size=3, padding=1, device=device, dtype=dtype
)
),
fl.Chain(
ResidualBlock(in_channels=320, out_channels=320, device=device, dtype=dtype),
CLIPLCrossAttention(channels=320, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=320, out_channels=320, device=device, dtype=dtype),
CLIPLCrossAttention(channels=320, device=device, dtype=dtype),
),
fl.Chain(fl.Downsample(channels=320, scale_factor=2, padding=1, device=device, dtype=dtype)),
fl.Chain(
ResidualBlock(in_channels=320, out_channels=640, device=device, dtype=dtype),
CLIPLCrossAttention(channels=640, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=640, out_channels=640, device=device, dtype=dtype),
CLIPLCrossAttention(channels=640, device=device, dtype=dtype),
),
fl.Chain(fl.Downsample(channels=640, scale_factor=2, padding=1, device=device, dtype=dtype)),
fl.Chain(
ResidualBlock(in_channels=640, out_channels=1280, device=device, dtype=dtype),
CLIPLCrossAttention(channels=1280, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=1280, out_channels=1280, device=device, dtype=dtype),
CLIPLCrossAttention(channels=1280, device=device, dtype=dtype),
),
fl.Chain(fl.Downsample(channels=1280, scale_factor=2, padding=1, device=device, dtype=dtype)),
fl.Chain(
ResidualBlock(in_channels=1280, out_channels=1280, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=1280, out_channels=1280, device=device, dtype=dtype),
),
)
class UpBlocks(fl.Chain):
def __init__(
self,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__(
fl.Chain(
ResidualBlock(in_channels=2560, out_channels=1280, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=2560, out_channels=1280, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=2560, out_channels=1280, device=device, dtype=dtype),
fl.Upsample(channels=1280, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=2560, out_channels=1280, device=device, dtype=dtype),
CLIPLCrossAttention(channels=1280, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=2560, out_channels=1280, device=device, dtype=dtype),
CLIPLCrossAttention(channels=1280, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=1920, out_channels=1280, device=device, dtype=dtype),
CLIPLCrossAttention(channels=1280, device=device, dtype=dtype),
fl.Upsample(channels=1280, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=1920, out_channels=640, device=device, dtype=dtype),
CLIPLCrossAttention(channels=640, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=1280, out_channels=640, device=device, dtype=dtype),
CLIPLCrossAttention(channels=640, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=960, out_channels=640, device=device, dtype=dtype),
CLIPLCrossAttention(channels=640, device=device, dtype=dtype),
fl.Upsample(channels=640, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=960, out_channels=320, device=device, dtype=dtype),
CLIPLCrossAttention(channels=320, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=640, out_channels=320, device=device, dtype=dtype),
CLIPLCrossAttention(channels=320, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=640, out_channels=320, device=device, dtype=dtype),
CLIPLCrossAttention(channels=320, device=device, dtype=dtype),
),
)
class MiddleBlock(fl.Chain):
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
super().__init__(
ResidualBlock(in_channels=1280, out_channels=1280, device=device, dtype=dtype),
CLIPLCrossAttention(channels=1280, device=device, dtype=dtype),
ResidualBlock(in_channels=1280, out_channels=1280, device=device, dtype=dtype),
)
class ResidualAccumulator(fl.Passthrough):
def __init__(self, n: int) -> None:
self.n = n
super().__init__(
fl.Residual(
fl.UseContext(context="unet", key="residuals").compose(func=lambda residuals: residuals[self.n])
),
fl.SetContext(context="unet", key="residuals", callback=self.update),
)
def update(self, residuals: list[Tensor | float], x: Tensor) -> None:
residuals[self.n] = x
class ResidualConcatenator(fl.Chain):
def __init__(self, n: int) -> None:
self.n = n
super().__init__(
fl.Concatenate(
fl.Identity(),
fl.UseContext(context="unet", key="residuals").compose(lambda residuals: residuals[self.n]),
dim=1,
),
)
class SD1UNet(fl.Chain):
def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
self.in_channels = in_channels
super().__init__(
TimestepEncoder(device=device, dtype=dtype),
DownBlocks(in_channels=in_channels, device=device, dtype=dtype),
fl.Sum(
fl.UseContext(context="unet", key="residuals").compose(lambda x: x[-1]),
MiddleBlock(device=device, dtype=dtype),
),
UpBlocks(),
fl.Chain(
fl.GroupNorm(channels=320, num_groups=32, device=device, dtype=dtype),
fl.SiLU(),
fl.Conv2d(
in_channels=320,
out_channels=4,
kernel_size=3,
stride=1,
padding=1,
device=device,
dtype=dtype,
),
),
)
for residual_block in self.layers(ResidualBlock):
chain = residual_block.Chain
RangeAdapter2d(
target=chain.Conv2d_1,
channels=residual_block.out_channels,
embedding_dim=1280,
context_key="timestep_embedding",
device=device,
dtype=dtype,
).inject(chain)
for n, block in enumerate(cast(Iterable[fl.Chain], self.DownBlocks)):
block.append(ResidualAccumulator(n))
for n, block in enumerate(cast(Iterable[fl.Chain], self.UpBlocks)):
block.insert(0, ResidualConcatenator(-n - 2))
def init_context(self) -> Contexts:
return {
"unet": {"residuals": [0.0] * 13},
"diffusion": {"timestep": None},
"range_adapter": {"timestep_embedding": None},
"sampling": {"shapes": []},
}
def set_clip_text_embedding(self, clip_text_embedding: Tensor) -> None:
self.set_context("cross_attention_block", {"clip_text_embedding": clip_text_embedding})
def set_timestep(self, timestep: Tensor) -> None:
self.set_context("diffusion", {"timestep": timestep})