You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
imaginAIry/imaginairy/vendored/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py

286 lines
11 KiB
Python

from typing import cast
from torch import Tensor, device as Device, dtype as DType
import imaginairy.vendored.refiners.fluxion.layers as fl
from imaginairy.vendored.refiners.fluxion.context import Contexts
from imaginairy.vendored.refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
from imaginairy.vendored.refiners.foundationals.latent_diffusion.range_adapter import (
RangeAdapter2d,
RangeEncoder,
compute_sinusoidal_embedding,
)
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import (
ResidualAccumulator,
ResidualBlock,
ResidualConcatenator,
)
class TextTimeEmbedding(fl.Chain):
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
self.timestep_embedding_dim = 1280
self.time_ids_embedding_dim = 256
self.text_time_embedding_dim = 2816
super().__init__(
fl.Concatenate(
fl.UseContext(context="diffusion", key="pooled_text_embedding"),
fl.Chain(
fl.UseContext(context="diffusion", key="time_ids"),
fl.Unsqueeze(dim=-1),
fl.Lambda(func=self.compute_sinusoidal_embedding),
fl.Reshape(-1),
),
dim=1,
),
fl.Converter(set_device=False, set_dtype=True),
fl.Linear(
in_features=self.text_time_embedding_dim,
out_features=self.timestep_embedding_dim,
device=device,
dtype=dtype,
),
fl.SiLU(),
fl.Linear(
in_features=self.timestep_embedding_dim,
out_features=self.timestep_embedding_dim,
device=device,
dtype=dtype,
),
)
def compute_sinusoidal_embedding(self, x: Tensor) -> Tensor:
return compute_sinusoidal_embedding(x=x, embedding_dim=self.time_ids_embedding_dim)
class TimestepEncoder(fl.Passthrough):
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
self.timestep_embedding_dim = 1280
super().__init__(
fl.Sum(
fl.Chain(
fl.UseContext(context="diffusion", key="timestep"),
RangeEncoder(
sinusoidal_embedding_dim=320,
embedding_dim=self.timestep_embedding_dim,
device=device,
dtype=dtype,
),
),
TextTimeEmbedding(device=device, dtype=dtype),
),
fl.SetContext(context="range_adapter", key="timestep_embedding"),
)
class SDXLCrossAttention(CrossAttentionBlock2d):
def __init__(
self,
channels: int,
num_attention_layers: int = 1,
num_attention_heads: int = 10,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__(
channels=channels,
context_embedding_dim=2048,
context_key="clip_text_embedding",
num_attention_layers=num_attention_layers,
num_attention_heads=num_attention_heads,
use_bias=False,
use_linear_projection=True,
device=device,
dtype=dtype,
)
class DownBlocks(fl.Chain):
def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
self.in_channels = in_channels
in_block = fl.Chain(
fl.Conv2d(in_channels=in_channels, out_channels=320, kernel_size=3, padding=1, device=device, dtype=dtype)
)
first_blocks = [
fl.Chain(
ResidualBlock(in_channels=320, out_channels=320, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=320, out_channels=320, device=device, dtype=dtype),
),
fl.Chain(
fl.Downsample(channels=320, scale_factor=2, padding=1, device=device, dtype=dtype),
),
]
second_blocks = [
fl.Chain(
ResidualBlock(in_channels=320, out_channels=640, device=device, dtype=dtype),
SDXLCrossAttention(
channels=640, num_attention_layers=2, num_attention_heads=10, device=device, dtype=dtype
),
),
fl.Chain(
ResidualBlock(in_channels=640, out_channels=640, device=device, dtype=dtype),
SDXLCrossAttention(
channels=640, num_attention_layers=2, num_attention_heads=10, device=device, dtype=dtype
),
),
fl.Chain(
fl.Downsample(channels=640, scale_factor=2, padding=1, device=device, dtype=dtype),
),
]
third_blocks = [
fl.Chain(
ResidualBlock(in_channels=640, out_channels=1280, device=device, dtype=dtype),
SDXLCrossAttention(
channels=1280, num_attention_layers=10, num_attention_heads=20, device=device, dtype=dtype
),
),
fl.Chain(
ResidualBlock(in_channels=1280, out_channels=1280, device=device, dtype=dtype),
SDXLCrossAttention(
channels=1280, num_attention_layers=10, num_attention_heads=20, device=device, dtype=dtype
),
),
]
super().__init__(
in_block,
*first_blocks,
*second_blocks,
*third_blocks,
)
class UpBlocks(fl.Chain):
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
first_blocks = [
fl.Chain(
ResidualBlock(in_channels=2560, out_channels=1280, device=device, dtype=dtype),
SDXLCrossAttention(
channels=1280, num_attention_layers=10, num_attention_heads=20, device=device, dtype=dtype
),
),
fl.Chain(
ResidualBlock(in_channels=2560, out_channels=1280, device=device, dtype=dtype),
SDXLCrossAttention(
channels=1280, num_attention_layers=10, num_attention_heads=20, device=device, dtype=dtype
),
),
fl.Chain(
ResidualBlock(in_channels=1920, out_channels=1280, device=device, dtype=dtype),
SDXLCrossAttention(
channels=1280, num_attention_layers=10, num_attention_heads=20, device=device, dtype=dtype
),
fl.Upsample(channels=1280, device=device, dtype=dtype),
),
]
second_blocks = [
fl.Chain(
ResidualBlock(in_channels=1920, out_channels=640, device=device, dtype=dtype),
SDXLCrossAttention(
channels=640, num_attention_layers=2, num_attention_heads=10, device=device, dtype=dtype
),
),
fl.Chain(
ResidualBlock(in_channels=1280, out_channels=640, device=device, dtype=dtype),
SDXLCrossAttention(
channels=640, num_attention_layers=2, num_attention_heads=10, device=device, dtype=dtype
),
),
fl.Chain(
ResidualBlock(in_channels=960, out_channels=640, device=device, dtype=dtype),
SDXLCrossAttention(
channels=640, num_attention_layers=2, num_attention_heads=10, device=device, dtype=dtype
),
fl.Upsample(channels=640, device=device, dtype=dtype),
),
]
third_blocks = [
fl.Chain(
ResidualBlock(in_channels=960, out_channels=320, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=640, out_channels=320, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=640, out_channels=320, device=device, dtype=dtype),
),
]
super().__init__(
*first_blocks,
*second_blocks,
*third_blocks,
)
class MiddleBlock(fl.Chain):
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
super().__init__(
ResidualBlock(in_channels=1280, out_channels=1280, device=device, dtype=dtype),
SDXLCrossAttention(
channels=1280, num_attention_layers=10, num_attention_heads=20, device=device, dtype=dtype
),
ResidualBlock(in_channels=1280, out_channels=1280, device=device, dtype=dtype),
)
class OutputBlock(fl.Chain):
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
super().__init__(
fl.GroupNorm(channels=320, num_groups=32, device=device, dtype=dtype),
fl.SiLU(),
fl.Conv2d(in_channels=320, out_channels=4, kernel_size=3, stride=1, padding=1, device=device, dtype=dtype),
)
class SDXLUNet(fl.Chain):
def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
self.in_channels = in_channels
super().__init__(
TimestepEncoder(device=device, dtype=dtype),
DownBlocks(in_channels=in_channels, device=device, dtype=dtype),
MiddleBlock(device=device, dtype=dtype),
fl.Residual(fl.UseContext(context="unet", key="residuals").compose(lambda x: x[-1])),
UpBlocks(device=device, dtype=dtype),
OutputBlock(device=device, dtype=dtype),
)
for residual_block in self.layers(ResidualBlock):
chain = residual_block.Chain
RangeAdapter2d(
target=chain.Conv2d_1,
channels=residual_block.out_channels,
embedding_dim=1280,
context_key="timestep_embedding",
device=device,
dtype=dtype,
).inject(chain)
for n, block in enumerate(iterable=cast(list[fl.Chain], self.DownBlocks)):
block.append(module=ResidualAccumulator(n=n))
for n, block in enumerate(iterable=cast(list[fl.Chain], self.UpBlocks)):
block.insert(index=0, module=ResidualConcatenator(n=-n - 2))
def init_context(self) -> Contexts:
return {
"unet": {"residuals": [0.0] * 10},
"diffusion": {"timestep": None, "time_ids": None, "pooled_text_embedding": None},
"range_adapter": {"timestep_embedding": None},
"sampling": {"shapes": []},
}
def set_clip_text_embedding(self, clip_text_embedding: Tensor) -> None:
self.set_context(context="cross_attention_block", value={"clip_text_embedding": clip_text_embedding})
def set_timestep(self, timestep: Tensor) -> None:
self.set_context(context="diffusion", value={"timestep": timestep})
def set_time_ids(self, time_ids: Tensor) -> None:
self.set_context(context="diffusion", value={"time_ids": time_ids})
def set_pooled_text_embedding(self, pooled_text_embedding: Tensor) -> None:
self.set_context(context="diffusion", value={"pooled_text_embedding": pooled_text_embedding})