build: vendorize refiners

so we can still work in conda envs
pull/438/head
Bryce 4 months ago committed by Bryce Drennan
parent f84406f12c
commit 55e27160f5

@ -201,6 +201,19 @@ vendorize_normal_map:
make af
vendorize_refiners:
export REPO=git@github.com:finegrain-ai/refiners.git PKG=refiners COMMIT=20c229903f53d05dc1c44659ec97603660ef964c && \
make download_repo REPO=$$REPO PKG=$$PKG COMMIT=$$COMMIT && \
mkdir -p ./imaginairy/vendored/$$PKG && \
rm -rf ./imaginairy/vendored/$$PKG/* && \
cp -R ./downloads/refiners/src/refiners/* ./imaginairy/vendored/$$PKG/ && \
cp ./downloads/refiners/LICENSE ./imaginairy/vendored/$$PKG/ && \
rm -rf ./imaginairy/vendored/$$PKG/training_utils && \
echo "vendored from $$REPO @ $$COMMIT" | tee ./imaginairy/vendored/$$PKG/readme.txt
find ./imaginairy/vendored/refiners/ -type f -name "*.py" -exec sed -i '' 's/from refiners/from imaginairy.vendored.refiners/g' {} + &&\
find ./imaginairy/vendored/refiners/ -type f -name "*.py" -exec sed -i '' 's/import refiners/import imaginairy.vendored.refiners/g' {} + &&\
make af
vendorize: ## vendorize a github repo. `make vendorize REPO=git@github.com:openai/CLIP.git PKG=clip`
mkdir -p ./downloads

@ -12,17 +12,17 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
def _generate_single_image_compvis(
def _generate_single_image(
prompt: "ImaginePrompt",
debug_img_callback=None,
progress_img_callback=None,
progress_img_interval_steps=3,
progress_img_interval_min_s=0.1,
half_mode=None,
add_caption=False,
# controlnet, finetune, naive, auto
inpaint_method="finetune",
return_latent=False,
dtype=None,
):
import torch.nn
from PIL import Image, ImageOps
@ -96,7 +96,7 @@ def _generate_single_image_compvis(
weights_location=prompt.model_weights,
config_path=prompt.model_architecture,
control_weights_locations=control_modes,
half_mode=half_mode,
half_mode=dtype == torch.float16,
for_inpainting=for_inpainting and inpaint_method == "finetune",
)
is_controlnet_model = hasattr(model, "control_key")
@ -502,7 +502,6 @@ def _generate_composition_image(
):
from PIL import Image
from imaginairy.api.generate_refiners import generate_single_image
from imaginairy.utils import default, get_default_dtype
cutoff = normalize_image_size(cutoff)
@ -532,7 +531,7 @@ def _generate_composition_image(
},
)
result = generate_single_image(composition_prompt, dtype=dtype)
result = _generate_single_image(composition_prompt, dtype=dtype)
img = result.images["generated"]
while img.width < target_width:
from imaginairy.enhancers.upscale_realesrgan import upscale_image

@ -27,7 +27,6 @@ def generate_single_image(
):
import torch.nn
from PIL import Image, ImageOps
from refiners.foundationals.latent_diffusion.schedulers import DDIM, DPMSolver
from tqdm import tqdm
from imaginairy.api.generate import (
@ -61,6 +60,10 @@ def generate_single_image(
prepare_image_for_outpaint,
)
from imaginairy.utils.safety import create_safety_score
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers import (
DDIM,
DPMSolver,
)
if dtype is None:
dtype = torch.float16
@ -513,7 +516,9 @@ def prep_control_input(
if not control_config:
msg = f"Unknown control mode: {control_input.mode}"
raise ValueError(msg)
from refiners.foundationals.latent_diffusion import SD1ControlnetAdapter
from imaginairy.vendored.refiners.foundationals.latent_diffusion import (
SD1ControlnetAdapter,
)
controlnet = SD1ControlnetAdapter( # type: ignore
name=control_input.mode,

@ -5,42 +5,52 @@ import math
from functools import lru_cache
from typing import Any, List, Literal
import refiners.fluxion.layers as fl
import torch
from refiners.fluxion.layers.attentions import ScaledDotProductAttention
from refiners.fluxion.layers.chain import ChainError
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
from refiners.foundationals.latent_diffusion.model import (
from torch import Tensor, device as Device, dtype as DType, nn
from torch.nn import functional as F
import imaginairy.vendored.refiners.fluxion.layers as fl
from imaginairy.schema import WeightedPrompt
from imaginairy.utils.feather_tile import rebuild_image, tile_image
from imaginairy.vendored.refiners.fluxion.layers.attentions import (
ScaledDotProductAttention,
)
from imaginairy.vendored.refiners.fluxion.layers.chain import ChainError
from imaginairy.vendored.refiners.foundationals.clip.text_encoder import (
CLIPTextEncoderL,
)
from imaginairy.vendored.refiners.foundationals.latent_diffusion.model import (
TLatentDiffusionModel,
)
from refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
from refiners.foundationals.latent_diffusion.self_attention_guidance import (
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.ddim import (
DDIM,
)
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import (
Scheduler,
)
from imaginairy.vendored.refiners.foundationals.latent_diffusion.self_attention_guidance import (
SelfAttentionMap,
)
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import (
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import (
Controlnet,
SD1ControlnetAdapter,
)
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import (
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.model import (
SD1Autoencoder,
SD1UNet,
StableDiffusion_1 as RefinerStableDiffusion_1,
StableDiffusion_1_Inpainting as RefinerStableDiffusion_1_Inpainting,
)
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import (
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import (
SDXLAutoencoder,
StableDiffusion_XL as RefinerStableDiffusion_XL,
)
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import (
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import (
DoubleTextEncoder,
)
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
from torch import Tensor, device as Device, dtype as DType, nn
from torch.nn import functional as F
from imaginairy.schema import WeightedPrompt
from imaginairy.utils.feather_tile import rebuild_image, tile_image
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import (
SDXLUNet,
)
from imaginairy.weight_management.conversion import cast_weights
logger = logging.getLogger(__name__)
@ -375,8 +385,8 @@ class StableDiffusion_1_Inpainting(TileModeMixin, RefinerStableDiffusion_1_Inpai
import torch
total_weight = sum(wp.weight for wp in prompts)
if str(self.clip_text_encoder.device) == "cpu":
self.clip_text_encoder = self.clip_text_encoder.to(dtype=torch.float32)
if str(self.clip_text_encoder.device) == "cpu": # type: ignore
self.clip_text_encoder = self.clip_text_encoder.to(dtype=torch.float32) # type: ignore
conditioning = sum(
self.clip_text_encoder(wp.text) * (wp.weight / total_weight)
for wp in prompts

@ -16,9 +16,6 @@ from huggingface_hub import (
try_to_load_from_cache,
)
from omegaconf import OmegaConf
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
from refiners.foundationals.latent_diffusion import DoubleTextEncoder, SD1UNet, SDXLUNet
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
from safetensors.torch import load_file
from imaginairy import config as iconfig
@ -29,6 +26,17 @@ from imaginairy.utils import clear_gpu_cache, get_device, instantiate_from_confi
from imaginairy.utils.model_cache import memory_managed_model
from imaginairy.utils.named_resolutions import normalize_image_size
from imaginairy.utils.paths import PKG_ROOT
from imaginairy.vendored.refiners.foundationals.clip.text_encoder import (
CLIPTextEncoderL,
)
from imaginairy.vendored.refiners.foundationals.latent_diffusion import (
DoubleTextEncoder,
SD1UNet,
SDXLUNet,
)
from imaginairy.vendored.refiners.foundationals.latent_diffusion.model import (
LatentDiffusionModel,
)
from imaginairy.weight_management import translators
logger = logging.getLogger(__name__)
@ -823,7 +831,7 @@ def open_weights(filepath, device=None):
device = get_device()
if "safetensor" in filepath.lower():
from refiners.fluxion.utils import safe_open
from imaginairy.vendored.refiners.fluxion.utils import safe_open
with safe_open(path=filepath, framework="pytorch", device=device) as tensors:
state_dict = {

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2023 Lagon Technologies
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

@ -0,0 +1,3 @@
from imaginairy.vendored.refiners.fluxion.utils import load_from_safetensors, manual_seed, norm, pad, save_to_safetensors
__all__ = ["norm", "manual_seed", "save_to_safetensors", "load_from_safetensors", "pad"]

@ -0,0 +1,3 @@
from imaginairy.vendored.refiners.fluxion.adapters.adapter import Adapter
__all__ = ["Adapter"]

@ -0,0 +1,101 @@
import contextlib
from typing import Any, Generic, Iterator, TypeVar
import imaginairy.vendored.refiners.fluxion.layers as fl
T = TypeVar("T", bound=fl.Module)
TAdapter = TypeVar("TAdapter", bound="Adapter[Any]") # Self (see PEP 673)
class Adapter(Generic[T]):
# we store _target into a one element list to avoid pytorch thinking it is a submodule
_target: "list[T]"
def __init_subclass__(cls, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
assert issubclass(cls, fl.Chain), f"Adapter {cls.__name__} must be a Chain"
@property
def target(self) -> T:
return self._target[0]
@contextlib.contextmanager
def setup_adapter(self, target: T) -> Iterator[None]:
assert isinstance(self, fl.Chain)
assert (not hasattr(self, "_modules")) or (
len(self) == 0
), "Call the Chain constructor in the setup_adapter context."
self._target = [target]
if not isinstance(self.target, fl.ContextModule):
yield
return
_old_can_refresh_parent = target._can_refresh_parent
target._can_refresh_parent = False
yield
target._can_refresh_parent = _old_can_refresh_parent
def inject(self: TAdapter, parent: fl.Chain | None = None) -> TAdapter:
assert isinstance(self, fl.Chain)
if (parent is None) and isinstance(self.target, fl.ContextModule):
parent = self.target.parent
if parent is not None:
assert isinstance(parent, fl.Chain), f"{self.target} has invalid parent {parent}"
target_parent = self.find_parent(self.target)
if parent is None:
if isinstance(self.target, fl.ContextModule):
self.target._set_parent(target_parent) # type: ignore[reportPrivateUsage]
return self
# In general, `true_parent` is `parent`. We do this to support multiple adaptation,
# i.e. initializing two adapters before injecting them.
true_parent = parent.ensure_find_parent(self.target)
true_parent.replace(
old_module=self.target,
new_module=self,
old_module_parent=target_parent,
)
return self
def eject(self) -> None:
assert isinstance(self, fl.Chain)
# In general, the "actual target" is the target.
# Here we deal with the edge case where the target
# is part of the replacement block and has been adapted by
# another adapter after this one. For instance, this is the
# case when stacking Controlnets.
actual_target = lookup_top_adapter(self, self.target)
if (parent := self.parent) is None:
if isinstance(actual_target, fl.ContextModule):
actual_target._set_parent(None) # type: ignore[reportPrivateUsage]
else:
parent.replace(old_module=self, new_module=actual_target)
def _pre_structural_copy(self) -> None:
if isinstance(self.target, fl.Chain):
raise RuntimeError("Chain adapters typically cannot be copied, eject them first.")
def _post_structural_copy(self: TAdapter, source: TAdapter) -> None:
self._target = [source.target]
def lookup_top_adapter(top: fl.Chain, target: fl.Module) -> fl.Module:
"""Lookup and return last adapter in parents tree (or target if none)."""
target_parent = top.find_parent(target)
if (target_parent is None) or (target_parent == top):
return target
r, p = target, target_parent
while p != top:
if isinstance(p, Adapter):
r = p
assert p.parent, f"parent tree of {top} is broken"
p = p.parent
return r

@ -0,0 +1,130 @@
from typing import Any, Generic, Iterable, TypeVar
from torch import Tensor, device as Device, dtype as DType
from torch.nn import Parameter as TorchParameter
from torch.nn.init import normal_, zeros_
import imaginairy.vendored.refiners.fluxion.layers as fl
from imaginairy.vendored.refiners.fluxion.adapters.adapter import Adapter
T = TypeVar("T", bound=fl.Chain)
TLoraAdapter = TypeVar("TLoraAdapter", bound="LoraAdapter[Any]") # Self (see PEP 673)
class Lora(fl.Chain):
def __init__(
self,
in_features: int,
out_features: int,
rank: int = 16,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.in_features = in_features
self.out_features = out_features
self.rank = rank
self.scale: float = 1.0
super().__init__(
fl.Linear(in_features=in_features, out_features=rank, bias=False, device=device, dtype=dtype),
fl.Linear(in_features=rank, out_features=out_features, bias=False, device=device, dtype=dtype),
fl.Lambda(func=self.scale_outputs),
)
normal_(tensor=self.Linear_1.weight, std=1 / self.rank)
zeros_(tensor=self.Linear_2.weight)
def scale_outputs(self, x: Tensor) -> Tensor:
return x * self.scale
def set_scale(self, scale: float) -> None:
self.scale = scale
def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None:
self.Linear_1.weight = TorchParameter(down_weight.to(device=self.device, dtype=self.dtype))
self.Linear_2.weight = TorchParameter(up_weight.to(device=self.device, dtype=self.dtype))
@property
def up_weight(self) -> Tensor:
return self.Linear_2.weight.data
@property
def down_weight(self) -> Tensor:
return self.Linear_1.weight.data
class SingleLoraAdapter(fl.Sum, Adapter[fl.Linear]):
def __init__(
self,
target: fl.Linear,
rank: int = 16,
scale: float = 1.0,
) -> None:
self.in_features = target.in_features
self.out_features = target.out_features
self.rank = rank
self.scale = scale
with self.setup_adapter(target):
super().__init__(
target,
Lora(
in_features=target.in_features,
out_features=target.out_features,
rank=rank,
device=target.device,
dtype=target.dtype,
),
)
self.Lora.set_scale(scale=scale)
class LoraAdapter(Generic[T], fl.Chain, Adapter[T]):
def __init__(
self,
target: T,
sub_targets: Iterable[tuple[fl.Linear, fl.Chain]],
rank: int | None = None,
scale: float = 1.0,
weights: list[Tensor] | None = None,
) -> None:
with self.setup_adapter(target):
super().__init__(target)
if weights is not None:
assert len(weights) % 2 == 0
weights_rank = weights[0].shape[1]
if rank is None:
rank = weights_rank
else:
assert rank == weights_rank
assert rank is not None, "either pass a rank or weights"
self.sub_targets = sub_targets
self.sub_adapters: list[tuple[SingleLoraAdapter, fl.Chain]] = []
for linear, parent in self.sub_targets:
self.sub_adapters.append((SingleLoraAdapter(target=linear, rank=rank, scale=scale), parent))
if weights is not None:
assert len(self.sub_adapters) == (len(weights) // 2)
for i, (adapter, _) in enumerate(self.sub_adapters):
lora = adapter.Lora
assert (
lora.rank == weights[i * 2].shape[1]
), f"Rank of Lora layer {lora.rank} must match shape of weights {weights[i*2].shape[1]}"
adapter.Lora.load_weights(up_weight=weights[i * 2], down_weight=weights[i * 2 + 1])
def inject(self: TLoraAdapter, parent: fl.Chain | None = None) -> TLoraAdapter:
for adapter, adapter_parent in self.sub_adapters:
adapter.inject(adapter_parent)
return super().inject(parent)
def eject(self) -> None:
for adapter, _ in self.sub_adapters:
adapter.eject()
super().eject()
@property
def weights(self) -> list[Tensor]:
return [w for adapter, _ in self.sub_adapters for w in [adapter.Lora.up_weight, adapter.Lora.down_weight]]

@ -0,0 +1,53 @@
from typing import Any
from torch import Tensor
Context = dict[str, Any]
Contexts = dict[str, Context]
class ContextProvider:
def __init__(self) -> None:
self.contexts: Contexts = {}
def set_context(self, key: str, value: Context) -> None:
self.contexts[key] = value
def get_context(self, key: str) -> Any:
return self.contexts.get(key)
def update_contexts(self, new_contexts: Contexts) -> None:
for key, value in new_contexts.items():
if key not in self.contexts:
self.contexts[key] = value
else:
self.contexts[key].update(value)
@staticmethod
def create(contexts: Contexts) -> "ContextProvider":
provider = ContextProvider()
provider.update_contexts(contexts)
return provider
def __add__(self, other: "ContextProvider") -> "ContextProvider":
self.contexts.update(other.contexts)
return self
def __lshift__(self, other: "ContextProvider") -> "ContextProvider":
other.contexts.update(self.contexts)
return other
def __bool__(self) -> bool:
return bool(self.contexts)
def _get_repr_for_value(self, value: Any) -> str:
if isinstance(value, Tensor):
return f"Tensor(shape={value.shape}, dtype={value.dtype}, device={value.device})"
return repr(value)
def _get_repr_for_dict(self, context_dict: Context) -> dict[str, str]:
return {key: self._get_repr_for_value(value) for key, value in context_dict.items()}
def __repr__(self) -> str:
contexts_repr = {key: self._get_repr_for_dict(value) for key, value in self.contexts.items()}
return f"{self.__class__.__name__}(contexts={contexts_repr})"

@ -0,0 +1,110 @@
from imaginairy.vendored.refiners.fluxion.layers.activations import GLU, ApproximateGeLU, GeLU, ReLU, Sigmoid, SiLU
from imaginairy.vendored.refiners.fluxion.layers.attentions import Attention, SelfAttention, SelfAttention2d
from imaginairy.vendored.refiners.fluxion.layers.basics import (
Buffer,
Chunk,
Cos,
Flatten,
GetArg,
Identity,
Multiply,
Parameter,
Permute,
Reshape,
Sin,
Slicing,
Squeeze,
Transpose,
Unbind,
Unflatten,
Unsqueeze,
View,
)
from imaginairy.vendored.refiners.fluxion.layers.chain import (
Breakpoint,
Chain,
Concatenate,
Distribute,
Lambda,
Matmul,
Parallel,
Passthrough,
Residual,
Return,
SetContext,
Sum,
UseContext,
)
from imaginairy.vendored.refiners.fluxion.layers.conv import Conv2d, ConvTranspose2d
from imaginairy.vendored.refiners.fluxion.layers.converter import Converter
from imaginairy.vendored.refiners.fluxion.layers.embedding import Embedding
from imaginairy.vendored.refiners.fluxion.layers.linear import Linear, MultiLinear
from imaginairy.vendored.refiners.fluxion.layers.maxpool import MaxPool1d, MaxPool2d
from imaginairy.vendored.refiners.fluxion.layers.module import ContextModule, Module, WeightedModule
from imaginairy.vendored.refiners.fluxion.layers.norm import GroupNorm, InstanceNorm2d, LayerNorm, LayerNorm2d
from imaginairy.vendored.refiners.fluxion.layers.padding import ReflectionPad2d
from imaginairy.vendored.refiners.fluxion.layers.pixelshuffle import PixelUnshuffle
from imaginairy.vendored.refiners.fluxion.layers.sampling import Downsample, Interpolate, Upsample
__all__ = [
"Embedding",
"LayerNorm",
"GroupNorm",
"LayerNorm2d",
"InstanceNorm2d",
"GeLU",
"GLU",
"SiLU",
"ReLU",
"ApproximateGeLU",
"Sigmoid",
"Attention",
"SelfAttention",
"SelfAttention2d",
"Identity",
"GetArg",
"View",
"Flatten",
"Unflatten",
"Transpose",
"Permute",
"Squeeze",
"Unsqueeze",
"Reshape",
"Slicing",
"Parameter",
"Sin",
"Cos",
"Chunk",
"Multiply",
"Unbind",
"Matmul",
"Buffer",
"Lambda",
"Return",
"Sum",
"Residual",
"Chain",
"UseContext",
"SetContext",
"Parallel",
"Distribute",
"Passthrough",
"Breakpoint",
"Concatenate",
"Conv2d",
"ConvTranspose2d",
"Linear",
"MultiLinear",
"Downsample",
"Upsample",
"Module",
"WeightedModule",
"ContextModule",
"Interpolate",
"ReflectionPad2d",
"PixelUnshuffle",
"Converter",
"MaxPool1d",
"MaxPool2d",
]

@ -0,0 +1,77 @@
from torch import Tensor, sigmoid
from torch.nn.functional import (
gelu, # type: ignore
silu,
)
from imaginairy.vendored.refiners.fluxion.layers.module import Module
class Activation(Module):
def __init__(self) -> None:
super().__init__()
class SiLU(Activation):
def __init__(self) -> None:
super().__init__()
def forward(self, x: Tensor) -> Tensor:
return silu(x) # type: ignore
class ReLU(Activation):
def __init__(self) -> None:
super().__init__()
def forward(self, x: Tensor) -> Tensor:
return x.relu()
class GeLU(Activation):
def __init__(self) -> None:
super().__init__()
def forward(self, x: Tensor) -> Tensor:
return gelu(x) # type: ignore
class ApproximateGeLU(Activation):
"""
The approximate form of Gaussian Error Linear Unit (GELU)
For more details, see section 2: https://arxiv.org/abs/1606.08415
"""
def __init__(self) -> None:
super().__init__()
def forward(self, x: Tensor) -> Tensor:
return x * sigmoid(1.702 * x)
class Sigmoid(Activation):
def __init__(self) -> None:
super().__init__()
def forward(self, x: Tensor) -> Tensor:
return x.sigmoid()
class GLU(Activation):
"""
Gated Linear Unit activation layer.
See https://arxiv.org/abs/2002.05202v1 for details.
"""
def __init__(self, activation: Activation) -> None:
super().__init__()
self.activation = activation
def __repr__(self):
return f"{self.__class__.__name__}(activation={self.activation})"
def forward(self, x: Tensor) -> Tensor:
assert x.shape[-1] % 2 == 0, "Non-batch input dimension must be divisible by 2"
output, gate = x.chunk(2, dim=-1)
return output * self.activation(gate)

@ -0,0 +1,246 @@
import math
import torch
from jaxtyping import Float
from torch import Tensor, device as Device, dtype as DType
from torch.nn.functional import scaled_dot_product_attention as _scaled_dot_product_attention # type: ignore
from imaginairy.vendored.refiners.fluxion.context import Contexts
from imaginairy.vendored.refiners.fluxion.layers.basics import Identity
from imaginairy.vendored.refiners.fluxion.layers.chain import Chain, Distribute, Lambda, Parallel
from imaginairy.vendored.refiners.fluxion.layers.linear import Linear
from imaginairy.vendored.refiners.fluxion.layers.module import Module
def scaled_dot_product_attention(
query: Float[Tensor, "batch source_sequence_length dim"],
key: Float[Tensor, "batch target_sequence_length dim"],
value: Float[Tensor, "batch target_sequence_length dim"],
is_causal: bool = False,
) -> Float[Tensor, "batch source_sequence_length dim"]:
return _scaled_dot_product_attention(query, key, value, is_causal=is_causal) # type: ignore
def sparse_dot_product_attention_non_optimized(
query: Float[Tensor, "batch source_sequence_length dim"],
key: Float[Tensor, "batch target_sequence_length dim"],
value: Float[Tensor, "batch target_sequence_length dim"],
is_causal: bool = False,
) -> Float[Tensor, "batch source_sequence_length dim"]:
if is_causal:
# TODO: implement causal attention
raise NotImplementedError("Causal attention for non_optimized attention is not yet implemented")
_, _, _, dim = query.shape
attention = query @ key.permute(0, 1, 3, 2)
attention = attention / math.sqrt(dim)
attention = torch.softmax(input=attention, dim=-1)
return attention @ value
class ScaledDotProductAttention(Module):
def __init__(
self,
num_heads: int = 1,
is_causal: bool | None = None,
is_optimized: bool = True,
slice_size: int | None = None,
) -> None:
super().__init__()
self.num_heads = num_heads
self.is_causal = is_causal
self.is_optimized = is_optimized
self.slice_size = slice_size
self.dot_product = (
scaled_dot_product_attention if self.is_optimized else sparse_dot_product_attention_non_optimized
)
def forward(
self,
query: Float[Tensor, "batch num_queries embedding_dim"],
key: Float[Tensor, "batch num_keys embedding_dim"],
value: Float[Tensor, "batch num_values embedding_dim"],
is_causal: bool | None = None,
) -> Float[Tensor, "batch num_queries dim"]:
if self.slice_size is None:
return self._process_attention(query, key, value, is_causal)
return self._sliced_attention(query, key, value, is_causal=is_causal, slice_size=self.slice_size)
def _sliced_attention(
self,
query: Float[Tensor, "batch num_queries embedding_dim"],
key: Float[Tensor, "batch num_keys embedding_dim"],
value: Float[Tensor, "batch num_values embedding_dim"],
slice_size: int,
is_causal: bool | None = None,
) -> Float[Tensor, "batch num_queries dim"]:
_, num_queries, _ = query.shape
output = torch.zeros_like(query)
for start_idx in range(0, num_queries, slice_size):
end_idx = min(start_idx + slice_size, num_queries)
output[:, start_idx:end_idx, :] = self._process_attention(
query[:, start_idx:end_idx, :], key, value, is_causal
)
return output
def _process_attention(
self,
query: Float[Tensor, "batch num_queries embedding_dim"],
key: Float[Tensor, "batch num_keys embedding_dim"],
value: Float[Tensor, "batch num_values embedding_dim"],
is_causal: bool | None = None,
) -> Float[Tensor, "batch num_queries dim"]:
return self.merge_multi_head(
x=self.dot_product(
query=self.split_to_multi_head(query),
key=self.split_to_multi_head(key),
value=self.split_to_multi_head(value),
is_causal=(
is_causal if is_causal is not None else (self.is_causal if self.is_causal is not None else False)
),
)
)
def split_to_multi_head(
self, x: Float[Tensor, "batch_size sequence_length embedding_dim"]
) -> Float[Tensor, "batch_size num_heads sequence_length (embedding_dim//num_heads)"]:
assert (
len(x.shape) == 3
), f"Expected tensor with shape (batch_size sequence_length embedding_dim), got {x.shape}"
assert (
x.shape[-1] % self.num_heads == 0
), f"Embedding dim (x.shape[-1]={x.shape[-1]}) must be divisible by num heads"
return x.reshape(x.shape[0], x.shape[1], self.num_heads, x.shape[-1] // self.num_heads).transpose(1, 2)
def merge_multi_head(
self, x: Float[Tensor, "batch_size num_heads sequence_length heads_dim"]
) -> Float[Tensor, "batch_size sequence_length heads_dim * num_heads"]:
return x.transpose(1, 2).reshape(x.shape[0], x.shape[2], self.num_heads * x.shape[-1])
class Attention(Chain):
def __init__(
self,
embedding_dim: int,
num_heads: int = 1,
key_embedding_dim: int | None = None,
value_embedding_dim: int | None = None,
inner_dim: int | None = None,
use_bias: bool = True,
is_causal: bool | None = None,
is_optimized: bool = True,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
assert (
embedding_dim % num_heads == 0
), f"embedding_dim {embedding_dim} must be divisible by num_heads {num_heads}"
self.embedding_dim = embedding_dim
self.num_heads = num_heads
self.heads_dim = embedding_dim // num_heads
self.key_embedding_dim = key_embedding_dim or embedding_dim
self.value_embedding_dim = value_embedding_dim or embedding_dim
self.inner_dim = inner_dim or embedding_dim
self.use_bias = use_bias
self.is_causal = is_causal
self.is_optimized = is_optimized
super().__init__(
Distribute(
Linear(
in_features=self.embedding_dim,
out_features=self.inner_dim,
bias=self.use_bias,
device=device,
dtype=dtype,
),
Linear(
in_features=self.key_embedding_dim,
out_features=self.inner_dim,
bias=self.use_bias,
device=device,
dtype=dtype,
),
Linear(
in_features=self.value_embedding_dim,
out_features=self.inner_dim,
bias=self.use_bias,
device=device,
dtype=dtype,
),
),
ScaledDotProductAttention(num_heads=num_heads, is_causal=is_causal, is_optimized=is_optimized),
Linear(
in_features=self.inner_dim,
out_features=self.embedding_dim,
bias=True,
device=device,
dtype=dtype,
),
)
class SelfAttention(Attention):
def __init__(
self,
embedding_dim: int,
inner_dim: int | None = None,
num_heads: int = 1,
use_bias: bool = True,
is_causal: bool | None = None,
is_optimized: bool = True,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__(
embedding_dim=embedding_dim,
inner_dim=inner_dim,
num_heads=num_heads,
use_bias=use_bias,
is_causal=is_causal,
is_optimized=is_optimized,
device=device,
dtype=dtype,
)
self.insert(0, Parallel(Identity(), Identity(), Identity()))
class SelfAttention2d(SelfAttention):
def __init__(
self,
channels: int,
num_heads: int = 1,
use_bias: bool = True,
is_causal: bool | None = None,
is_optimized: bool = True,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
assert channels % num_heads == 0, f"channels {channels} must be divisible by num_heads {num_heads}"
self.channels = channels
super().__init__(
embedding_dim=channels,
num_heads=num_heads,
use_bias=use_bias,
is_causal=is_causal,
is_optimized=is_optimized,
device=device,
dtype=dtype,
)
self.insert(0, Lambda(self.tensor_2d_to_sequence))
self.append(Lambda(self.sequence_to_tensor_2d))
def init_context(self) -> Contexts:
return {"reshape": {"height": None, "width": None}}
def tensor_2d_to_sequence(
self, x: Float[Tensor, "batch channels height width"]
) -> Float[Tensor, "batch height*width channels"]:
height, width = x.shape[-2:]
self.set_context(context="reshape", value={"height": height, "width": width})
return x.reshape(x.shape[0], x.shape[1], height * width).transpose(1, 2)
def sequence_to_tensor_2d(
self, x: Float[Tensor, "batch sequence_length channels"]
) -> Float[Tensor, "batch channels height width"]:
height, width = self.use_context("reshape").values()
return x.transpose(1, 2).reshape(x.shape[0], x.shape[2], height, width)

@ -0,0 +1,207 @@
import torch
from torch import Size, Tensor, device as Device, dtype as DType, randn
from torch.nn import Parameter as TorchParameter
from imaginairy.vendored.refiners.fluxion.layers.module import Module, WeightedModule
class Identity(Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x: Tensor) -> Tensor:
return x
class View(Module):
def __init__(self, *shape: int) -> None:
super().__init__()
self.shape = shape
def forward(self, x: Tensor) -> Tensor:
return x.view(*self.shape)
class GetArg(Module):
def __init__(self, index: int) -> None:
super().__init__()
self.index = index
def forward(self, *args: Tensor) -> Tensor:
return args[self.index]
class Flatten(Module):
def __init__(self, start_dim: int = 0, end_dim: int = -1) -> None:
super().__init__()
self.start_dim = start_dim
self.end_dim = end_dim
def forward(self, x: Tensor) -> Tensor:
return x.flatten(self.start_dim, self.end_dim)
class Unflatten(Module):
def __init__(self, dim: int) -> None:
super().__init__()
self.dim = dim
def forward(self, x: Tensor, sizes: Size) -> Tensor:
return x.unflatten(self.dim, sizes) # type: ignore
class Reshape(Module):
"""
Reshape the input tensor to the given shape. The shape must be compatible with the input tensor shape. The batch
dimension is preserved.
"""
def __init__(self, *shape: int) -> None:
super().__init__()
self.shape = shape
def forward(self, x: Tensor) -> Tensor:
return x.reshape(x.shape[0], *self.shape)
class Transpose(Module):
def __init__(self, dim0: int, dim1: int) -> None:
super().__init__()
self.dim0 = dim0
self.dim1 = dim1
def forward(self, x: Tensor) -> Tensor:
return x.transpose(self.dim0, self.dim1)
class Permute(Module):
def __init__(self, *dims: int) -> None:
super().__init__()
self.dims = dims
def forward(self, x: Tensor) -> Tensor:
return x.permute(*self.dims)
class Slicing(Module):
def __init__(self, dim: int = 0, start: int = 0, end: int | None = None, step: int = 1) -> None:
super().__init__()
self.dim = dim
self.start = start
self.end = end
self.step = step
def forward(self, x: Tensor) -> Tensor:
dim_size = x.shape[self.dim]
start = self.start if self.start >= 0 else dim_size + self.start
end = self.end or dim_size
end = end if end >= 0 else dim_size + end
start = max(min(start, dim_size), 0)
end = max(min(end, dim_size), 0)
if start >= end:
return self.get_empty_slice(x)
indices = torch.arange(start=start, end=end, step=self.step, device=x.device)
return x.index_select(self.dim, indices)
def get_empty_slice(self, x: Tensor) -> Tensor:
"""
Return an empty slice of the same shape as the input tensor to mimic PyTorch's slicing behavior.
"""
shape = list(x.shape)
shape[self.dim] = 0
return torch.empty(*shape, device=x.device)
class Squeeze(Module):
def __init__(self, dim: int) -> None:
super().__init__()
self.dim = dim
def forward(self, x: Tensor) -> Tensor:
return x.squeeze(self.dim)
class Unsqueeze(Module):
def __init__(self, dim: int) -> None:
super().__init__()
self.dim = dim
def forward(self, x: Tensor) -> Tensor:
return x.unsqueeze(self.dim)
class Unbind(Module):
def __init__(self, dim: int = 0) -> None:
self.dim = dim
super().__init__()
def forward(self, x: Tensor) -> tuple[Tensor, ...]:
return x.unbind(dim=self.dim) # type: ignore
class Chunk(Module):
def __init__(self, chunks: int, dim: int = 0) -> None:
self.chunks = chunks
self.dim = dim
super().__init__()
def forward(self, x: Tensor) -> tuple[Tensor, ...]:
return x.chunk(chunks=self.chunks, dim=self.dim) # type: ignore
class Sin(Module):
def forward(self, x: Tensor) -> Tensor:
return torch.sin(input=x)
class Cos(Module):
def forward(self, x: Tensor) -> Tensor:
return torch.cos(input=x)
class Multiply(Module):
def __init__(self, scale: float = 1.0, bias: float = 0.0) -> None:
super().__init__()
self.scale = scale
self.bias = bias
def forward(self, x: Tensor) -> Tensor:
return self.scale * x + self.bias
class Parameter(WeightedModule):
"""
A layer that wraps a tensor as a parameter. This is useful to create a parameter that is not a weight or a bias.
"""
def __init__(self, *dims: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
super().__init__()
self.dims = dims
self.weight = TorchParameter(randn(*dims, device=device, dtype=dtype))
def forward(self, x: Tensor) -> Tensor:
return self.weight.expand(x.shape[0], *self.dims)
class Buffer(WeightedModule):
"""
A layer that wraps a tensor as a buffer. This is useful to create a buffer that is not a weight or a bias.
Buffers are not trainable.
"""
def __init__(self, *dims: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
super().__init__()
self.dims = dims
self.register_buffer("buffer", randn(*dims, device=device, dtype=dtype))
@property
def device(self) -> Device:
return self.buffer.device
@property
def dtype(self) -> DType:
return self.buffer.dtype
def forward(self, _: Tensor) -> Tensor:
return self.buffer

@ -0,0 +1,586 @@
import inspect
import re
import sys
import traceback
from collections import defaultdict
from typing import Any, Callable, Iterable, Iterator, TypeVar, cast, overload
import torch
from torch import Tensor, cat, device as Device, dtype as DType
from imaginairy.vendored.refiners.fluxion.context import ContextProvider, Contexts
from imaginairy.vendored.refiners.fluxion.layers.module import ContextModule, Module, ModuleTree, WeightedModule
from imaginairy.vendored.refiners.fluxion.utils import summarize_tensor
T = TypeVar("T", bound=Module)
TChain = TypeVar("TChain", bound="Chain") # because Self (PEP 673) is not in 3.10
class Lambda(Module):
"""Lambda is a wrapper around a callable object that allows it to be used as a PyTorch module."""
def __init__(self, func: Callable[..., Any]) -> None:
super().__init__()
self.func = func
def forward(self, *args: Any) -> Any:
return self.func(*args)
def __str__(self) -> str:
func_name = getattr(self.func, "__name__", "partial_function")
return f"Lambda({func_name}{str(inspect.signature(self.func))})"
def generate_unique_names(
modules: tuple[Module, ...],
) -> dict[str, Module]:
class_counts: dict[str, int] = {}
unique_names: list[tuple[str, Module]] = []
for module in modules:
class_name = module.__class__.__name__
class_counts[class_name] = class_counts.get(class_name, 0) + 1
name_counter: dict[str, int] = {}
for module in modules:
class_name = module.__class__.__name__
name_counter[class_name] = name_counter.get(class_name, 0) + 1
unique_name = f"{class_name}_{name_counter[class_name]}" if class_counts[class_name] > 1 else class_name
unique_names.append((unique_name, module))
return dict(unique_names)
class UseContext(ContextModule):
def __init__(self, context: str, key: str) -> None:
super().__init__()
self.context = context
self.key = key
self.func: Callable[[Any], Any] = lambda x: x
def __call__(self, *args: Any) -> Any:
context = self.use_context(self.context)
assert context, f"context {self.context} is unset"
value = context.get(self.key)
assert value is not None, f"context entry {self.context}.{self.key} is unset"
return self.func(value)
def __repr__(self):
return f"{self.__class__.__name__}(context={repr(self.context)}, key={repr(self.key)})"
def compose(self, func: Callable[[Any], Any]) -> "UseContext":
self.func = func
return self
class SetContext(ContextModule):
"""A Module that sets a context value when executed.
The context need to pre exist in the context provider.
#TODO Is there a way to create the context if it doesn't exist?
"""
def __init__(self, context: str, key: str, callback: Callable[[Any, Any], Any] | None = None) -> None:
super().__init__()
self.context = context
self.key = key
self.callback = callback
def __call__(self, x: Tensor) -> Tensor:
if context := self.use_context(self.context):
if not self.callback:
context.update({self.key: x})
else:
self.callback(context[self.key], x)
return x
def __repr__(self):
return f"{self.__class__.__name__}(context={repr(self.context)}, key={repr(self.key)})"
class ReturnException(Exception):
"""Exception raised when a Return module is encountered."""
def __init__(self, value: Tensor):
self.value = value
class Return(Module):
"""A Module that stops the execution of a Chain when encountered."""
def forward(self, x: Tensor):
raise ReturnException(x)
def structural_copy(m: T) -> T:
return m.structural_copy() if isinstance(m, ContextModule) else m
class ChainError(RuntimeError):
"""Exception raised when an error occurs during the execution of a Chain."""
def __init__(self, message: str, /) -> None:
super().__init__(message)
class Chain(ContextModule):
_modules: dict[str, Module]
_provider: ContextProvider
_tag = "CHAIN"
def __init__(self, *args: Module | Iterable[Module]) -> None:
super().__init__()
self._provider = ContextProvider()
modules = cast(
tuple[Module],
(
tuple(args[0])
if len(args) == 1 and isinstance(args[0], Iterable) and not isinstance(args[0], Chain)
else tuple(args)
),
)
for module in modules:
# Violating this would mean a ContextModule ends up in two chains,
# with a single one correctly set as its parent.
assert (
(not isinstance(module, ContextModule))
or (not module._can_refresh_parent)
or (module.parent is None)
or (module.parent == self)
), f"{module.__class__.__name__} already has parent {module.parent.__class__.__name__}"
self._regenerate_keys(modules)
self._reset_context()
for module in self:
if isinstance(module, ContextModule) and module.parent != self:
module._set_parent(self)
def __setattr__(self, name: str, value: Any) -> None:
if isinstance(value, torch.nn.Module):
raise ValueError(
"Chain does not support setting modules by attribute. Instead, use a mutation method like `append` or"
" wrap it within a single element list to prevent pytorch from registering it as a submodule."
)
super().__setattr__(name, value)
@property
def provider(self) -> ContextProvider:
return self._provider
def init_context(self) -> Contexts:
return {}
def _register_provider(self, context: Contexts | None = None) -> None:
if context:
self._provider.update_contexts(context)
for module in self:
if isinstance(module, Chain):
module._register_provider(context=self._provider.contexts)
def _reset_context(self) -> None:
self._register_provider(self.init_context())
def set_context(self, context: str, value: Any) -> None:
self._provider.set_context(context, value)
self._register_provider()
def _show_error_in_tree(self, name: str, /, max_lines: int = 20) -> str:
tree = ModuleTree(module=self)
classname_counter: dict[str, int] = defaultdict(int)
first_ancestor = self.get_parents()[-1] if self.get_parents() else self
def find_state_dict_key(module: Module, /) -> str | None:
for key, layer in module.named_modules():
if layer == self:
return ".".join((key, name))
return None
for child in tree:
classname, count = name.rsplit(sep="_", maxsplit=1) if "_" in name else (name, "1")
if child["class_name"] == classname:
classname_counter[classname] += 1
if classname_counter[classname] == int(count):
state_dict_key = find_state_dict_key(first_ancestor)
child["value"] = f">>> {child['value']} | {state_dict_key}"
break
tree_repr = tree._generate_tree_repr(tree.root, depth=3) # type: ignore[reportPrivateUsage]
lines = tree_repr.split(sep="\n")
error_line_idx = next((idx for idx, line in enumerate(iterable=lines) if line.startswith(">>>")), 0)
return ModuleTree.shorten_tree_repr(tree_repr, line_index=error_line_idx, max_lines=max_lines)
@staticmethod
def _pretty_print_args(*args: Any) -> str:
"""
Flatten nested tuples and print tensors with their shape and other informations.
"""
def _flatten_tuple(t: Tensor | tuple[Any, ...], /) -> list[Any]:
if isinstance(t, tuple):
return [item for subtuple in t for item in _flatten_tuple(subtuple)]
else:
return [t]
flat_args = _flatten_tuple(args)
return "\n".join(
[
f"{idx}: {summarize_tensor(arg) if isinstance(arg, Tensor) else arg}"
for idx, arg in enumerate(iterable=flat_args)
]
)
def _filter_traceback(self, *frames: traceback.FrameSummary) -> list[traceback.FrameSummary]:
patterns_to_exclude = [
(r"torch/nn/modules/", r"^_call_impl$"),
(r"torch/nn/functional\.py", r""),
(r"refiners/fluxion/layers/", r"^_call_layer$"),
(r"refiners/fluxion/layers/", r"^forward$"),
(r"refiners/fluxion/layers/chain\.py", r""),
(r"", r"^_"),
]
def should_exclude(frame: traceback.FrameSummary, /) -> bool:
for filename_pattern, name_pattern in patterns_to_exclude:
if re.search(pattern=filename_pattern, string=frame.filename) and re.search(
pattern=name_pattern, string=frame.name
):
return True
return False
return [frame for frame in frames if not should_exclude(frame)]
def _call_layer(self, layer: Module, name: str, /, *args: Any) -> Any:
try:
return layer(*args)
except Exception as e:
exc_type, _, exc_traceback = sys.exc_info()
assert exc_type
tb_list = traceback.extract_tb(tb=exc_traceback)
filtered_tb_list = self._filter_traceback(*tb_list)
formatted_tb = "".join(traceback.format_list(extracted_list=filtered_tb_list))
pretty_args = Chain._pretty_print_args(args)
error_tree = self._show_error_in_tree(name)
exception_str = re.sub(pattern=r"\n\s*\n", repl="\n", string=str(object=e))
message = f"{formatted_tb}\n{exception_str}\n---------------\n{error_tree}\n{pretty_args}"
if "Error" not in exception_str:
message = f"{exc_type.__name__}:\n {message}"
raise ChainError(message) from None
def forward(self, *args: Any) -> Any:
result: tuple[Any] | Any = None
intermediate_args: tuple[Any, ...] = args
for name, layer in self._modules.items():
result = self._call_layer(layer, name, *intermediate_args)
intermediate_args = (result,) if not isinstance(result, tuple) else result
self._reset_context()
return result
def _regenerate_keys(self, modules: Iterable[Module]) -> None:
self._modules = generate_unique_names(tuple(modules)) # type: ignore
def __add__(self, other: "Chain | Module | list[Module]") -> "Chain":
if isinstance(other, Module):
other = Chain(other)
if isinstance(other, list):
other = Chain(*other)
return Chain(*self, *other)
@overload
def __getitem__(self, key: int) -> Module:
...
@overload
def __getitem__(self, key: str) -> Module:
...
@overload
def __getitem__(self, key: slice) -> "Chain":
...
def __getitem__(self, key: int | str | slice) -> Module:
if isinstance(key, slice):
copy = self.structural_copy()
copy._regenerate_keys(modules=list(copy)[key])
return copy
elif isinstance(key, str):
return self._modules[key]
else:
return list(self)[key]
def __iter__(self) -> Iterator[Module]:
return iter(self._modules.values())
def __len__(self) -> int:
return len(self._modules)
@property
def device(self) -> Device | None:
wm = self.find(WeightedModule)
return None if wm is None else wm.device
@property
def dtype(self) -> DType | None:
wm = self.find(WeightedModule)
return None if wm is None else wm.dtype
def _walk(
self, predicate: Callable[[Module, "Chain"], bool] | None = None, recurse: bool = False
) -> Iterator[tuple[Module, "Chain"]]:
if predicate is None:
predicate = lambda _m, _p: True
for module in self:
try:
p = predicate(module, self)
except StopIteration:
continue
if p:
yield (module, self)
if not recurse:
continue
if isinstance(module, Chain):
yield from module.walk(predicate, recurse)
@overload
def walk(
self, predicate: Callable[[Module, "Chain"], bool] | None = None, recurse: bool = False
) -> Iterator[tuple[Module, "Chain"]]:
...
@overload
def walk(self, predicate: type[T], recurse: bool = False) -> Iterator[tuple[T, "Chain"]]:
...
def walk(
self, predicate: type[T] | Callable[[Module, "Chain"], bool] | None = None, recurse: bool = False
) -> Iterator[tuple[T, "Chain"]] | Iterator[tuple[Module, "Chain"]]:
if isinstance(predicate, type):
return self._walk(lambda m, _: isinstance(m, predicate), recurse)
else:
return self._walk(predicate, recurse)
def layers(self, layer_type: type[T], recurse: bool = False) -> Iterator[T]:
for module, _ in self.walk(layer_type, recurse):
yield module
def find(self, layer_type: type[T]) -> T | None:
return next(self.layers(layer_type=layer_type), None)
def ensure_find(self, layer_type: type[T]) -> T:
r = self.find(layer_type)
assert r is not None, f"could not find {layer_type} in {self}"
return r
def find_parent(self, module: Module) -> "Chain | None":
if module in self: # avoid DFS-crawling the whole tree
return self
for _, parent in self.walk(lambda m, _: m == module):
return parent
return None
def ensure_find_parent(self, module: Module) -> "Chain":
r = self.find_parent(module)
assert r is not None, f"could not find {module} in {self}"
return r
def insert(self, index: int, module: Module) -> None:
if index < 0:
index = max(0, len(self._modules) + index + 1)
modules = list(self)
modules.insert(index, module)
self._regenerate_keys(modules)
if isinstance(module, ContextModule):
module._set_parent(self)
self._register_provider()
def insert_before_type(self, module_type: type[Module], new_module: Module) -> None:
for i, module in enumerate(self):
if isinstance(module, module_type):
self.insert(i, new_module)
return
raise ValueError(f"No module of type {module_type.__name__} found in the chain.")
def insert_after_type(self, module_type: type[Module], new_module: Module) -> None:
for i, module in enumerate(self):
if isinstance(module, module_type):
self.insert(i + 1, new_module)
return
raise ValueError(f"No module of type {module_type.__name__} found in the chain.")
def append(self, module: Module) -> None:
self.insert(-1, module)
def pop(self, index: int = -1) -> Module | tuple[Module]:
modules = list(self)
if index < 0:
index = len(modules) + index
if index < 0 or index >= len(modules):
raise IndexError("Index out of range.")
removed_module = modules.pop(index)
if isinstance(removed_module, ContextModule):
removed_module._set_parent(None)
self._regenerate_keys(modules)
return removed_module
def remove(self, module: Module) -> None:
"""Remove a module from the chain."""
modules = list(self)
try:
modules.remove(module)
except ValueError:
raise ValueError(f"{module} is not in {self}")
self._regenerate_keys(modules)
if isinstance(module, ContextModule):
module._set_parent(None)
def replace(
self,
old_module: Module,
new_module: Module,
old_module_parent: "Chain | None" = None,
) -> None:
"""Replace a module in the chain with a new module."""
modules = list(self)
try:
modules[modules.index(old_module)] = new_module
except ValueError:
raise ValueError(f"{old_module} is not in {self}")
self._regenerate_keys(modules)
if isinstance(new_module, ContextModule):
new_module._set_parent(self)
if isinstance(old_module, ContextModule):
old_module._set_parent(old_module_parent)
def structural_copy(self: TChain) -> TChain:
"""Copy the structure of the Chain tree.
This method returns a recursive copy of the Chain tree where all inner nodes
(instances of Chain and its subclasses) are duplicated and all leaves
(regular Modules) are not.
Such copies can be adapted without disrupting the base model, but do not
require extra GPU memory since the weights are in the leaves and hence not copied.
"""
if hasattr(self, "_pre_structural_copy"):
self._pre_structural_copy()
modules = [structural_copy(m) for m in self]
clone = super().structural_copy()
clone._provider = ContextProvider.create(clone.init_context())
for module in modules:
clone.append(module=module)
if hasattr(clone, "_post_structural_copy"):
clone._post_structural_copy(self)
return clone
def _show_only_tag(self) -> bool:
return self.__class__ == Chain
class Parallel(Chain):
_tag = "PAR"
def forward(self, *args: Any) -> tuple[Tensor, ...]:
return tuple([self._call_layer(module, name, *args) for name, module in self._modules.items()])
def _show_only_tag(self) -> bool:
return self.__class__ == Parallel
class Distribute(Chain):
_tag = "DISTR"
def forward(self, *args: Any) -> tuple[Tensor, ...]:
n, m = len(args), len(self._modules)
assert n == m, f"Number of positional arguments ({n}) must match number of sub-modules ({m})."
return tuple([self._call_layer(module, name, arg) for arg, (name, module) in zip(args, self._modules.items())])
def _show_only_tag(self) -> bool:
return self.__class__ == Distribute
class Passthrough(Chain):
_tag = "PASS"
def forward(self, *inputs: Any) -> Any:
super().forward(*inputs)
return inputs
def _show_only_tag(self) -> bool:
return self.__class__ == Passthrough
class Sum(Chain):
_tag = "SUM"
def forward(self, *inputs: Any) -> Any:
output = None
for layer in self:
layer_output: Any = layer(*inputs)
if isinstance(layer_output, tuple):
layer_output = sum(layer_output) # type: ignore
output = layer_output if output is None else output + layer_output
return output
def _show_only_tag(self) -> bool:
return self.__class__ == Sum
class Residual(Chain):
_tag = "RES"
def forward(self, *inputs: Any) -> Any:
assert len(inputs) == 1, "Residual connection can only be used with a single input."
return super().forward(*inputs) + inputs[0]
class Breakpoint(ContextModule):
def __init__(self, vscode: bool = True):
super().__init__()
self.vscode = vscode
def forward(self, *args: Any):
if self.vscode:
import debugpy # type: ignore
debugpy.breakpoint() # type: ignore
else:
breakpoint()
return args[0] if len(args) == 1 else args
class Concatenate(Chain):
_tag = "CAT"
def __init__(self, *modules: Module, dim: int = 0) -> None:
super().__init__(*modules)
self.dim = dim
def forward(self, *args: Any) -> Tensor:
outputs = [module(*args) for module in self]
return cat([output for output in outputs if output is not None], dim=self.dim)
def _show_only_tag(self) -> bool:
return self.__class__ == Concatenate
class Matmul(Chain):
_tag = "MATMUL"
def __init__(self, input: Module, other: Module) -> None:
super().__init__(
input,
other,
)
def forward(self, *args: Tensor) -> Tensor:
return torch.matmul(input=self[0](*args), other=self[1](*args))

@ -0,0 +1,96 @@
from torch import device as Device, dtype as DType, nn
from imaginairy.vendored.refiners.fluxion.layers.module import WeightedModule
class Conv2d(nn.Conv2d, WeightedModule):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int | tuple[int, int],
stride: int | tuple[int, int] = (1, 1),
padding: int | tuple[int, int] | str = (0, 0),
groups: int = 1,
use_bias: bool = True,
dilation: int | tuple[int, int] = (1, 1),
padding_mode: str = "zeros",
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__( # type: ignore
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
use_bias,
padding_mode,
device,
dtype,
)
self.use_bias = use_bias
class Conv1d(nn.Conv1d, WeightedModule):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int | tuple[int],
stride: int | tuple[int] = 1,
padding: int | tuple[int] | str = 0,
groups: int = 1,
use_bias: bool = True,
dilation: int | tuple[int] = 1,
padding_mode: str = "zeros",
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__( # type: ignore
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
use_bias,
padding_mode,
device,
dtype,
)
class ConvTranspose2d(nn.ConvTranspose2d, WeightedModule):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int | tuple[int, int],
stride: int | tuple[int, int] = 1,
padding: int | tuple[int, int] = 0,
output_padding: int | tuple[int, int] = 0,
groups: int = 1,
use_bias: bool = True,
dilation: int | tuple[int, int] = 1,
padding_mode: str = "zeros",
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__( # type: ignore
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding,
dilation=dilation,
groups=groups,
bias=use_bias,
padding_mode=padding_mode,
device=device,
dtype=dtype,
)

@ -0,0 +1,45 @@
from torch import Tensor
from imaginairy.vendored.refiners.fluxion.layers.module import ContextModule
class Converter(ContextModule):
"""
A Converter class that adjusts tensor properties based on a parent module's settings.
This class inherits from `ContextModule` and provides functionality to adjust
the device and dtype of input tensor(s) to match the parent module's attributes.
Attributes:
set_device (bool): If True, matches the device of the input tensor(s) to the parent's device.
set_dtype (bool): If True, matches the dtype of the input tensor(s) to the parent's dtype.
Note:
Ensure the parent module has `device` and `dtype` attributes if `set_device` or `set_dtype` are set to True.
"""
def __init__(self, set_device: bool = True, set_dtype: bool = True) -> None:
super().__init__()
self.set_device = set_device
self.set_dtype = set_dtype
def forward(self, *inputs: Tensor) -> tuple[Tensor, ...]:
parent = self.ensure_parent
converted_tensors: list[Tensor] = []
for x in inputs:
if self.set_device:
device = parent.device
assert device is not None, "parent has no device"
x = x.to(device=device)
if self.set_dtype:
dtype = parent.dtype
assert dtype is not None, "parent has no dtype"
x = x.to(dtype=dtype)
converted_tensors.append(x)
return tuple(converted_tensors)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(set_device={self.set_device}, set_dtype={self.set_dtype})"

@ -0,0 +1,21 @@
from jaxtyping import Float, Int
from torch import Tensor, device as Device, dtype as DType
from torch.nn import Embedding as _Embedding
from imaginairy.vendored.refiners.fluxion.layers.module import WeightedModule
class Embedding(_Embedding, WeightedModule): # type: ignore
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
device: Device | str | None = None,
dtype: DType | None = None,
):
_Embedding.__init__( # type: ignore
self, num_embeddings=num_embeddings, embedding_dim=embedding_dim, device=device, dtype=dtype
)
def forward(self, x: Int[Tensor, "batch length"]) -> Float[Tensor, "batch length embedding_dim"]: # type: ignore
return super().forward(x)

@ -0,0 +1,49 @@
from jaxtyping import Float
from torch import Tensor, device as Device, dtype as DType
from torch.nn import Linear as _Linear
from imaginairy.vendored.refiners.fluxion.layers.activations import ReLU
from imaginairy.vendored.refiners.fluxion.layers.chain import Chain
from imaginairy.vendored.refiners.fluxion.layers.module import Module, WeightedModule
class Linear(_Linear, WeightedModule):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.in_features = in_features
self.out_features = out_features
super().__init__( # type: ignore
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
dtype=dtype,
)
def forward(self, x: Float[Tensor, "batch in_features"]) -> Float[Tensor, "batch out_features"]: # type: ignore
return super().forward(x)
class MultiLinear(Chain):
def __init__(
self,
input_dim: int,
output_dim: int,
inner_dim: int,
num_layers: int,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
layers: list[Module] = []
for i in range(num_layers - 1):
layers.append(Linear(input_dim if i == 0 else inner_dim, inner_dim, device=device, dtype=dtype))
layers.append(ReLU())
layers.append(Linear(inner_dim, output_dim, device=device, dtype=dtype))
super().__init__(layers)

@ -0,0 +1,43 @@
from torch import nn
from imaginairy.vendored.refiners.fluxion.layers.module import Module
class MaxPool1d(nn.MaxPool1d, Module):
def __init__(
self,
kernel_size: int,
stride: int | None = None,
padding: int = 0,
dilation: int = 1,
return_indices: bool = False,
ceil_mode: bool = False,
) -> None:
super().__init__(
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
return_indices=return_indices,
ceil_mode=ceil_mode,
)
class MaxPool2d(nn.MaxPool2d, Module):
def __init__(
self,
kernel_size: int | tuple[int, int],
stride: int | tuple[int, int] | None = None,
padding: int | tuple[int, int] = (0, 0),
dilation: int | tuple[int, int] = (1, 1),
return_indices: bool = False,
ceil_mode: bool = False,
) -> None:
super().__init__(
kernel_size=kernel_size,
stride=stride,
padding=padding, # type: ignore
dilation=dilation,
return_indices=return_indices,
ceil_mode=ceil_mode,
)

@ -0,0 +1,264 @@
import sys
from collections import defaultdict
from inspect import Parameter, signature
from pathlib import Path
from types import ModuleType
from typing import TYPE_CHECKING, Any, DefaultDict, Generator, Sequence, TypedDict, TypeVar, cast
from torch import device as Device, dtype as DType
from torch.nn.modules.module import Module as TorchModule
from imaginairy.vendored.refiners.fluxion.context import Context, ContextProvider
from imaginairy.vendored.refiners.fluxion.utils import load_from_safetensors
if TYPE_CHECKING:
from imaginairy.vendored.refiners.fluxion.layers.chain import Chain
T = TypeVar("T", bound="Module")
TContextModule = TypeVar("TContextModule", bound="ContextModule")
BasicType = str | float | int | bool
class Module(TorchModule):
_parameters: dict[str, Any]
_buffers: dict[str, Any]
_tag: str = ""
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, *kwargs) # type: ignore[reportUnknownMemberType]
def __getattr__(self, name: str) -> Any:
return super().__getattr__(name=name)
def __setattr__(self, name: str, value: Any) -> None:
return super().__setattr__(name=name, value=value)
def load_from_safetensors(self, tensors_path: str | Path, strict: bool = True) -> "Module":
state_dict = load_from_safetensors(tensors_path)
self.load_state_dict(state_dict, strict=strict)
return self
def named_modules(self, *args: Any, **kwargs: Any) -> "Generator[tuple[str, Module], None, None]": # type: ignore
return super().named_modules(*args) # type: ignore
def to(self: T, device: Device | str | None = None, dtype: DType | None = None) -> T: # type: ignore
return super().to(device=device, dtype=dtype) # type: ignore
def __str__(self) -> str:
basic_attributes_str = ", ".join(
f"{key}={value}" for key, value in self.basic_attributes(init_attrs_only=True).items()
)
result = f"{self.__class__.__name__}({basic_attributes_str})"
return result
def __repr__(self) -> str:
tree = ModuleTree(module=self)
return repr(tree)
def pretty_print(self, depth: int = -1) -> None:
tree = ModuleTree(module=self)
print(tree._generate_tree_repr(tree.root, is_root=True, depth=depth)) # type: ignore[reportPrivateUsage]
def basic_attributes(self, init_attrs_only: bool = False) -> dict[str, BasicType]:
"""Return a dictionary of basic attributes of the module.
Basic attributes are public attributes made of basic types (int, float, str, bool) or a sequence of basic types.
"""
sig = signature(obj=self.__init__)
init_params = set(sig.parameters.keys()) - {"self"}
default_values = {k: v.default for k, v in sig.parameters.items() if v.default is not Parameter.empty}
def is_basic_attribute(key: str, value: Any) -> bool:
if key.startswith("_"):
return False
if isinstance(value, BasicType):
return True
if isinstance(value, Sequence) and all(isinstance(y, BasicType) for y in cast(Sequence[Any], value)):
return True
return False
return {
key: str(object=value)
for key, value in self.__dict__.items()
if is_basic_attribute(key=key, value=value)
and (not init_attrs_only or (key in init_params and value != default_values.get(key)))
}
def _show_only_tag(self) -> bool:
"""Whether to show only the tag when printing the module.
This is useful to distinguish between Chain subclasses that override their forward from one another.
"""
return False
class ContextModule(Module):
# we store parent into a one element list to avoid pytorch thinking it's a submodule
_parent: "list[Chain]"
_can_refresh_parent: bool = True # see usage in Adapter and Chain
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, *kwargs)
self._parent = []
@property
def parent(self) -> "Chain | None":
return self._parent[0] if self._parent else None
@property
def ensure_parent(self) -> "Chain":
assert self._parent, "module is not bound to a Chain"
return self._parent[0]
def _set_parent(self, parent: "Chain | None") -> None:
if not self._can_refresh_parent:
return
if parent is None:
self._parent = []
return
# Always insert the module in the Chain first to avoid inconsistencies.
assert self in iter(parent), f"{self} not in {parent}"
self._parent = [parent]
@property
def provider(self) -> ContextProvider:
return self.ensure_parent.provider
def get_parents(self) -> "list[Chain]":
return self._parent + self._parent[0].get_parents() if self._parent else []
def use_context(self, context_name: str) -> Context:
"""Retrieve the context object from the module's context provider."""
context = self.provider.get_context(context_name)
assert context is not None, f"Context {context_name} not found."
return context
def structural_copy(self: TContextModule) -> TContextModule:
clone = object.__new__(self.__class__)
not_torch_attributes = [
key
for key, value in self.__dict__.items()
if not key.startswith("_")
and isinstance(sys.modules.get(type(value).__module__), ModuleType)
and "torch" not in sys.modules[type(value).__module__].__name__
]
for k in not_torch_attributes:
setattr(clone, k, getattr(self, k))
ContextModule.__init__(self=clone)
return clone
class WeightedModule(Module):
@property
def device(self) -> Device:
return self.weight.device
@property
def dtype(self) -> DType:
return self.weight.dtype
class TreeNode(TypedDict):
value: str
class_name: str
children: list["TreeNode"]
class ModuleTree:
def __init__(self, module: Module) -> None:
self.root: TreeNode = self._module_to_tree(module=module)
self._fold_successive_identical(node=self.root)
def __str__(self) -> str:
return f"{self.__class__.__name__}(root={self.root['value']})"
def __repr__(self) -> str:
return self._generate_tree_repr(self.root, is_root=True, depth=7)
def __iter__(self) -> Generator[TreeNode, None, None]:
for child in self.root["children"]:
yield child
@classmethod
def shorten_tree_repr(cls, tree_repr: str, /, line_index: int = 0, max_lines: int = 20) -> str:
"""Shorten the tree representation to a given number of lines around a given line index."""
lines = tree_repr.split(sep="\n")
start_idx = max(0, line_index - max_lines // 2)
end_idx = min(len(lines), line_index + max_lines // 2 + 1)
return "\n".join(lines[start_idx:end_idx])
def _generate_tree_repr(
self, node: TreeNode, /, *, prefix: str = "", is_last: bool = True, is_root: bool = True, depth: int = -1
) -> str:
if depth == 0 and node["children"]:
return f"{prefix}{'└── ' if is_last else '├── '}{node['value']} ..."
if depth > 0:
depth -= 1
tree_icon: str = "" if is_root else ("└── " if is_last else "├── ")
counts: DefaultDict[str, int] = defaultdict(int)
for child in node["children"]:
counts[child["class_name"]] += 1
instance_counts: DefaultDict[str, int] = defaultdict(int)
lines = [f"{prefix}{tree_icon}{node['value']}"]
new_prefix: str = " " if is_last else ""
for i, child in enumerate(iterable=node["children"]):
instance_counts[child["class_name"]] += 1
if counts[child["class_name"]] > 1:
child_value = f"{child['value']} #{instance_counts[child['class_name']]}"
else:
child_value = child["value"]
child_str = self._generate_tree_repr(
{"value": child_value, "class_name": child["class_name"], "children": child["children"]},
prefix=prefix + new_prefix,
is_last=i == len(node["children"]) - 1,
is_root=False,
depth=depth,
)
if child_str:
lines.append(child_str)
return "\n".join(lines)
def _module_to_tree(self, module: Module) -> TreeNode:
match (module._tag, module._show_only_tag()): # pyright: ignore[reportPrivateUsage]
case ("", False):
value = str(module)
case (_, True):
value = f"({module._tag})" # pyright: ignore[reportPrivateUsage]
case (_, False):
value = f"({module._tag}) {module}" # pyright: ignore[reportPrivateUsage]
class_name = module.__class__.__name__
node: TreeNode = {"value": value, "class_name": class_name, "children": []}
for child in module.children():
node["children"].append(self._module_to_tree(module=child)) # type: ignore
return node
def _fold_successive_identical(self, node: TreeNode) -> None:
i = 0
while i < len(node["children"]):
j = i
while j < len(node["children"]) and node["children"][i] == node["children"][j]:
j += 1
count = j - i
if count > 1:
node["children"][i]["value"] += f" (x{count})"
del node["children"][i + 1 : j]
self._fold_successive_identical(node=node["children"][i])
i += 1

@ -0,0 +1,88 @@
from jaxtyping import Float
from torch import Tensor, device as Device, dtype as DType, nn, ones, sqrt, zeros
from imaginairy.vendored.refiners.fluxion.layers.module import Module, WeightedModule
class LayerNorm(nn.LayerNorm, WeightedModule):
def __init__(
self,
normalized_shape: int | list[int],
eps: float = 0.00001,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__( # type: ignore
normalized_shape=normalized_shape,
eps=eps,
elementwise_affine=True, # otherwise not a WeightedModule
device=device,
dtype=dtype,
)
class GroupNorm(nn.GroupNorm, WeightedModule):
def __init__(
self,
channels: int,
num_groups: int,
eps: float = 1e-5,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__( # type: ignore
num_groups=num_groups,
num_channels=channels,
eps=eps,
affine=True, # otherwise not a WeightedModule
device=device,
dtype=dtype,
)
self.channels = channels
self.num_groups = num_groups
self.eps = eps
class LayerNorm2d(WeightedModule):
"""
2D Layer Normalization module.
Parameters:
channels (int): Number of channels in the input tensor.
eps (float, optional): A small constant for numerical stability. Default: 1e-6.
"""
def __init__(
self,
channels: int,
eps: float = 1e-6,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__()
self.weight = nn.Parameter(ones(channels, device=device, dtype=dtype))
self.bias = nn.Parameter(zeros(channels, device=device, dtype=dtype))
self.eps = eps
def forward(self, x: Float[Tensor, "batch channels height width"]) -> Float[Tensor, "batch channels height width"]:
x_mean = x.mean(1, keepdim=True)
x_var = (x - x_mean).pow(2).mean(1, keepdim=True)
x_norm = (x - x_mean) / sqrt(x_var + self.eps)
x_out = self.weight.unsqueeze(-1).unsqueeze(-1) * x_norm + self.bias.unsqueeze(-1).unsqueeze(-1)
return x_out
class InstanceNorm2d(nn.InstanceNorm2d, Module):
def __init__(
self,
num_features: int,
eps: float = 1e-05,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__( # type: ignore
num_features=num_features,
eps=eps,
device=device,
dtype=dtype,
)

@ -0,0 +1,8 @@
from torch import nn
from imaginairy.vendored.refiners.fluxion.layers.module import Module
class ReflectionPad2d(nn.ReflectionPad2d, Module):
def __init__(self, padding: int) -> None:
super().__init__(padding=padding)

@ -0,0 +1,8 @@
from torch.nn import PixelUnshuffle as _PixelUnshuffle
from imaginairy.vendored.refiners.fluxion.layers.module import Module
class PixelUnshuffle(_PixelUnshuffle, Module):
def __init__(self, downscale_factor: int):
_PixelUnshuffle.__init__(self, downscale_factor=downscale_factor)

@ -0,0 +1,99 @@
from typing import Callable
from torch import Size, Tensor, device as Device, dtype as DType
from torch.nn.functional import pad
from imaginairy.vendored.refiners.fluxion.layers.basics import Identity
from imaginairy.vendored.refiners.fluxion.layers.chain import Chain, Lambda, Parallel, SetContext, UseContext
from imaginairy.vendored.refiners.fluxion.layers.conv import Conv2d
from imaginairy.vendored.refiners.fluxion.layers.module import Module
from imaginairy.vendored.refiners.fluxion.utils import interpolate
class Downsample(Chain):
def __init__(
self,
channels: int,
scale_factor: int,
padding: int = 0,
register_shape: bool = True,
device: Device | str | None = None,
dtype: DType | None = None,
):
"""Downsamples the input by the given scale factor.
If register_shape is True, the input shape is registered in the context. It will throw an error if the context
sampling is not set or if the context does not contain a list.
"""
self.channels = channels
self.in_channels = channels
self.out_channels = channels
self.scale_factor = scale_factor
self.padding = padding
super().__init__(
Conv2d(
in_channels=channels,
out_channels=channels,
kernel_size=3,
stride=scale_factor,
padding=padding,
device=device,
dtype=dtype,
),
)
if padding == 0:
zero_pad: Callable[[Tensor], Tensor] = lambda x: pad(x, (0, 1, 0, 1))
self.insert(0, Lambda(zero_pad))
if register_shape:
self.insert(0, SetContext(context="sampling", key="shapes", callback=self.register_shape))
def register_shape(self, shapes: list[Size], x: Tensor) -> None:
shapes.append(x.shape[2:])
class Interpolate(Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x: Tensor, shape: Size) -> Tensor:
return interpolate(x, shape)
class Upsample(Chain):
def __init__(
self,
channels: int,
upsample_factor: int | None = None,
device: Device | str | None = None,
dtype: DType | None = None,
):
"""Upsamples the input by the given scale factor.
If upsample_factor is None, the input shape is taken from the context. It will throw an error if the context
sampling is not set or if the context is empty (then you should use the dynamic version of Downsample).
"""
self.channels = channels
self.upsample_factor = upsample_factor
super().__init__(
Parallel(
Identity(),
(
Lambda(self._get_static_shape)
if upsample_factor is not None
else UseContext(context="sampling", key="shapes").compose(lambda x: x.pop())
),
),
Interpolate(),
Conv2d(
in_channels=channels,
out_channels=channels,
kernel_size=3,
padding=1,
device=device,
dtype=dtype,
),
)
def _get_static_shape(self, x: Tensor) -> Size:
assert self.upsample_factor is not None
return Size([size * self.upsample_factor for size in x.shape[2:]])

@ -0,0 +1,644 @@
from collections import defaultdict
from enum import Enum, auto
from pathlib import Path
from typing import Any, DefaultDict, TypedDict
import torch
from torch import Tensor, nn
from torch.utils.hooks import RemovableHandle
from imaginairy.vendored.refiners.fluxion.utils import no_grad, norm, save_to_safetensors
TORCH_BASIC_LAYERS: list[type[nn.Module]] = [
nn.Conv1d,
nn.Conv2d,
nn.Conv3d,
nn.ConvTranspose1d,
nn.ConvTranspose2d,
nn.ConvTranspose3d,
nn.Linear,
nn.BatchNorm1d,
nn.BatchNorm2d,
nn.BatchNorm3d,
nn.LayerNorm,
nn.GroupNorm,
nn.Embedding,
nn.MaxPool2d,
nn.AvgPool2d,
nn.AdaptiveAvgPool2d,
]
ModelTypeShape = tuple[type[nn.Module], tuple[torch.Size, ...]]
class ModuleArgsDict(TypedDict):
"""Represents positional and keyword arguments passed to a module.
- `positional`: A tuple of positional arguments.
- `keyword`: A dictionary of keyword arguments.
"""
positional: tuple[Any, ...]
keyword: dict[str, Any]
class ConversionStage(Enum):
"""Represents the current stage of the conversion process.
- `INIT`: The conversion process has not started.
- `BASIC_LAYERS_MATCH`: The source and target models have the same number of basic layers.
"""
INIT = auto()
BASIC_LAYERS_MATCH = auto()
SHAPE_AND_LAYERS_MATCH = auto()
MODELS_OUTPUT_AGREE = auto()
class ModelConverter:
ModuleArgs = tuple[Any, ...] | dict[str, Any] | ModuleArgsDict
stage: ConversionStage = ConversionStage.INIT
_stored_mapping: dict[str, str] | None = None
def __init__(
self,
source_model: nn.Module,
target_model: nn.Module,
source_keys_to_skip: list[str] | None = None,
target_keys_to_skip: list[str] | None = None,
custom_layer_mapping: dict[type[nn.Module], type[nn.Module]] | None = None,
threshold: float = 1e-5,
skip_output_check: bool = False,
skip_init_check: bool = False,
verbose: bool = True,
) -> None:
"""
Create a ModelConverter.
- `source_model`: The model to convert from.
- `target_model`: The model to convert to.
- `source_keys_to_skip`: A list of keys to skip when tracing the source model.
- `target_keys_to_skip`: A list of keys to skip when tracing the target model.
- `custom_layer_mapping`: A dictionary mapping custom layer types between the source and target models.
- `threshold`: The threshold for comparing outputs between the source and target models.
- `skip_output_check`: Whether to skip comparing the outputs of the source and target models.
- `skip_init_check`: Whether to skip checking that the source and target models have the same number of basic
layers.
- `verbose`: Whether to print messages during the conversion process.
The conversion process consists of three stages:
1. Verify that the source and target models have the same number of basic layers.
2. Find matching shapes and layers between the source and target models.
3. Convert the source model's state_dict to match the target model's state_dict.
4. Compare the outputs of the source and target models.
The conversion process can be run multiple times, and will resume from the last stage.
### Example:
```
converter = ModelConverter(source_model=source, target_model=target, threshold=0.1, verbose=False)
is_converted = converter(args)
if is_converted:
converter.save_to_safetensors(path="test.pt")
```
"""
self.source_model = source_model
self.target_model = target_model
self.source_keys_to_skip = source_keys_to_skip or []
self.target_keys_to_skip = target_keys_to_skip or []
self.custom_layer_mapping = custom_layer_mapping or {}
self.threshold = threshold
self.skip_output_check = skip_output_check
self.skip_init_check = skip_init_check
self.verbose = verbose
def __repr__(self) -> str:
return (
f"ModelConverter(source_model={self.source_model.__class__.__name__},"
f" target_model={self.target_model.__class__.__name__}, stage={self.stage})"
)
def __bool__(self) -> bool:
return self.stage.value >= 2 if self.skip_output_check else self.stage.value >= 3
def run(self, source_args: ModuleArgs, target_args: ModuleArgs | None = None) -> bool:
"""
Run the conversion process.
- `source_args`: The arguments to pass to the source model it can be either a tuple of positional arguments,
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args`
is not provided, these arguments will also be passed to the target model.
- `target_args`: The arguments to pass to the target model it can be either a tuple of positional arguments,
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
### Returns:
- `True` if the conversion process is done and the models agree.
The conversion process consists of three stages:
1. Verify that the source and target models have the same number of basic layers.
2. Find matching shapes and layers between the source and target models.
3. Convert the source model's state_dict to match the target model's state_dict.
4. Compare the outputs of the source and target models.
The conversion process can be run multiple times, and will resume from the last stage.
"""
if target_args is None:
target_args = source_args
match self.stage:
case ConversionStage.MODELS_OUTPUT_AGREE:
self._increment_stage()
return True
case ConversionStage.SHAPE_AND_LAYERS_MATCH if self._run_shape_and_layers_match_stage(
source_args=source_args, target_args=target_args
):
self._increment_stage()
return True
case ConversionStage.BASIC_LAYERS_MATCH if self._run_basic_layers_match_stage(
source_args=source_args, target_args=target_args
):
self._increment_stage()
return self.run(source_args=source_args, target_args=target_args)
case ConversionStage.INIT if self._run_init_stage():
self._increment_stage()
return self.run(source_args=source_args, target_args=target_args)
case _:
self._log(message=f"Conversion failed at stage {self.stage.value}")
return False
def _increment_stage(self) -> None:
"""Increment the stage of the conversion process."""
match self.stage:
case ConversionStage.INIT:
self.stage = ConversionStage.BASIC_LAYERS_MATCH
self._log(
message=(
"Stage 0 -> 1 - Models have the same number of basic layers. Finding matching shapes and"
" layers..."
)
)
case ConversionStage.BASIC_LAYERS_MATCH:
self.stage = ConversionStage.SHAPE_AND_LAYERS_MATCH
self._log(
message=(
"Stage 1 -> 2 - Shape of both models agree. Applying state_dict to target model. Comparing"
" models..."
)
)
case ConversionStage.SHAPE_AND_LAYERS_MATCH:
if self.skip_output_check:
self._log(
message=(
"Stage 2 - Nothing to do. Skipping output check. If you want to compare the outputs, set"
" `skip_output_check` to `False`"
)
)
else:
self.stage = ConversionStage.MODELS_OUTPUT_AGREE
self._log(
message=(
"Stage 2 -> 3 - Conversion is done and source and target models agree: you can export the"
" converted model using `save_to_safetensors`"
)
)
case ConversionStage.MODELS_OUTPUT_AGREE:
self._log(
message=(
"Stage 3 - Nothing to do. Conversion is done and source and target models agree: you can export"
" the converted model using `save_to_safetensors`"
)
)
def get_state_dict(self) -> dict[str, Tensor]:
"""Get the converted state_dict."""
if not self:
raise ValueError("The conversion process is not done yet. Run `converter(args)` first.")
return self.target_model.state_dict()
def get_mapping(self) -> dict[str, str]:
"""Get the mapping between the source and target models' state_dicts."""
if not self:
raise ValueError("The conversion process is not done yet. Run `converter(args)` first.")
assert self._stored_mapping is not None, "Mapping is not stored"
return self._stored_mapping
def save_to_safetensors(self, path: Path | str, metadata: dict[str, str] | None = None, half: bool = False) -> None:
"""Save the converted model to a SafeTensors file.
This method can only be called after the conversion process is done.
- `path`: The path to save the converted model to.
- `metadata`: Metadata to save with the converted model.
- `half`: Whether to save the converted model as half precision.
### Raises:
- `ValueError` if the conversion process is not done yet. Run `converter(args)` first.
"""
if not self:
raise ValueError("The conversion process is not done yet. Run `converter(args)` first.")
state_dict = self.get_state_dict()
if half:
state_dict = {key: value.half() for key, value in state_dict.items()}
save_to_safetensors(path=path, tensors=state_dict, metadata=metadata)
def map_state_dicts(
self,
source_args: ModuleArgs,
target_args: ModuleArgs | None = None,
) -> dict[str, str] | None:
"""
Find a mapping between the source and target models' state_dicts.
- `source_args`: The arguments to pass to the source model it can be either a tuple of positional arguments,
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args`
is not provided, these arguments will also be passed to the target model.
- `target_args`: The arguments to pass to the target model it can be either a tuple of positional arguments,
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
### Returns:
- A dictionary mapping keys in the target model's state_dict to keys in the source model's state_dict.
"""
if target_args is None:
target_args = source_args
source_order = self._trace_module_execution_order(
module=self.source_model, args=source_args, keys_to_skip=self.source_keys_to_skip
)
target_order = self._trace_module_execution_order(
module=self.target_model, args=target_args, keys_to_skip=self.target_keys_to_skip
)
if not self._assert_shapes_aligned(source_order=source_order, target_order=target_order):
return None
mapping: dict[str, str] = {}
for source_type_shape in source_order:
source_keys = source_order[source_type_shape]
target_type_shape = source_type_shape
if not self._is_torch_basic_layer(module_type=source_type_shape[0]):
for source_custom_type, target_custom_type in self.custom_layer_mapping.items():
if source_custom_type == source_type_shape[0]:
target_type_shape = (target_custom_type, source_type_shape[1])
break
target_keys = target_order[target_type_shape]
mapping.update(zip(target_keys, source_keys))
return mapping
def compare_models(
self,
source_args: ModuleArgs,
target_args: ModuleArgs | None = None,
threshold: float = 1e-5,
) -> bool:
"""
Compare the outputs of the source and target models.
- `source_args`: The arguments to pass to the source model it can be either a tuple of positional arguments,
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args`
is not provided, these arguments will also be passed to the target model.
- `target_args`: The arguments to pass to the target model it can be either a tuple of positional arguments,
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
- `threshold`: The threshold for comparing outputs between the source and target models.
"""
if target_args is None:
target_args = source_args
source_outputs = self._collect_layers_outputs(
module=self.source_model, args=source_args, keys_to_skip=self.source_keys_to_skip
)
target_outputs = self._collect_layers_outputs(
module=self.target_model, args=target_args, keys_to_skip=self.target_keys_to_skip
)
diff, prev_source_key, prev_target_key = None, None, None
for (source_key, source_output), (target_key, target_output) in zip(source_outputs, target_outputs):
diff = norm(source_output - target_output.reshape(shape=source_output.shape)).item()
if diff > threshold:
self._log(
f"Models diverged between {prev_source_key} and {source_key}, and between {prev_target_key} and"
f" {target_key}, difference in norm: {diff}"
)
return False
prev_source_key, prev_target_key = source_key, target_key
self._log(message=f"Models agree. Difference in norm: {diff}")
return True
def _run_init_stage(self) -> bool:
"""Run the init stage of the conversion process."""
if self.skip_init_check:
self._log(
message=(
"Skipping init check. If you want to check the number of basic layers, set `skip_init_check` to"
" `False`"
)
)
return True
is_count_correct = self._verify_basic_layers_count()
is_not_missing_layers = self._verify_missing_basic_layers()
return is_count_correct and is_not_missing_layers
def _run_basic_layers_match_stage(self, source_args: ModuleArgs, target_args: ModuleArgs | None) -> bool:
"""Run the basic layers match stage of the conversion process."""
mapping = self.map_state_dicts(source_args=source_args, target_args=target_args)
self._stored_mapping = mapping
if mapping is None:
self._log(message="Models do not have matching shapes.")
return False
source_state_dict = self.source_model.state_dict()
target_state_dict = self.target_model.state_dict()
converted_state_dict = self._convert_state_dict(
source_state_dict=source_state_dict, target_state_dict=target_state_dict, state_dict_mapping=mapping
)
self.target_model.load_state_dict(state_dict=converted_state_dict)
return True
def _run_shape_and_layers_match_stage(self, source_args: ModuleArgs, target_args: ModuleArgs | None) -> bool:
"""Run the shape and layers match stage of the conversion process."""
if self.skip_output_check:
self._log(
message="Skipping output check. If you want to compare the outputs, set `skip_output_check` to `False`"
)
return True
try:
if self.compare_models(source_args=source_args, target_args=target_args, threshold=self.threshold):
self._log(message="Models agree. You can export the converted model using `save_to_safetensors`")
return True
else:
self._log(message="Models do not agree. Try to increase the threshold or modify the models.")
return False
except Exception as e:
self._log(message=f"An error occurred while comparing the models: {e}")
return False
def _log(self, message: str) -> None:
"""Print a message if `verbose` is `True`."""
if self.verbose:
print(message)
def _debug_print_shapes(
self,
shape: ModelTypeShape,
source_keys: list[str],
target_keys: list[str],
) -> None:
"""Print the shapes of the sub-modules in `source_keys` and `target_keys`."""
self._log(message=f"{shape}")
max_len = max(len(source_keys), len(target_keys))
for i in range(max_len):
source_key = source_keys[i] if i < len(source_keys) else "---"
target_key = target_keys[i] if i < len(target_keys) else "---"
self._log(f"\t{source_key}\t{target_key}")
@staticmethod
def _unpack_module_args(module_args: ModuleArgs) -> tuple[tuple[Any, ...], dict[str, Any]]:
"""Unpack the positional and keyword arguments passed to a module."""
match module_args:
case tuple(positional_args):
keyword_args: dict[str, Any] = {}
case {"positional": positional_args, "keyword": keyword_args}:
pass
case _:
positional_args = ()
keyword_args = dict(**module_args)
return positional_args, keyword_args
def _is_torch_basic_layer(self, module_type: type[nn.Module]) -> bool:
"""Check if a module type is a subclass of a torch basic layer."""
return any(issubclass(module_type, torch_basic_layer) for torch_basic_layer in TORCH_BASIC_LAYERS)
def _infer_basic_layer_type(self, module: nn.Module) -> type[nn.Module] | None:
"""Infer the type of a basic layer."""
layer_types = (
set(self.custom_layer_mapping.keys()) | set(self.custom_layer_mapping.values()) | set(TORCH_BASIC_LAYERS)
)
for layer_type in layer_types:
if isinstance(module, layer_type):
return layer_type
return None
def get_module_signature(self, module: nn.Module) -> ModelTypeShape:
"""Get the signature of a module."""
layer_type = self._infer_basic_layer_type(module=module)
assert layer_type is not None, f"Module {module} is not a basic layer"
param_shapes = [p.shape for p in module.parameters()]
return (layer_type, tuple(param_shapes))
def _count_basic_layers(self, module: nn.Module) -> dict[type[nn.Module], int]:
"""Count the number of basic layers in a module."""
count: DefaultDict[type[nn.Module], int] = defaultdict(int)
for submodule in module.modules():
layer_type = self._infer_basic_layer_type(module=submodule)
if layer_type is not None:
count[layer_type] += 1
return count
def _verify_basic_layers_count(self) -> bool:
"""Verify that the source and target models have the same number of basic layers."""
source_layers = self._count_basic_layers(module=self.source_model)
target_layers = self._count_basic_layers(module=self.target_model)
reverse_mapping = {v: k for k, v in self.custom_layer_mapping.items()}
diff: dict[type[nn.Module], tuple[int, int]] = {}
for layer_type, source_count in source_layers.items():
target_type = self.custom_layer_mapping.get(layer_type, layer_type)
target_count = target_layers.get(target_type, 0)
if source_count != target_count:
diff[layer_type] = (source_count, target_count)
for layer_type, target_count in target_layers.items():
source_type = reverse_mapping.get(layer_type, layer_type)
source_count = source_layers.get(source_type, 0)
if source_count != target_count:
diff[layer_type] = (source_count, target_count)
if diff:
message = "Models do not have the same number of basic layers:\n"
for layer_type, counts in diff.items():
message += f" {layer_type}: Source {counts[0]} - Target {counts[1]}\n"
self._log(message=message.strip())
return False
return True
def _is_weighted_leaf_module(self, module: nn.Module) -> bool:
"""Check if a module is a leaf module with weights."""
return next(module.parameters(), None) is not None and next(module.children(), None) is None
def _check_for_missing_basic_layers(self, module: nn.Module) -> list[type[nn.Module]]:
"""Check if a module has weighted leaf modules that are not basic layers."""
return [
type(submodule)
for submodule in module.modules()
if self._is_weighted_leaf_module(module=submodule) and not self._infer_basic_layer_type(module=submodule)
]
def _verify_missing_basic_layers(self) -> bool:
"""Verify that the source and target models do not have missing basic layers."""
missing_source_layers = self._check_for_missing_basic_layers(module=self.source_model)
missing_target_layers = self._check_for_missing_basic_layers(module=self.target_model)
if missing_source_layers or missing_target_layers:
self._log(
message=(
"Models might have missing basic layers. If you want to skip this check, set"
f" `skip_init_check` to `True`: {missing_source_layers}, {missing_target_layers}"
)
)
return False
return True
@no_grad()
def _trace_module_execution_order(
self,
module: nn.Module,
args: ModuleArgs,
keys_to_skip: list[str],
) -> dict[ModelTypeShape, list[str]]:
"""
Execute a forward pass and store the order of execution of specific sub-modules.
- `module`: The module to trace.
- `args`: The arguments to pass to the module it can be either a tuple of positional arguments,
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
- `keys_to_skip`: A list of keys to skip when tracing the module.
### Returns:
- A dictionary mapping the signature of each sub-module to a list of keys in the module's `named_modules`
"""
submodule_to_key: dict[nn.Module, str] = {}
execution_order: defaultdict[ModelTypeShape, list[str]] = defaultdict(list)
def collect_execution_order_hook(layer: nn.Module, *_: Any) -> None:
layer_signature = self.get_module_signature(module=layer)
execution_order[layer_signature].append(submodule_to_key[layer])
hooks: list[RemovableHandle] = []
named_modules: list[tuple[str, nn.Module]] = module.named_modules() # type: ignore
for name, submodule in named_modules:
if (self._infer_basic_layer_type(module=submodule) is not None) and name not in keys_to_skip:
submodule_to_key[submodule] = name # type: ignore
hook = submodule.register_forward_hook(hook=collect_execution_order_hook)
hooks.append(hook)
positional_args, keyword_args = self._unpack_module_args(module_args=args)
module(*positional_args, **keyword_args)
for hook in hooks:
hook.remove()
return dict(execution_order)
def _assert_shapes_aligned(
self, source_order: dict[ModelTypeShape, list[str]], target_order: dict[ModelTypeShape, list[str]]
) -> bool:
"""Assert that the shapes of the sub-modules in `source_order` and `target_order` are aligned."""
model_type_shapes = set(source_order.keys()) | set(target_order.keys())
default_type_shapes = [
type_shape for type_shape in model_type_shapes if self._is_torch_basic_layer(module_type=type_shape[0])
]
shape_mismatched = False
for model_type_shape in default_type_shapes:
source_keys = source_order.get(model_type_shape, [])
target_keys = target_order.get(model_type_shape, [])
if len(source_keys) != len(target_keys):
shape_mismatched = True
self._debug_print_shapes(shape=model_type_shape, source_keys=source_keys, target_keys=target_keys)
for source_custom_type in self.custom_layer_mapping.keys():
# iterate over all type_shapes that have the same type as source_custom_type
for source_type_shape in [
type_shape for type_shape in model_type_shapes if type_shape[0] == source_custom_type
]:
source_keys = source_order.get(source_type_shape, [])
target_custom_type = self.custom_layer_mapping[source_custom_type]
target_type_shape = (target_custom_type, source_type_shape[1])
target_keys = target_order.get(target_type_shape, [])
if len(source_keys) != len(target_keys):
shape_mismatched = True
self._debug_print_shapes(shape=source_type_shape, source_keys=source_keys, target_keys=target_keys)
return not shape_mismatched
@staticmethod
def _convert_state_dict(
source_state_dict: dict[str, Tensor], target_state_dict: dict[str, Tensor], state_dict_mapping: dict[str, str]
) -> dict[str, Tensor]:
"""Convert the source model's state_dict to match the target model's state_dict."""
converted_state_dict: dict[str, Tensor] = {}
for target_key in target_state_dict:
target_prefix, suffix = target_key.rsplit(sep=".", maxsplit=1)
source_prefix = state_dict_mapping[target_prefix]
source_key = ".".join([source_prefix, suffix])
converted_state_dict[target_key] = source_state_dict[source_key]
return converted_state_dict
@no_grad()
def _collect_layers_outputs(
self, module: nn.Module, args: ModuleArgs, keys_to_skip: list[str]
) -> list[tuple[str, Tensor]]:
"""
Execute a forward pass and store the output of specific sub-modules.
- `module`: The module to trace.
- `args`: The arguments to pass to the module it can be either a tuple of positional arguments,
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
- `keys_to_skip`: A list of keys to skip when tracing the module.
### Returns:
- A list of tuples containing the key of each sub-module and its output.
### Note:
- The output of each sub-module is cloned to avoid memory leaks.
"""
submodule_to_key: dict[nn.Module, str] = {}
execution_order: list[tuple[str, Tensor]] = []
def collect_execution_order_hook(layer: nn.Module, _: Any, output: Tensor) -> None:
execution_order.append((submodule_to_key[layer], output.clone()))
hooks: list[RemovableHandle] = []
named_modules: list[tuple[str, nn.Module]] = module.named_modules() # type: ignore
for name, submodule in named_modules:
if (self._infer_basic_layer_type(module=submodule) is not None) and name not in keys_to_skip:
submodule_to_key[submodule] = name # type: ignore
hook = submodule.register_forward_hook(hook=collect_execution_order_hook)
hooks.append(hook)
positional_args, keyword_args = self._unpack_module_args(module_args=args)
module(*positional_args, **keyword_args)
for hook in hooks:
hook.remove()
return execution_order

@ -0,0 +1,206 @@
from pathlib import Path
from typing import Any, Iterable, Literal, TypeVar
import torch
from jaxtyping import Float
from numpy import array, float32
from PIL import Image
from safetensors import safe_open as _safe_open # type: ignore
from safetensors.torch import save_file as _save_file # type: ignore
from torch import (
Tensor,
device as Device,
dtype as DType,
manual_seed as _manual_seed, # type: ignore
no_grad as _no_grad, # type: ignore
norm as _norm, # type: ignore
)
from torch.nn.functional import conv2d, interpolate as _interpolate, pad as _pad # type: ignore
T = TypeVar("T")
E = TypeVar("E")
def norm(x: Tensor) -> Tensor:
return _norm(x) # type: ignore
def manual_seed(seed: int) -> None:
_manual_seed(seed)
class no_grad(_no_grad):
def __new__(cls, orig_func: Any | None = None) -> "no_grad": # type: ignore
return object.__new__(cls)
def pad(x: Tensor, pad: Iterable[int], value: float = 0.0, mode: str = "constant") -> Tensor:
return _pad(input=x, pad=pad, value=value, mode=mode) # type: ignore
def interpolate(x: Tensor, factor: float | torch.Size, mode: str = "nearest") -> Tensor:
return (
_interpolate(x, scale_factor=factor, mode=mode)
if isinstance(factor, float | int)
else _interpolate(x, size=factor, mode=mode)
) # type: ignore
# Adapted from https://github.com/pytorch/vision/blob/main/torchvision/transforms/_functional_tensor.py
def normalize(
tensor: Float[Tensor, "*batch channels height width"], mean: list[float], std: list[float]
) -> Float[Tensor, "*batch channels height width"]:
assert tensor.is_floating_point()
assert tensor.ndim >= 3
dtype = tensor.dtype
pixel_mean = torch.tensor(mean, dtype=dtype, device=tensor.device).view(-1, 1, 1)
pixel_std = torch.tensor(std, dtype=dtype, device=tensor.device).view(-1, 1, 1)
if (pixel_std == 0).any():
raise ValueError(f"std evaluated to zero after conversion to {dtype}, leading to division by zero.")
return (tensor - pixel_mean) / pixel_std
# Adapted from https://github.com/pytorch/vision/blob/main/torchvision/transforms/_functional_tensor.py
def gaussian_blur(
tensor: Float[Tensor, "*batch channels height width"],
kernel_size: int | tuple[int, int],
sigma: float | tuple[float, float] | None = None,
) -> Float[Tensor, "*batch channels height width"]:
assert torch.is_floating_point(tensor)
def get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Float[Tensor, "kernel_size"]:
ksize_half = (kernel_size - 1) * 0.5
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
kernel1d = pdf / pdf.sum()
return kernel1d
def get_gaussian_kernel2d(
kernel_size_x: int, kernel_size_y: int, sigma_x: float, sigma_y: float, dtype: DType, device: Device
) -> Float[Tensor, "kernel_size_y kernel_size_x"]:
kernel1d_x = get_gaussian_kernel1d(kernel_size_x, sigma_x).to(device, dtype=dtype)
kernel1d_y = get_gaussian_kernel1d(kernel_size_y, sigma_y).to(device, dtype=dtype)
kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :])
return kernel2d
def default_sigma(kernel_size: int) -> float:
return kernel_size * 0.15 + 0.35
if isinstance(kernel_size, int):
kx, ky = kernel_size, kernel_size
else:
kx, ky = kernel_size
if sigma is None:
sx, sy = default_sigma(kx), default_sigma(ky)
elif isinstance(sigma, float):
sx, sy = sigma, sigma
else:
assert isinstance(sigma, tuple)
sx, sy = sigma
channels = tensor.shape[-3]
kernel = get_gaussian_kernel2d(kx, ky, sx, sy, dtype=tensor.dtype, device=tensor.device)
kernel = kernel.expand(channels, 1, kernel.shape[0], kernel.shape[1])
# pad = (left, right, top, bottom)
tensor = pad(tensor, pad=(kx // 2, kx // 2, ky // 2, ky // 2), mode="reflect")
tensor = conv2d(tensor, weight=kernel, groups=channels)
return tensor
def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtype: DType | None = None) -> Tensor:
"""
Convert a PIL Image to a Tensor.
If the image is in mode `RGB` the tensor will have shape `[3, H, W]`, otherwise
`[1, H, W]` for mode `L` (grayscale) or `[4, H, W]` for mode `RGBA`.
Values are clamped to the range `[0, 1]`.
"""
image_tensor = torch.tensor(array(image).astype(float32) / 255.0, device=device, dtype=dtype)
match image.mode:
case "L":
image_tensor = image_tensor.unsqueeze(0)
case "RGBA" | "RGB":
image_tensor = image_tensor.permute(2, 0, 1)
case _:
raise ValueError(f"Unsupported image mode: {image.mode}")
return image_tensor.unsqueeze(0)
def tensor_to_image(tensor: Tensor) -> Image.Image:
"""
Convert a Tensor to a PIL Image.
The tensor must have shape `[1, channels, height, width]` where the number of
channels is either 1 (grayscale) or 3 (RGB) or 4 (RGBA).
Expected values are in the range `[0, 1]` and are clamped to this range.
"""
assert tensor.ndim == 4 and tensor.shape[0] == 1, f"Unsupported tensor shape: {tensor.shape}"
num_channels = tensor.shape[1]
tensor = tensor.clamp(0, 1).squeeze(0)
match num_channels:
case 1:
tensor = tensor.squeeze(0)
case 3 | 4:
tensor = tensor.permute(1, 2, 0)
case _:
raise ValueError(f"Unsupported number of channels: {num_channels}")
return Image.fromarray((tensor.cpu().numpy() * 255).astype("uint8")) # type: ignore[reportUnknownType]
def safe_open(
path: Path | str,
framework: Literal["pytorch", "tensorflow", "flax", "numpy"],
device: Device | str = "cpu",
) -> dict[str, Tensor]:
framework_mapping = {
"pytorch": "pt",
"tensorflow": "tf",
"flax": "flax",
"numpy": "numpy",
}
return _safe_open(str(path), framework=framework_mapping[framework], device=str(device)) # type: ignore
def load_from_safetensors(path: Path | str, device: Device | str = "cpu") -> dict[str, Tensor]:
with safe_open(path=path, framework="pytorch", device=device) as tensors: # type: ignore
return {key: tensors.get_tensor(key) for key in tensors.keys()} # type: ignore
def load_metadata_from_safetensors(path: Path | str) -> dict[str, str] | None:
with safe_open(path=path, framework="pytorch") as tensors: # type: ignore
return tensors.metadata() # type: ignore
def save_to_safetensors(path: Path | str, tensors: dict[str, Tensor], metadata: dict[str, str] | None = None) -> None:
_save_file(tensors, path, metadata) # type: ignore
def summarize_tensor(tensor: torch.Tensor, /) -> str:
return (
"Tensor("
+ ", ".join(
[
f"shape=({', '.join(map(str, tensor.shape))})",
f"dtype={str(object=tensor.dtype).removeprefix('torch.')}",
f"device={tensor.device}",
f"min={tensor.min():.2f}", # type: ignore
f"max={tensor.max():.2f}", # type: ignore
f"mean={tensor.mean():.2f}",
f"std={tensor.std():.2f}",
f"norm={norm(x=tensor):.2f}",
f"grad={tensor.requires_grad}",
]
)
+ ")"
)

@ -0,0 +1,48 @@
from torch import Tensor, arange, device as Device, dtype as DType
import imaginairy.vendored.refiners.fluxion.layers as fl
class PositionalEncoder(fl.Chain):
def __init__(
self,
max_sequence_length: int,
embedding_dim: int,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.max_sequence_length = max_sequence_length
self.embedding_dim = embedding_dim
super().__init__(
fl.Lambda(func=self.get_position_ids),
fl.Embedding(
num_embeddings=max_sequence_length,
embedding_dim=embedding_dim,
device=device,
dtype=dtype,
),
)
@property
def position_ids(self) -> Tensor:
return arange(end=self.max_sequence_length, device=self.device).reshape(1, -1)
def get_position_ids(self, x: Tensor) -> Tensor:
return self.position_ids[:, : x.shape[1]]
class FeedForward(fl.Chain):
def __init__(
self,
embedding_dim: int,
feedforward_dim: int,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.embedding_dim = embedding_dim
self.feedforward_dim = feedforward_dim
super().__init__(
fl.Linear(in_features=embedding_dim, out_features=feedforward_dim, device=device, dtype=dtype),
fl.GeLU(),
fl.Linear(in_features=feedforward_dim, out_features=embedding_dim, device=device, dtype=dtype),
)

@ -0,0 +1,148 @@
import re
from typing import cast
import torch.nn.functional as F
from torch import Tensor, cat, zeros
from torch.nn import Parameter
import imaginairy.vendored.refiners.fluxion.layers as fl
from imaginairy.vendored.refiners.fluxion.adapters.adapter import Adapter
from imaginairy.vendored.refiners.foundationals.clip.text_encoder import CLIPTextEncoder, TokenEncoder
from imaginairy.vendored.refiners.foundationals.clip.tokenizer import CLIPTokenizer
class EmbeddingExtender(fl.Chain, Adapter[TokenEncoder]):
old_weight: Parameter
new_weight: Parameter
def __init__(
self,
target: TokenEncoder,
) -> None:
with self.setup_adapter(target):
super().__init__(fl.Lambda(func=self.lookup))
self.old_weight = cast(Parameter, target.weight)
self.new_weight = Parameter(
zeros([0, target.embedding_dim], device=target.device, dtype=target.dtype)
) # requires_grad=True by default
# Use F.embedding instead of nn.Embedding to make sure that gradients can only be computed for the new embeddings
def lookup(self, x: Tensor) -> Tensor:
# Concatenate old and new weights for dynamic embedding updates during training
return F.embedding(x, cat([self.old_weight, self.new_weight]))
def add_embedding(self, embedding: Tensor) -> None:
assert embedding.shape == (self.old_weight.shape[1],)
self.new_weight = Parameter(
cat([self.new_weight, embedding.unsqueeze(0).to(self.new_weight.device, self.new_weight.dtype)])
)
@property
def num_embeddings(self) -> int:
return self.old_weight.shape[0] + self.new_weight.shape[0]
class TokenExtender(fl.Chain, Adapter[CLIPTokenizer]):
def __init__(self, target: CLIPTokenizer) -> None:
with self.setup_adapter(target):
super().__init__(
CLIPTokenizer(
vocabulary_path=target.vocabulary_path,
sequence_length=target.sequence_length,
start_of_text_token_id=target.start_of_text_token_id,
end_of_text_token_id=target.end_of_text_token_id,
pad_token_id=target.pad_token_id,
)
)
def add_token(self, token: str, token_id: int) -> None:
token = token.lower()
tokenizer = self.ensure_find(CLIPTokenizer)
assert token_id not in tokenizer.token_to_id_mapping.values()
tokenizer.token_to_id_mapping[token] = token_id
current_pattern = tokenizer.token_pattern.pattern
new_pattern = re.escape(token) + "|" + current_pattern
tokenizer.token_pattern = re.compile(new_pattern, re.IGNORECASE)
# Define the keyword as its own smallest subtoken
tokenizer.byte_pair_encoding_cache[token] = token
class ConceptExtender(fl.Chain, Adapter[CLIPTextEncoder]):
"""
Extends the vocabulary of a CLIPTextEncoder with one or multiple new concepts, e.g. obtained via the Textual Inversion technique.
Example:
import torch
from imaginairy.vendored.refiners.foundationals.clip.concepts import ConceptExtender
from imaginairy.vendored.refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
from imaginairy.vendored.refiners.fluxion.utils import load_from_safetensors
encoder = CLIPTextEncoderL(device="cuda")
tensors = load_from_safetensors("CLIPTextEncoderL.safetensors")
encoder.load_state_dict(tensors)
cat_embedding = torch.load("cat_embedding.bin")["<this-cat>"]
dog_embedding = torch.load("dog_embedding.bin")["<that-dog>"]
extender = ConceptExtender(encoder)
extender.add_concept(token="<this-cat>", embedding=cat_embedding)
extender.inject()
# New concepts can be added at any time
extender.add_concept(token="<that-dog>", embedding=dog_embedding)
# Now the encoder can be used with the new concepts
"""
def __init__(self, target: CLIPTextEncoder) -> None:
with self.setup_adapter(target):
super().__init__(target)
try:
token_encoder, token_encoder_parent = next(target.walk(TokenEncoder))
self._token_encoder_parent = [token_encoder_parent]
except StopIteration:
raise RuntimeError("TokenEncoder not found.")
try:
clip_tokenizer, clip_tokenizer_parent = next(target.walk(CLIPTokenizer))
self._clip_tokenizer_parent = [clip_tokenizer_parent]
except StopIteration:
raise RuntimeError("Tokenizer not found.")
self._embedding_extender = [EmbeddingExtender(token_encoder)]
self._token_extender = [TokenExtender(clip_tokenizer)]
@property
def embedding_extender(self) -> EmbeddingExtender:
assert len(self._embedding_extender) == 1, "EmbeddingExtender not found."
return self._embedding_extender[0]
@property
def token_extender(self) -> TokenExtender:
assert len(self._token_extender) == 1, "TokenExtender not found."
return self._token_extender[0]
@property
def token_encoder_parent(self) -> fl.Chain:
assert len(self._token_encoder_parent) == 1, "TokenEncoder parent not found."
return self._token_encoder_parent[0]
@property
def clip_tokenizer_parent(self) -> fl.Chain:
assert len(self._clip_tokenizer_parent) == 1, "Tokenizer parent not found."
return self._clip_tokenizer_parent[0]
def add_concept(self, token: str, embedding: Tensor) -> None:
self.embedding_extender.add_embedding(embedding)
self.token_extender.add_token(token, self.embedding_extender.num_embeddings - 1)
def inject(self: "ConceptExtender", parent: fl.Chain | None = None) -> "ConceptExtender":
self.embedding_extender.inject(self.token_encoder_parent)
self.token_extender.inject(self.clip_tokenizer_parent)
return super().inject(parent)
def eject(self) -> None:
self.embedding_extender.eject()
self.token_extender.eject()
super().eject()

@ -0,0 +1,179 @@
from typing import Callable
from torch import Tensor, device as Device, dtype as DType
import imaginairy.vendored.refiners.fluxion.layers as fl
from imaginairy.vendored.refiners.foundationals.clip.common import FeedForward, PositionalEncoder
class ClassToken(fl.Chain):
def __init__(self, embedding_dim: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
self.embedding_dim = embedding_dim
super().__init__(fl.Parameter(1, embedding_dim, device=device, dtype=dtype))
class PatchEncoder(fl.Chain):
def __init__(
self,
in_channels: int,
out_channels: int,
patch_size: int = 16,
use_bias: bool = True,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.in_channels = in_channels
self.out_channels = out_channels
self.patch_size = patch_size
self.use_bias = use_bias
super().__init__(
fl.Conv2d(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=(self.patch_size, self.patch_size),
stride=(self.patch_size, self.patch_size),
use_bias=self.use_bias,
device=device,
dtype=dtype,
),
fl.Permute(0, 2, 3, 1),
)
class TransformerLayer(fl.Chain):
def __init__(
self,
embedding_dim: int = 768,
feedforward_dim: int = 3072,
num_attention_heads: int = 12,
layer_norm_eps: float = 1e-5,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.embedding_dim = embedding_dim
self.feedforward_dim = feedforward_dim
self.num_attention_heads = num_attention_heads
self.layer_norm_eps = layer_norm_eps
super().__init__(
fl.Residual(
fl.LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype),
fl.SelfAttention(
embedding_dim=embedding_dim, num_heads=num_attention_heads, device=device, dtype=dtype
),
),
fl.Residual(
fl.LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype),
FeedForward(embedding_dim=embedding_dim, feedforward_dim=feedforward_dim, device=device, dtype=dtype),
),
)
class ViTEmbeddings(fl.Chain):
def __init__(
self,
image_size: int = 224,
embedding_dim: int = 768,
patch_size: int = 32,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.image_size = image_size
self.embedding_dim = embedding_dim
self.patch_size = patch_size
super().__init__(
fl.Concatenate(
ClassToken(embedding_dim, device=device, dtype=dtype),
fl.Chain(
PatchEncoder(
in_channels=3,
out_channels=embedding_dim,
patch_size=patch_size,
use_bias=False,
device=device,
dtype=dtype,
),
fl.Reshape((image_size // patch_size) ** 2, embedding_dim),
),
dim=1,
),
fl.Residual(
PositionalEncoder(
max_sequence_length=(image_size // patch_size) ** 2 + 1,
embedding_dim=embedding_dim,
device=device,
dtype=dtype,
),
),
)
class CLIPImageEncoder(fl.Chain):
def __init__(
self,
image_size: int = 224,
embedding_dim: int = 768,
output_dim: int = 512,
patch_size: int = 32,
num_layers: int = 12,
num_attention_heads: int = 12,
feedforward_dim: int = 3072,
layer_norm_eps: float = 1e-5,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.image_size = image_size
self.embedding_dim = embedding_dim
self.output_dim = output_dim
self.patch_size = patch_size
self.num_layers = num_layers
self.num_attention_heads = num_attention_heads
self.feedforward_dim = feedforward_dim
cls_token_pooling: Callable[[Tensor], Tensor] = lambda x: x[:, 0, :]
super().__init__(
ViTEmbeddings(
image_size=image_size, embedding_dim=embedding_dim, patch_size=patch_size, device=device, dtype=dtype
),
fl.LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype),
fl.Chain(
TransformerLayer(
embedding_dim=embedding_dim,
feedforward_dim=feedforward_dim,
num_attention_heads=num_attention_heads,
layer_norm_eps=layer_norm_eps,
device=device,
dtype=dtype,
)
for _ in range(num_layers)
),
fl.Lambda(func=cls_token_pooling),
fl.LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype),
fl.Linear(in_features=embedding_dim, out_features=output_dim, bias=False, device=device, dtype=dtype),
)
class CLIPImageEncoderH(CLIPImageEncoder):
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
super().__init__(
embedding_dim=1280,
output_dim=1024,
patch_size=14,
num_layers=32,
num_attention_heads=16,
feedforward_dim=5120,
device=device,
dtype=dtype,
)
class CLIPImageEncoderG(CLIPImageEncoder):
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
super().__init__(
embedding_dim=1664,
output_dim=1280,
patch_size=14,
num_layers=48,
num_attention_heads=16,
feedforward_dim=8192,
device=device,
dtype=dtype,
)

@ -0,0 +1,195 @@
from torch import device as Device, dtype as DType
import imaginairy.vendored.refiners.fluxion.layers as fl
from imaginairy.vendored.refiners.foundationals.clip.common import FeedForward, PositionalEncoder
from imaginairy.vendored.refiners.foundationals.clip.tokenizer import CLIPTokenizer
class TokenEncoder(fl.Embedding):
def __init__(
self,
vocabulary_size: int,
embedding_dim: int,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.vocabulary_size = vocabulary_size
self.embedding_dim = embedding_dim
super().__init__(
num_embeddings=vocabulary_size,
embedding_dim=embedding_dim,
device=device,
dtype=dtype,
)
class TransformerLayer(fl.Chain):
def __init__(
self,
embedding_dim: int,
feedforward_dim: int,
num_attention_heads: int = 1,
layer_norm_eps: float = 1e-5,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.embedding_dim = embedding_dim
self.num_attention_heads = num_attention_heads
self.feedforward_dim = feedforward_dim
self.layer_norm_eps = layer_norm_eps
super().__init__(
fl.Residual(
fl.LayerNorm(
normalized_shape=embedding_dim,
eps=layer_norm_eps,
device=device,
dtype=dtype,
),
fl.SelfAttention(
embedding_dim=embedding_dim,
num_heads=num_attention_heads,
is_causal=True,
device=device,
dtype=dtype,
),
),
fl.Residual(
fl.LayerNorm(
normalized_shape=embedding_dim,
eps=layer_norm_eps,
device=device,
dtype=dtype,
),
FeedForward(
embedding_dim=embedding_dim,
feedforward_dim=feedforward_dim,
device=device,
dtype=dtype,
),
),
)
class CLIPTextEncoder(fl.Chain):
def __init__(
self,
embedding_dim: int = 768,
max_sequence_length: int = 77,
vocabulary_size: int = 49408,
num_layers: int = 12,
num_attention_heads: int = 12,
feedforward_dim: int = 3072,
layer_norm_eps: float = 1e-5,
use_quick_gelu: bool = False,
tokenizer: CLIPTokenizer | None = None,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.embedding_dim = embedding_dim
self.max_sequence_length = max_sequence_length
self.vocabulary_size = vocabulary_size
self.num_layers = num_layers
self.num_attention_heads = num_attention_heads
self.feedforward_dim = feedforward_dim
self.layer_norm_eps = layer_norm_eps
self.use_quick_gelu = use_quick_gelu
super().__init__(
tokenizer or CLIPTokenizer(sequence_length=max_sequence_length),
fl.Converter(set_dtype=False),
fl.Sum(
TokenEncoder(
vocabulary_size=vocabulary_size,
embedding_dim=embedding_dim,
device=device,
dtype=dtype,
),
PositionalEncoder(
max_sequence_length=max_sequence_length,
embedding_dim=embedding_dim,
device=device,
dtype=dtype,
),
),
*(
TransformerLayer(
embedding_dim=embedding_dim,
num_attention_heads=num_attention_heads,
feedforward_dim=feedforward_dim,
layer_norm_eps=layer_norm_eps,
device=device,
dtype=dtype,
)
for _ in range(num_layers)
),
fl.LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype),
)
if use_quick_gelu:
for gelu, parent in self.walk(predicate=lambda m, _: isinstance(m, fl.GeLU)):
parent.replace(old_module=gelu, new_module=fl.ApproximateGeLU())
class CLIPTextEncoderL(CLIPTextEncoder):
"""
CLIPTextEncoderL is the CLIP text encoder with the following parameters:
embedding_dim=768
num_layers=12
num_attention_heads=12
feedforward_dim=3072
use_quick_gelu=True
We replace the GeLU activation function with an approximate GeLU to comply with the original CLIP implementation
of OpenAI (https://github.com/openai/CLIP/blob/main/clip/model.py#L166)
"""
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
super().__init__(
embedding_dim=768,
num_layers=12,
num_attention_heads=12,
feedforward_dim=3072,
use_quick_gelu=True,
device=device,
dtype=dtype,
)
class CLIPTextEncoderH(CLIPTextEncoder):
"""
CLIPTextEncoderH is the CLIP text encoder with the following parameters:
embedding_dim=1024
num_layers=23
num_attention_heads=16
feedforward_dim=4096
"""
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
super().__init__(
embedding_dim=1024,
num_layers=23,
num_attention_heads=16,
feedforward_dim=4096,
device=device,
dtype=dtype,
)
class CLIPTextEncoderG(CLIPTextEncoder):
"""
CLIPTextEncoderG is the CLIP text encoder with the following parameters:
embedding_dim=1280
num_layers=32
num_attention_heads=16
feedforward_dim=5120
"""
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
tokenizer = CLIPTokenizer(pad_token_id=0)
super().__init__(
embedding_dim=1280,
num_layers=32,
num_attention_heads=20,
feedforward_dim=5120,
tokenizer=tokenizer,
device=device,
dtype=dtype,
)

@ -0,0 +1,121 @@
import gzip
import re
from functools import lru_cache
from itertools import islice
from pathlib import Path
from torch import Tensor, tensor
import imaginairy.vendored.refiners.fluxion.layers as fl
from imaginairy.vendored.refiners.fluxion import pad
class CLIPTokenizer(fl.Module):
def __init__(
self,
vocabulary_path: str | Path = Path(__file__).resolve().parent / "bpe_simple_vocab_16e6.txt.gz",
sequence_length: int = 77,
start_of_text_token_id: int = 49406,
end_of_text_token_id: int = 49407,
pad_token_id: int = 49407,
) -> None:
super().__init__()
self.vocabulary_path = vocabulary_path
self.sequence_length = sequence_length
self.byte_to_unicode_mapping = self.get_bytes_to_unicode_mapping()
self.byte_decoder = {v: k for k, v in self.byte_to_unicode_mapping.items()}
merge_tuples = [
tuple(merge.split())
for merge in gzip.open(filename=vocabulary_path)
.read()
.decode(encoding="utf-8")
.split(sep="\n")[1 : 49152 - 256 - 2 + 1]
]
vocabulary = (
list(self.byte_to_unicode_mapping.values())
+ [v + "</w>" for v in self.byte_to_unicode_mapping.values()]
+ ["".join(merge) for merge in merge_tuples]
+ ["", ""]
)
self.token_to_id_mapping = {token: i for i, token in enumerate(iterable=vocabulary)}
self.byte_pair_encoding_ranks = {merge: i for i, merge in enumerate(iterable=merge_tuples)}
self.byte_pair_encoding_cache = {"": ""}
# Note: this regular expression does not support Unicode. It was changed so
# to get rid of the dependence on the `regex` module. Unicode support could
# potentially be added back by leveraging the `\w` character class.
self.token_pattern = re.compile(
pattern=r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[a-zA-Z]+|[0-9]|[^\s\w]+""",
flags=re.IGNORECASE,
)
self.start_of_text_token_id: int = start_of_text_token_id
self.end_of_text_token_id: int = end_of_text_token_id
self.pad_token_id: int = pad_token_id
def forward(self, text: str) -> Tensor:
tokens = self.encode(text=text, max_length=self.sequence_length).unsqueeze(dim=0)
assert (
tokens.shape[1] <= self.sequence_length
), f"Text is too long: tokens.shape[1] > sequence_length: {tokens.shape[1]} > {self.sequence_length}"
return pad(x=tokens, pad=(0, self.sequence_length - tokens.shape[1]), value=self.pad_token_id)
@lru_cache()
def get_bytes_to_unicode_mapping(self) -> dict[int, str]:
initial_byte_values = (
list(range(ord("!"), ord("~") + 1))
+ list(range(ord("¡"), ord("¬") + 1))
+ list(range(ord("®"), ord("ÿ") + 1))
)
extra_unicode_values = (byte for byte in range(2**8) if byte not in initial_byte_values)
byte_values = initial_byte_values + list(extra_unicode_values)
unicode_values = [chr(value) for value in byte_values]
return dict(zip(byte_values, unicode_values))
def byte_pair_encoding(self, token: str) -> str:
if token in self.byte_pair_encoding_cache:
return self.byte_pair_encoding_cache[token]
def recursive_bpe(word: tuple[str, ...]) -> tuple[str, ...]:
if len(word) < 2:
return word
pairs = {(i, (word[i], word[i + 1])) for i in range(len(word) - 1)}
min_pair = min(
pairs,
key=lambda pair: self.byte_pair_encoding_ranks.get(pair[1], float("inf")),
)
if min_pair[1] not in self.byte_pair_encoding_ranks:
return word
new_word: list[str] = []
i = 0
while i < len(word):
if i == min_pair[0]:
new_word.append(min_pair[1][0] + min_pair[1][1])
i += 2
else:
new_word.append(word[i])
i += 1
return recursive_bpe(tuple(new_word))
word = tuple(token[:-1]) + (token[-1] + "</w>",)
result = " ".join(recursive_bpe(word=word))
self.byte_pair_encoding_cache[token] = result
return result
def encode(self, text: str, max_length: int | None = None) -> Tensor:
text = re.sub(pattern=r"\s+", repl=" ", string=text.lower())
tokens = re.findall(pattern=self.token_pattern, string=text)
upper_bound = None
if max_length:
assert max_length >= 2
upper_bound = max_length - 2
encoded_tokens = islice(
(
self.token_to_id_mapping[subtoken]
for token in tokens
for subtoken in self.byte_pair_encoding(
token="".join(self.byte_to_unicode_mapping[character] for character in token.encode("utf-8"))
).split(sep=" ")
),
0,
upper_bound,
)
return tensor(data=[self.start_of_text_token_id, *encoded_tokens, self.end_of_text_token_id])

@ -0,0 +1,29 @@
from .dinov2 import (
DINOv2_base,
DINOv2_base_reg,
DINOv2_large,
DINOv2_large_reg,
DINOv2_small,
DINOv2_small_reg,
)
from .vit import (
ViT,
ViT_base,
ViT_large,
ViT_small,
ViT_tiny,
)
__all__ = [
"DINOv2_base",
"DINOv2_base_reg",
"DINOv2_large",
"DINOv2_large_reg",
"DINOv2_small",
"DINOv2_small_reg",
"ViT",
"ViT_base",
"ViT_large",
"ViT_small",
"ViT_tiny",
]

@ -0,0 +1,148 @@
import torch
from imaginairy.vendored.refiners.foundationals.dinov2.vit import ViT
# TODO: add preprocessing logic like
# https://github.com/facebookresearch/dinov2/blob/2302b6b/dinov2/data/transforms.py#L77
class DINOv2_small(ViT):
def __init__(
self,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
super().__init__(
embedding_dim=384,
patch_size=14,
image_size=518,
num_layers=12,
num_heads=6,
device=device,
dtype=dtype,
)
class DINOv2_base(ViT):
def __init__(
self,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
super().__init__(
embedding_dim=768,
patch_size=14,
image_size=518,
num_layers=12,
num_heads=12,
device=device,
dtype=dtype,
)
class DINOv2_large(ViT):
def __init__(
self,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
super().__init__(
embedding_dim=1024,
patch_size=14,
image_size=518,
num_layers=24,
num_heads=16,
device=device,
dtype=dtype,
)
# TODO: implement SwiGLU layer
# class DINOv2_giant2(ViT):
# def __init__(
# self,
# device: torch.device | str | None = None,
# dtype: torch.dtype | None = None,
# ) -> None:
# super().__init__(
# embedding_dim=1536,
# patch_size=14,
# image_size=518,
# num_layers=40,
# num_heads=24,
# device=device,
# dtype=dtype,
# )
class DINOv2_small_reg(ViT):
def __init__(
self,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
super().__init__(
embedding_dim=384,
patch_size=14,
image_size=518,
num_layers=12,
num_heads=6,
num_registers=4,
device=device,
dtype=dtype,
)
class DINOv2_base_reg(ViT):
def __init__(
self,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
super().__init__(
embedding_dim=768,
patch_size=14,
image_size=518,
num_layers=12,
num_heads=12,
num_registers=4,
device=device,
dtype=dtype,
)
class DINOv2_large_reg(ViT):
def __init__(
self,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
super().__init__(
embedding_dim=1024,
patch_size=14,
image_size=518,
num_layers=24,
num_heads=16,
num_registers=4,
device=device,
dtype=dtype,
)
# TODO: implement SwiGLU layer
# class DINOv2_giant2_reg(ViT):
# def __init__(
# self,
# device: torch.device | str | None = None,
# dtype: torch.dtype | None = None,
# ) -> None:
# super().__init__(
# embedding_dim=1536,
# patch_size=14,
# image_size=518,
# num_layers=40,
# num_heads=24,
# num_registers=4,
# device=device,
# dtype=dtype,
# )

@ -0,0 +1,373 @@
import torch
from torch import Tensor
import imaginairy.vendored.refiners.fluxion.layers as fl
from imaginairy.vendored.refiners.fluxion.layers.activations import Activation
class ClassToken(fl.Chain):
"""Learnable token representing the class of the input."""
def __init__(
self,
embedding_dim: int,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
self.embedding_dim = embedding_dim
super().__init__(
fl.Parameter(
*(1, embedding_dim),
device=device,
dtype=dtype,
),
)
class PositionalEncoder(fl.Residual):
"""Encode the position of each patch in the input."""
def __init__(
self,
sequence_length: int,
embedding_dim: int,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
self.num_patches = sequence_length
self.embedding_dim = embedding_dim
super().__init__(
fl.Parameter(
*(sequence_length, embedding_dim),
device=device,
dtype=dtype,
),
)
class LayerScale(fl.WeightedModule):
"""Scale the input tensor by a learnable parameter."""
def __init__(
self,
embedding_dim: int,
init_value: float = 1.0,
dtype: torch.dtype | None = None,
device: torch.device | str | None = None,
) -> None:
super().__init__()
self.embedding_dim = embedding_dim
self.register_parameter(
name="weight",
param=torch.nn.Parameter(
torch.full(
size=(embedding_dim,),
fill_value=init_value,
dtype=dtype,
device=device,
),
),
)
def forward(self, x: Tensor) -> Tensor:
return x * self.weight
class FeedForward(fl.Chain):
"""Apply two linear transformations interleaved by an activation function."""
def __init__(
self,
embedding_dim: int,
feedforward_dim: int,
activation: Activation = fl.GeLU, # type: ignore
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
self.embedding_dim = embedding_dim
self.feedforward_dim = feedforward_dim
super().__init__(
fl.Linear(
in_features=embedding_dim,
out_features=feedforward_dim,
device=device,
dtype=dtype,
),
activation(),
fl.Linear(
in_features=feedforward_dim,
out_features=embedding_dim,
device=device,
dtype=dtype,
),
)
class PatchEncoder(fl.Chain):
"""Encode an image into a sequence of patches."""
def __init__(
self,
in_channels: int,
out_channels: int,
patch_size: int,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
self.in_channels = in_channels
self.out_channels = out_channels
self.patch_size = patch_size
super().__init__(
fl.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=patch_size,
stride=patch_size,
device=device,
dtype=dtype,
), # (N,3,H,W) -> (N,D,P,P)
fl.Reshape(out_channels, -1), # (N,D,P,P) -> (N,D,P²)
fl.Transpose(1, 2), # (N,D,P²) -> (N,P²,D)
)
class TransformerLayer(fl.Chain):
"""Apply a multi-head self-attention mechanism to the input tensor."""
def __init__(
self,
embedding_dim: int,
num_heads: int,
norm_eps: float,
mlp_ratio: int,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
self.embedding_dim = embedding_dim
self.num_heads = num_heads
self.norm_eps = norm_eps
self.mlp_ratio = mlp_ratio
super().__init__(
fl.Residual(
fl.LayerNorm(
normalized_shape=embedding_dim,
eps=norm_eps,
device=device,
dtype=dtype,
),
fl.SelfAttention(
embedding_dim=embedding_dim,
num_heads=num_heads,
device=device,
dtype=dtype,
),
LayerScale(
embedding_dim=embedding_dim,
device=device,
dtype=dtype,
),
),
fl.Residual(
fl.LayerNorm(
normalized_shape=embedding_dim,
eps=norm_eps,
device=device,
dtype=dtype,
),
FeedForward(
embedding_dim=embedding_dim,
feedforward_dim=embedding_dim * mlp_ratio,
device=device,
dtype=dtype,
),
LayerScale(
embedding_dim=embedding_dim,
device=device,
dtype=dtype,
),
),
)
class Transformer(fl.Chain):
"""Alias for a Chain of TransformerLayer."""
class Registers(fl.Concatenate):
"""Insert register tokens between CLS token and patches."""
def __init__(
self,
num_registers: int,
embedding_dim: int,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
self.num_registers = num_registers
self.embedding_dim = embedding_dim
super().__init__(
fl.Slicing(dim=1, end=1),
fl.Parameter(
*(num_registers, embedding_dim),
device=device,
dtype=dtype,
),
fl.Slicing(dim=1, start=1),
dim=1,
)
class ViT(fl.Chain):
"""Vision Transformer (ViT).
see https://arxiv.org/abs/2010.11929v2
"""
def __init__(
self,
embedding_dim: int = 768,
patch_size: int = 16,
image_size: int = 224,
num_layers: int = 12,
num_heads: int = 12,
norm_eps: float = 1e-6,
mlp_ratio: int = 4,
num_registers: int = 0,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
num_patches = image_size // patch_size
self.embedding_dim = embedding_dim
self.patch_size = patch_size
self.image_size = image_size
self.num_layers = num_layers
self.num_heads = num_heads
self.norm_eps = norm_eps
self.mlp_ratio = mlp_ratio
self.num_registers = num_registers
super().__init__(
fl.Concatenate(
ClassToken(
embedding_dim=embedding_dim,
device=device,
dtype=dtype,
),
PatchEncoder(
in_channels=3,
out_channels=embedding_dim,
patch_size=patch_size,
device=device,
dtype=dtype,
),
dim=1,
),
# TODO: support https://github.com/facebookresearch/dinov2/blob/2302b6b/dinov2/models/vision_transformer.py#L179
PositionalEncoder(
sequence_length=num_patches**2 + 1,
embedding_dim=embedding_dim,
device=device,
dtype=dtype,
),
Transformer(
TransformerLayer(
embedding_dim=embedding_dim,
num_heads=num_heads,
norm_eps=norm_eps,
mlp_ratio=mlp_ratio,
device=device,
dtype=dtype,
)
for _ in range(num_layers)
),
fl.LayerNorm(
normalized_shape=embedding_dim,
eps=norm_eps,
device=device,
dtype=dtype,
),
)
if self.num_registers > 0:
registers = Registers(
num_registers=num_registers,
embedding_dim=embedding_dim,
device=device,
dtype=dtype,
)
self.insert_before_type(Transformer, registers)
class ViT_tiny(ViT):
def __init__(
self,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
super().__init__(
embedding_dim=192,
patch_size=16,
image_size=224,
num_layers=12,
num_heads=3,
device=device,
dtype=dtype,
)
class ViT_small(ViT):
def __init__(
self,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
super().__init__(
embedding_dim=384,
patch_size=16,
image_size=224,
num_layers=12,
num_heads=6,
device=device,
dtype=dtype,
)
class ViT_base(ViT):
def __init__(
self,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
super().__init__(
embedding_dim=768,
patch_size=16,
image_size=224,
num_layers=12,
num_heads=12,
device=device,
dtype=dtype,
)
class ViT_large(ViT):
def __init__(
self,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
super().__init__(
embedding_dim=1024,
patch_size=16,
image_size=224,
num_layers=24,
num_heads=16,
device=device,
dtype=dtype,
)

@ -0,0 +1,40 @@
from imaginairy.vendored.refiners.foundationals.clip.text_encoder import (
CLIPTextEncoderL,
)
from imaginairy.vendored.refiners.foundationals.latent_diffusion.auto_encoder import (
LatentDiffusionAutoencoder,
)
from imaginairy.vendored.refiners.foundationals.latent_diffusion.freeu import SDFreeUAdapter
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers import DPMSolver, Scheduler
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1 import (
SD1ControlnetAdapter,
SD1IPAdapter,
SD1T2IAdapter,
SD1UNet,
StableDiffusion_1,
StableDiffusion_1_Inpainting,
)
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_xl import (
DoubleTextEncoder,
SDXLIPAdapter,
SDXLT2IAdapter,
SDXLUNet,
)
__all__ = [
"StableDiffusion_1",
"StableDiffusion_1_Inpainting",
"SD1UNet",
"SD1ControlnetAdapter",
"SD1IPAdapter",
"SD1T2IAdapter",
"SDXLUNet",
"DoubleTextEncoder",
"SDXLIPAdapter",
"SDXLT2IAdapter",
"DPMSolver",
"Scheduler",
"CLIPTextEncoderL",
"LatentDiffusionAutoencoder",
"SDFreeUAdapter",
]

@ -0,0 +1,221 @@
from PIL import Image
from torch import Tensor, device as Device, dtype as DType
from imaginairy.vendored.refiners.fluxion.context import Contexts
from imaginairy.vendored.refiners.fluxion.layers import (
Chain,
Conv2d,
Downsample,
GroupNorm,
Identity,
Residual,
SelfAttention2d,
SiLU,
Slicing,
Sum,
Upsample,
)
from imaginairy.vendored.refiners.fluxion.utils import image_to_tensor, tensor_to_image
class Resnet(Sum):
def __init__(
self,
in_channels: int,
out_channels: int,
num_groups: int = 32,
device: Device | str | None = None,
dtype: DType | None = None,
):
self.in_channels = in_channels
self.out_channels = out_channels
shortcut = (
Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, device=device, dtype=dtype)
if in_channels != out_channels
else Identity()
)
super().__init__(
shortcut,
Chain(
GroupNorm(channels=in_channels, num_groups=num_groups, device=device, dtype=dtype),
SiLU(),
Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
padding=1,
device=device,
dtype=dtype,
),
GroupNorm(channels=out_channels, num_groups=num_groups, device=device, dtype=dtype),
SiLU(),
Conv2d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
padding=1,
device=device,
dtype=dtype,
),
),
)
class Encoder(Chain):
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
resnet_sizes: list[int] = [128, 256, 512, 512, 512]
input_channels: int = 3
latent_dim: int = 8
resnet_layers: list[Chain] = [
Chain(
[
Resnet(
in_channels=resnet_sizes[i - 1] if i > 0 else resnet_sizes[0],
out_channels=resnet_sizes[i],
device=device,
dtype=dtype,
),
Resnet(in_channels=resnet_sizes[i], out_channels=resnet_sizes[i], device=device, dtype=dtype),
]
)
for i in range(len(resnet_sizes))
]
for _, layer in zip(range(3), resnet_layers):
channels: int = layer[-1].out_channels # type: ignore
layer.append(Downsample(channels=channels, scale_factor=2, device=device, dtype=dtype))
attention_layer = Residual(
GroupNorm(channels=resnet_sizes[-1], num_groups=32, eps=1e-6, device=device, dtype=dtype),
SelfAttention2d(channels=resnet_sizes[-1], device=device, dtype=dtype),
)
resnet_layers[-1].insert_after_type(Resnet, attention_layer)
super().__init__(
Conv2d(
in_channels=input_channels,
out_channels=resnet_sizes[0],
kernel_size=3,
padding=1,
device=device,
dtype=dtype,
),
Chain(*resnet_layers),
Chain(
GroupNorm(channels=resnet_sizes[-1], num_groups=32, eps=1e-6, device=device, dtype=dtype),
SiLU(),
Conv2d(
in_channels=resnet_sizes[-1],
out_channels=latent_dim,
kernel_size=3,
padding=1,
device=device,
dtype=dtype,
),
),
Chain(
Conv2d(in_channels=8, out_channels=8, kernel_size=1, device=device, dtype=dtype),
Slicing(dim=1, end=4),
),
)
def init_context(self) -> Contexts:
return {"sampling": {"shapes": []}}
class Decoder(Chain):
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
self.resnet_sizes: list[int] = [128, 256, 512, 512, 512]
self.latent_dim: int = 4
self.output_channels: int = 3
resnet_sizes = self.resnet_sizes[::-1]
resnet_layers: list[Chain] = [
(
Chain(
[
Resnet(
in_channels=resnet_sizes[i - 1] if i > 0 else resnet_sizes[0],
out_channels=resnet_sizes[i],
device=device,
dtype=dtype,
),
Resnet(in_channels=resnet_sizes[i], out_channels=resnet_sizes[i], device=device, dtype=dtype),
Resnet(in_channels=resnet_sizes[i], out_channels=resnet_sizes[i], device=device, dtype=dtype),
]
)
if i > 0
else Chain(
[
Resnet(in_channels=resnet_sizes[0], out_channels=resnet_sizes[i], device=device, dtype=dtype),
Resnet(in_channels=resnet_sizes[i], out_channels=resnet_sizes[i], device=device, dtype=dtype),
]
)
)
for i in range(len(resnet_sizes))
]
attention_layer = Residual(
GroupNorm(channels=resnet_sizes[0], num_groups=32, eps=1e-6, device=device, dtype=dtype),
SelfAttention2d(channels=resnet_sizes[0], device=device, dtype=dtype),
)
resnet_layers[0].insert(1, attention_layer)
for _, layer in zip(range(3), resnet_layers[1:]):
channels: int = layer[-1].out_channels
layer.insert(-1, Upsample(channels=channels, upsample_factor=2, device=device, dtype=dtype))
super().__init__(
Conv2d(
in_channels=self.latent_dim, out_channels=self.latent_dim, kernel_size=1, device=device, dtype=dtype
),
Conv2d(
in_channels=self.latent_dim,
out_channels=resnet_sizes[0],
kernel_size=3,
padding=1,
device=device,
dtype=dtype,
),
Chain(*resnet_layers),
Chain(
GroupNorm(channels=resnet_sizes[-1], num_groups=32, eps=1e-6, device=device, dtype=dtype),
SiLU(),
Conv2d(
in_channels=resnet_sizes[-1],
out_channels=self.output_channels,
kernel_size=3,
padding=1,
device=device,
dtype=dtype,
),
),
)
class LatentDiffusionAutoencoder(Chain):
encoder_scale = 0.18125
def __init__(
self,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__(
Encoder(device=device, dtype=dtype),
Decoder(device=device, dtype=dtype),
)
def encode(self, x: Tensor) -> Tensor:
encoder = self[0]
x = self.encoder_scale * encoder(x)
return x
def decode(self, x: Tensor) -> Tensor:
decoder = self[1]
x = decoder(x / self.encoder_scale)
return x
def encode_image(self, image: Image.Image) -> Tensor:
x = image_to_tensor(image, device=self.device, dtype=self.dtype)
x = 2 * x - 1
return self.encode(x)
def decode_latents(self, x: Tensor) -> Image.Image:
x = self.decode(x)
x = (x + 1) / 2
return tensor_to_image(x)

@ -0,0 +1,175 @@
from torch import Size, Tensor, device as Device, dtype as DType
from imaginairy.vendored.refiners.fluxion.context import Contexts
from imaginairy.vendored.refiners.fluxion.layers import (
GLU,
Attention,
Chain,
Conv2d,
Flatten,
GeLU,
GroupNorm,
Identity,
LayerNorm,
Linear,
Parallel,
Residual,
SelfAttention,
SetContext,
Transpose,
Unflatten,
UseContext,
)
class CrossAttentionBlock(Chain):
def __init__(
self,
embedding_dim: int,
context_embedding_dim: int,
context_key: str,
num_heads: int = 1,
use_bias: bool = True,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.embedding_dim = embedding_dim
self.context_embedding_dim = context_embedding_dim
self.context = "cross_attention_block"
self.context_key = context_key
self.num_heads = num_heads
self.use_bias = use_bias
super().__init__(
Residual(
LayerNorm(normalized_shape=embedding_dim, device=device, dtype=dtype),
SelfAttention(
embedding_dim=embedding_dim, num_heads=num_heads, use_bias=use_bias, device=device, dtype=dtype
),
),
Residual(
LayerNorm(normalized_shape=embedding_dim, device=device, dtype=dtype),
Parallel(
Identity(),
UseContext(context=self.context, key=context_key),
UseContext(context=self.context, key=context_key),
),
Attention(
embedding_dim=embedding_dim,
num_heads=num_heads,
key_embedding_dim=context_embedding_dim,
value_embedding_dim=context_embedding_dim,
use_bias=use_bias,
device=device,
dtype=dtype,
),
),
Residual(
LayerNorm(normalized_shape=embedding_dim, device=device, dtype=dtype),
Linear(in_features=embedding_dim, out_features=2 * 4 * embedding_dim, device=device, dtype=dtype),
GLU(GeLU()),
Linear(in_features=4 * embedding_dim, out_features=embedding_dim, device=device, dtype=dtype),
),
)
class StatefulFlatten(Chain):
def __init__(self, context: str, key: str, start_dim: int = 0, end_dim: int = -1) -> None:
self.start_dim = start_dim
self.end_dim = end_dim
super().__init__(
SetContext(context=context, key=key, callback=self.push),
Flatten(start_dim=start_dim, end_dim=end_dim),
)
def push(self, sizes: list[Size], x: Tensor) -> None:
sizes.append(
x.shape[slice(self.start_dim, self.end_dim + 1 if self.end_dim >= 0 else x.ndim + self.end_dim + 1)]
)
class CrossAttentionBlock2d(Residual):
def __init__(
self,
channels: int,
context_embedding_dim: int,
context_key: str,
num_attention_heads: int = 1,
num_attention_layers: int = 1,
num_groups: int = 32,
use_bias: bool = True,
use_linear_projection: bool = False,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
assert channels % num_attention_heads == 0, "in_channels must be divisible by num_attention_heads"
self.channels = channels
self.in_channels = channels
self.out_channels = channels
self.context_embedding_dim = context_embedding_dim
self.num_attention_heads = num_attention_heads
self.num_attention_layers = num_attention_layers
self.num_groups = num_groups
self.use_bias = use_bias
self.context_key = context_key
self.use_linear_projection = use_linear_projection
self.projection_type = "Linear" if use_linear_projection else "Conv2d"
in_block = (
Chain(
GroupNorm(channels=channels, num_groups=num_groups, eps=1e-6, device=device, dtype=dtype),
StatefulFlatten(context="flatten", key="sizes", start_dim=2),
Transpose(1, 2),
Linear(in_features=channels, out_features=channels, device=device, dtype=dtype),
)
if use_linear_projection
else Chain(
GroupNorm(channels=channels, num_groups=num_groups, eps=1e-6, device=device, dtype=dtype),
Conv2d(in_channels=channels, out_channels=channels, kernel_size=1, device=device, dtype=dtype),
StatefulFlatten(context="flatten", key="sizes", start_dim=2),
Transpose(1, 2),
)
)
out_block = (
Chain(
Linear(in_features=channels, out_features=channels, device=device, dtype=dtype),
Transpose(1, 2),
Parallel(
Identity(),
UseContext(context="flatten", key="sizes").compose(lambda x: x.pop()),
),
Unflatten(dim=2),
)
if use_linear_projection
else Chain(
Transpose(1, 2),
Parallel(
Identity(),
UseContext(context="flatten", key="sizes").compose(lambda x: x.pop()),
),
Unflatten(dim=2),
Conv2d(in_channels=channels, out_channels=channels, kernel_size=1, device=device, dtype=dtype),
)
)
super().__init__(
in_block,
Chain(
CrossAttentionBlock(
embedding_dim=channels,
context_embedding_dim=context_embedding_dim,
context_key=context_key,
num_heads=num_attention_heads,
use_bias=use_bias,
device=device,
dtype=dtype,
)
for _ in range(num_attention_layers)
),
out_block,
)
def init_context(self) -> Contexts:
return {"flatten": {"sizes": []}}

@ -0,0 +1,94 @@
import math
from typing import Any, Callable, Generic, TypeVar
import torch
from torch import Tensor
from torch.fft import fftn, fftshift, ifftn, ifftshift # type: ignore
import imaginairy.vendored.refiners.fluxion.layers as fl
from imaginairy.vendored.refiners.fluxion.adapters.adapter import Adapter
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import ResidualConcatenator, SD1UNet
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
T = TypeVar("T", bound="SD1UNet | SDXLUNet")
TSDFreeUAdapter = TypeVar("TSDFreeUAdapter", bound="SDFreeUAdapter[Any]") # Self (see PEP 673)
def fourier_filter(x: Tensor, scale: float = 1, threshold: int = 1) -> Tensor:
"""Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497).
This version of the method comes from here:
https://github.com/ChenyangSi/FreeU/blob/main/demo/free_lunch_utils.py#L23
"""
batch, channels, height, width = x.shape
dtype = x.dtype
device = x.device
if not (math.log2(height).is_integer() and math.log2(width).is_integer()):
x = x.to(dtype=torch.float32)
x_freq = fftn(x, dim=(-2, -1)) # type: ignore
x_freq = fftshift(x_freq, dim=(-2, -1)) # type: ignore
mask = torch.ones((batch, channels, height, width), device=device) # type: ignore
center_row, center_col = height // 2, width // 2 # type: ignore
mask[..., center_row - threshold : center_row + threshold, center_col - threshold : center_col + threshold] = scale
x_freq = x_freq * mask # type: ignore
x_freq = ifftshift(x_freq, dim=(-2, -1)) # type: ignore
x_filtered = ifftn(x_freq, dim=(-2, -1)).real # type: ignore
return x_filtered.to(dtype=dtype) # type: ignore
class FreeUBackboneFeatures(fl.Module):
def __init__(self, backbone_scale: float) -> None:
super().__init__()
self.backbone_scale = backbone_scale
def forward(self, x: Tensor) -> Tensor:
num_half_channels = x.shape[1] // 2
x[:, :num_half_channels] = x[:, :num_half_channels] * self.backbone_scale
return x
class FreeUSkipFeatures(fl.Chain):
def __init__(self, n: int, skip_scale: float) -> None:
apply_filter: Callable[[Tensor], Tensor] = lambda x: fourier_filter(x, scale=skip_scale)
super().__init__(
fl.UseContext(context="unet", key="residuals").compose(lambda residuals: residuals[n]),
fl.Lambda(apply_filter),
)
class FreeUResidualConcatenator(fl.Concatenate):
def __init__(self, n: int, backbone_scale: float, skip_scale: float) -> None:
super().__init__(
FreeUBackboneFeatures(backbone_scale),
FreeUSkipFeatures(n, skip_scale),
dim=1,
)
class SDFreeUAdapter(Generic[T], fl.Chain, Adapter[T]):
def __init__(self, target: T, backbone_scales: list[float], skip_scales: list[float]) -> None:
assert len(backbone_scales) == len(skip_scales)
assert len(backbone_scales) <= len(target.UpBlocks)
self.backbone_scales = backbone_scales
self.skip_scales = skip_scales
with self.setup_adapter(target):
super().__init__(target)
def inject(self: TSDFreeUAdapter, parent: fl.Chain | None = None) -> TSDFreeUAdapter:
for n, (backbone_scale, skip_scale) in enumerate(zip(self.backbone_scales, self.skip_scales)):
block = self.target.UpBlocks[n]
concat = block.ensure_find(ResidualConcatenator)
block.replace(concat, FreeUResidualConcatenator(-n - 2, backbone_scale, skip_scale))
return super().inject(parent)
def eject(self) -> None:
for n in range(len(self.backbone_scales)):
block = self.target.UpBlocks[n]
concat = block.ensure_find(FreeUResidualConcatenator)
block.replace(concat, ResidualConcatenator(-n - 2))
super().eject()

@ -0,0 +1,465 @@
import math
from enum import IntEnum
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar
from jaxtyping import Float
from PIL import Image
from torch import Tensor, cat, device as Device, dtype as DType, softmax, zeros_like
import imaginairy.vendored.refiners.fluxion.layers as fl
from imaginairy.vendored.refiners.fluxion.adapters.adapter import Adapter
from imaginairy.vendored.refiners.fluxion.adapters.lora import Lora
from imaginairy.vendored.refiners.fluxion.context import Contexts
from imaginairy.vendored.refiners.fluxion.layers.attentions import ScaledDotProductAttention
from imaginairy.vendored.refiners.fluxion.utils import image_to_tensor, normalize
from imaginairy.vendored.refiners.foundationals.clip.image_encoder import CLIPImageEncoderH
if TYPE_CHECKING:
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
T = TypeVar("T", bound="SD1UNet | SDXLUNet")
TIPAdapter = TypeVar("TIPAdapter", bound="IPAdapter[Any]") # Self (see PEP 673)
class ImageProjection(fl.Chain):
def __init__(
self,
clip_image_embedding_dim: int = 1024,
clip_text_embedding_dim: int = 768,
num_tokens: int = 4,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.clip_image_embedding_dim = clip_image_embedding_dim
self.clip_text_embedding_dim = clip_text_embedding_dim
self.num_tokens = num_tokens
super().__init__(
fl.Linear(
in_features=clip_image_embedding_dim,
out_features=clip_text_embedding_dim * num_tokens,
device=device,
dtype=dtype,
),
fl.Reshape(num_tokens, clip_text_embedding_dim),
fl.LayerNorm(normalized_shape=clip_text_embedding_dim, device=device, dtype=dtype),
)
class FeedForward(fl.Chain):
def __init__(
self,
embedding_dim: int,
feedforward_dim: int,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.embedding_dim = embedding_dim
self.feedforward_dim = feedforward_dim
super().__init__(
fl.Linear(
in_features=self.embedding_dim,
out_features=self.feedforward_dim,
bias=False,
device=device,
dtype=dtype,
),
fl.GeLU(),
fl.Linear(
in_features=self.feedforward_dim,
out_features=self.embedding_dim,
bias=False,
device=device,
dtype=dtype,
),
)
# Adapted from https://github.com/tencent-ailab/IP-Adapter/blob/6212981/ip_adapter/resampler.py
# See also:
# - https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
# - https://github.com/lucidrains/flamingo-pytorch
class PerceiverScaledDotProductAttention(fl.Module):
def __init__(self, head_dim: int, num_heads: int) -> None:
super().__init__()
self.num_heads = num_heads
# See https://github.com/tencent-ailab/IP-Adapter/blob/6212981/ip_adapter/resampler.py#L69
# -> "More stable with f16 than dividing afterwards"
self.scale = 1 / math.sqrt(math.sqrt(head_dim))
def forward(
self,
key_value: Float[Tensor, "batch sequence_length 2*head_dim*num_heads"],
query: Float[Tensor, "batch num_tokens head_dim*num_heads"],
) -> Float[Tensor, "batch num_tokens head_dim*num_heads"]:
bs, length, _ = query.shape
key, value = key_value.chunk(2, dim=-1)
q = self.reshape_tensor(query)
k = self.reshape_tensor(key)
v = self.reshape_tensor(value)
attention = (q * self.scale) @ (k * self.scale).transpose(-2, -1)
attention = softmax(input=attention.float(), dim=-1).type(attention.dtype)
attention = attention @ v
return attention.permute(0, 2, 1, 3).reshape(bs, length, -1)
def reshape_tensor(
self, x: Float[Tensor, "batch length head_dim*num_heads"]
) -> Float[Tensor, "batch num_heads length head_dim"]:
bs, length, _ = x.shape
x = x.view(bs, length, self.num_heads, -1)
x = x.transpose(1, 2)
x = x.reshape(bs, self.num_heads, length, -1)
return x
class PerceiverAttention(fl.Chain):
def __init__(
self,
embedding_dim: int,
head_dim: int = 64,
num_heads: int = 8,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.embedding_dim = embedding_dim
self.head_dim = head_dim
self.inner_dim = head_dim * num_heads
super().__init__(
fl.Distribute(
fl.LayerNorm(normalized_shape=self.embedding_dim, device=device, dtype=dtype),
fl.LayerNorm(normalized_shape=self.embedding_dim, device=device, dtype=dtype),
),
fl.Parallel(
fl.Chain(
fl.Lambda(func=self.to_kv),
fl.Linear(
in_features=self.embedding_dim,
out_features=2 * self.inner_dim,
bias=False,
device=device,
dtype=dtype,
), # Wkv
),
fl.Chain(
fl.GetArg(index=1),
fl.Linear(
in_features=self.embedding_dim,
out_features=self.inner_dim,
bias=False,
device=device,
dtype=dtype,
), # Wq
),
),
PerceiverScaledDotProductAttention(head_dim=head_dim, num_heads=num_heads),
fl.Linear(
in_features=self.inner_dim, out_features=self.embedding_dim, bias=False, device=device, dtype=dtype
),
)
def to_kv(self, x: Tensor, latents: Tensor) -> Tensor:
return cat((x, latents), dim=-2)
class LatentsToken(fl.Chain):
def __init__(
self, num_tokens: int, latents_dim: int, device: Device | str | None = None, dtype: DType | None = None
) -> None:
self.num_tokens = num_tokens
self.latents_dim = latents_dim
super().__init__(fl.Parameter(num_tokens, latents_dim, device=device, dtype=dtype))
class Transformer(fl.Chain):
pass
class TransformerLayer(fl.Chain):
pass
class PerceiverResampler(fl.Chain):
def __init__(
self,
latents_dim: int = 1024,
num_attention_layers: int = 8,
num_attention_heads: int = 16,
head_dim: int = 64,
num_tokens: int = 8,
input_dim: int = 768,
output_dim: int = 1024,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.latents_dim = latents_dim
self.num_attention_layers = num_attention_layers
self.head_dim = head_dim
self.num_attention_heads = num_attention_heads
self.num_tokens = num_tokens
self.input_dim = input_dim
self.output_dim = output_dim
self.feedforward_dim = 4 * self.latents_dim
super().__init__(
fl.Linear(in_features=input_dim, out_features=latents_dim, device=device, dtype=dtype),
fl.SetContext(context="perceiver_resampler", key="x"),
LatentsToken(num_tokens, latents_dim, device=device, dtype=dtype),
Transformer(
TransformerLayer(
fl.Residual(
fl.Parallel(fl.UseContext(context="perceiver_resampler", key="x"), fl.Identity()),
PerceiverAttention(
embedding_dim=latents_dim,
head_dim=head_dim,
num_heads=num_attention_heads,
device=device,
dtype=dtype,
),
),
fl.Residual(
fl.LayerNorm(normalized_shape=latents_dim, device=device, dtype=dtype),
FeedForward(
embedding_dim=latents_dim, feedforward_dim=self.feedforward_dim, device=device, dtype=dtype
),
),
)
for _ in range(num_attention_layers)
),
fl.Linear(in_features=latents_dim, out_features=output_dim, device=device, dtype=dtype),
fl.LayerNorm(normalized_shape=output_dim, device=device, dtype=dtype),
)
def init_context(self) -> Contexts:
return {"perceiver_resampler": {"x": None}}
class _CrossAttnIndex(IntEnum):
TXT_CROSS_ATTN = 0 # text cross-attention
IMG_CROSS_ATTN = 1 # image cross-attention
class InjectionPoint(fl.Chain):
pass
class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
def __init__(
self,
target: fl.Attention,
text_sequence_length: int = 77,
image_sequence_length: int = 4,
scale: float = 1.0,
) -> None:
self.text_sequence_length = text_sequence_length
self.image_sequence_length = image_sequence_length
self.scale = scale
with self.setup_adapter(target):
super().__init__(
fl.Distribute(
# Note: the same query is used for image cross-attention as for text cross-attention
InjectionPoint(), # Wq
fl.Parallel(
fl.Chain(
fl.Slicing(dim=1, end=text_sequence_length),
InjectionPoint(), # Wk
),
fl.Chain(
fl.Slicing(dim=1, start=text_sequence_length),
fl.Linear(
in_features=self.target.key_embedding_dim,
out_features=self.target.inner_dim,
bias=self.target.use_bias,
device=target.device,
dtype=target.dtype,
), # Wk'
),
),
fl.Parallel(
fl.Chain(
fl.Slicing(dim=1, end=text_sequence_length),
InjectionPoint(), # Wv
),
fl.Chain(
fl.Slicing(dim=1, start=text_sequence_length),
fl.Linear(
in_features=self.target.key_embedding_dim,
out_features=self.target.inner_dim,
bias=self.target.use_bias,
device=target.device,
dtype=target.dtype,
), # Wv'
),
),
),
fl.Sum(
fl.Chain(
fl.Lambda(func=partial(self.select_qkv, index=_CrossAttnIndex.TXT_CROSS_ATTN)),
ScaledDotProductAttention(num_heads=target.num_heads, is_causal=target.is_causal),
),
fl.Chain(
fl.Lambda(func=partial(self.select_qkv, index=_CrossAttnIndex.IMG_CROSS_ATTN)),
ScaledDotProductAttention(num_heads=target.num_heads, is_causal=target.is_causal),
fl.Lambda(func=self.scale_outputs),
),
),
InjectionPoint(), # proj
)
def select_qkv(
self, query: Tensor, keys: tuple[Tensor, Tensor], values: tuple[Tensor, Tensor], index: _CrossAttnIndex
) -> tuple[Tensor, Tensor, Tensor]:
return (query, keys[index.value], values[index.value])
def scale_outputs(self, x: Tensor) -> Tensor:
return x * self.scale
def _predicate(self, k: type[fl.Module]) -> Callable[[fl.Module, fl.Chain], bool]:
def f(m: fl.Module, _: fl.Chain) -> bool:
if isinstance(m, Lora): # do not adapt LoRAs
raise StopIteration
return isinstance(m, k)
return f
def _target_linears(self) -> list[fl.Linear]:
return [m for m, _ in self.target.walk(self._predicate(fl.Linear)) if isinstance(m, fl.Linear)]
def inject(self: "CrossAttentionAdapter", parent: fl.Chain | None = None) -> "CrossAttentionAdapter":
linears = self._target_linears()
assert len(linears) == 4 # Wq, Wk, Wv and Proj
injection_points = list(self.layers(InjectionPoint))
assert len(injection_points) == 4
for linear, ip in zip(linears, injection_points):
ip.append(linear)
assert len(ip) == 1
return super().inject(parent)
def eject(self) -> None:
injection_points = list(self.layers(InjectionPoint))
assert len(injection_points) == 4
for ip in injection_points:
ip.pop()
assert len(ip) == 0
super().eject()
class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
# Prevent PyTorch module registration
_clip_image_encoder: list[CLIPImageEncoderH]
_grid_image_encoder: list[CLIPImageEncoderH]
_image_proj: list[fl.Module]
def __init__(
self,
target: T,
clip_image_encoder: CLIPImageEncoderH,
image_proj: fl.Module,
scale: float = 1.0,
fine_grained: bool = False,
weights: dict[str, Tensor] | None = None,
) -> None:
with self.setup_adapter(target):
super().__init__(target)
self.fine_grained = fine_grained
self._clip_image_encoder = [clip_image_encoder]
if fine_grained:
self._grid_image_encoder = [self.convert_to_grid_features(clip_image_encoder)]
self._image_proj = [image_proj]
self.sub_adapters = [
CrossAttentionAdapter(target=cross_attn, scale=scale, image_sequence_length=self.image_proj.num_tokens)
for cross_attn in filter(lambda attn: type(attn) != fl.SelfAttention, target.layers(fl.Attention))
]
if weights is not None:
image_proj_state_dict: dict[str, Tensor] = {
k.removeprefix("image_proj."): v for k, v in weights.items() if k.startswith("image_proj.")
}
self.image_proj.load_state_dict(image_proj_state_dict)
for i, cross_attn in enumerate(self.sub_adapters):
cross_attn_state_dict: dict[str, Tensor] = {}
for k, v in weights.items():
prefix = f"ip_adapter.{i:03d}."
if not k.startswith(prefix):
continue
cross_attn_state_dict[k.removeprefix(prefix)] = v
cross_attn.load_state_dict(state_dict=cross_attn_state_dict)
@property
def clip_image_encoder(self) -> CLIPImageEncoderH:
return self._clip_image_encoder[0]
@property
def grid_image_encoder(self) -> CLIPImageEncoderH:
assert hasattr(self, "_grid_image_encoder")
return self._grid_image_encoder[0]
@property
def image_proj(self) -> fl.Module:
return self._image_proj[0]
def inject(self: "TIPAdapter", parent: fl.Chain | None = None) -> "TIPAdapter":
for adapter in self.sub_adapters:
adapter.inject()
return super().inject(parent)
def eject(self) -> None:
for adapter in self.sub_adapters:
adapter.eject()
super().eject()
def set_scale(self, scale: float) -> None:
for cross_attn in self.sub_adapters:
cross_attn.scale = scale
# These should be concatenated to the CLIP text embedding before setting the UNet context
def compute_clip_image_embedding(self, image_prompt: Tensor) -> Tensor:
image_encoder = self.clip_image_encoder if not self.fine_grained else self.grid_image_encoder
clip_embedding = image_encoder(image_prompt)
conditional_embedding = self.image_proj(clip_embedding)
if not self.fine_grained:
negative_embedding = self.image_proj(zeros_like(clip_embedding))
else:
# See https://github.com/tencent-ailab/IP-Adapter/blob/d580c50/tutorial_train_plus.py#L351-L352
clip_embedding = image_encoder(zeros_like(image_prompt))
negative_embedding = self.image_proj(clip_embedding)
return cat((negative_embedding, conditional_embedding))
def preprocess_image(
self,
image: Image.Image,
size: tuple[int, int] = (224, 224),
mean: list[float] | None = None,
std: list[float] | None = None,
) -> Tensor:
# Default mean and std are parameters from https://github.com/openai/CLIP
return normalize(
image_to_tensor(image.resize(size), device=self.target.device, dtype=self.target.dtype),
mean=[0.48145466, 0.4578275, 0.40821073] if mean is None else mean,
std=[0.26862954, 0.26130258, 0.27577711] if std is None else std,
)
@staticmethod
def convert_to_grid_features(clip_image_encoder: CLIPImageEncoderH) -> CLIPImageEncoderH:
encoder_clone = clip_image_encoder.structural_copy()
assert isinstance(encoder_clone[-1], fl.Linear) # final proj
assert isinstance(encoder_clone[-2], fl.LayerNorm) # final normalization
assert isinstance(encoder_clone[-3], fl.Lambda) # pooling (classif token)
for _ in range(3):
encoder_clone.pop()
transfomer_layers = encoder_clone[-1]
assert isinstance(transfomer_layers, fl.Chain) and len(transfomer_layers) == 32
transfomer_layers.pop()
return encoder_clone

@ -0,0 +1,146 @@
from enum import Enum
from pathlib import Path
from typing import Callable, Iterator
from torch import Tensor
import imaginairy.vendored.refiners.fluxion.layers as fl
from imaginairy.vendored.refiners.fluxion.adapters.adapter import Adapter
from imaginairy.vendored.refiners.fluxion.adapters.lora import Lora, LoraAdapter
from imaginairy.vendored.refiners.fluxion.utils import load_from_safetensors, load_metadata_from_safetensors
from imaginairy.vendored.refiners.foundationals.clip.text_encoder import FeedForward, TransformerLayer
from imaginairy.vendored.refiners.foundationals.latent_diffusion import (
CLIPTextEncoderL,
LatentDiffusionAutoencoder,
SD1UNet,
StableDiffusion_1,
)
from imaginairy.vendored.refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import Controlnet
MODELS = ["unet", "text_encoder", "lda"]
class LoraTarget(str, Enum):
Self = "self"
Attention = "Attention"
SelfAttention = "SelfAttention"
CrossAttention = "CrossAttentionBlock2d"
FeedForward = "FeedForward"
TransformerLayer = "TransformerLayer"
def get_class(self) -> type[fl.Chain]:
match self:
case LoraTarget.Self:
return fl.Chain
case LoraTarget.Attention:
return fl.Attention
case LoraTarget.SelfAttention:
return fl.SelfAttention
case LoraTarget.CrossAttention:
return CrossAttentionBlock2d
case LoraTarget.FeedForward:
return FeedForward
case LoraTarget.TransformerLayer:
return TransformerLayer
def _predicate(k: type[fl.Module]) -> Callable[[fl.Module, fl.Chain], bool]:
def f(m: fl.Module, _: fl.Chain) -> bool:
if isinstance(m, Lora): # do not adapt other LoRAs
raise StopIteration
if isinstance(m, Controlnet): # do not adapt Controlnet linears
raise StopIteration
return isinstance(m, k)
return f
def _iter_linears(module: fl.Chain) -> Iterator[tuple[fl.Linear, fl.Chain]]:
for m, p in module.walk(_predicate(fl.Linear)):
assert isinstance(m, fl.Linear)
yield (m, p)
def lora_targets(
module: fl.Chain,
target: LoraTarget | list[LoraTarget],
) -> Iterator[tuple[fl.Linear, fl.Chain]]:
if isinstance(target, list):
for t in target:
yield from lora_targets(module, t)
return
if target == LoraTarget.Self:
yield from _iter_linears(module)
return
for layer, _ in module.walk(_predicate(target.get_class())):
assert isinstance(layer, fl.Chain)
yield from _iter_linears(layer)
class SD1LoraAdapter(fl.Chain, Adapter[StableDiffusion_1]):
metadata: dict[str, str] | None
tensors: dict[str, Tensor]
def __init__(
self,
target: StableDiffusion_1,
sub_targets: dict[str, list[LoraTarget]],
scale: float = 1.0,
weights: dict[str, Tensor] | None = None,
):
with self.setup_adapter(target):
super().__init__(target)
self.sub_adapters: list[LoraAdapter[SD1UNet | CLIPTextEncoderL | LatentDiffusionAutoencoder]] = []
for model_name in MODELS:
if not (model_targets := sub_targets.get(model_name, [])):
continue
model = getattr(target, "clip_text_encoder" if model_name == "text_encoder" else model_name)
lora_weights = [weights[k] for k in sorted(weights) if k.startswith(model_name)] if weights else None
self.sub_adapters.append(
LoraAdapter[type(model)](
model,
sub_targets=lora_targets(model, model_targets),
scale=scale,
weights=lora_weights,
)
)
@classmethod
def from_safetensors(
cls,
target: StableDiffusion_1,
checkpoint_path: Path | str,
scale: float = 1.0,
):
metadata = load_metadata_from_safetensors(checkpoint_path)
assert metadata is not None, "Invalid safetensors checkpoint: missing metadata"
tensors = load_from_safetensors(checkpoint_path, device=target.device)
sub_targets: dict[str, list[LoraTarget]] = {}
for model_name in MODELS:
if not (v := metadata.get(f"{model_name}_targets", "")):
continue
sub_targets[model_name] = [LoraTarget(x) for x in v.split(",")]
return cls(
target,
sub_targets,
scale=scale,
weights=tensors,
)
def inject(self: "SD1LoraAdapter", parent: fl.Chain | None = None) -> "SD1LoraAdapter":
for adapter in self.sub_adapters:
adapter.inject()
return super().inject(parent)
def eject(self) -> None:
for adapter in self.sub_adapters:
adapter.eject()
super().eject()

@ -0,0 +1,115 @@
from abc import ABC, abstractmethod
from typing import TypeVar
import torch
from PIL import Image
from torch import Tensor, device as Device, dtype as DType
import imaginairy.vendored.refiners.fluxion.layers as fl
from imaginairy.vendored.refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
T = TypeVar("T", bound="fl.Module")
TLatentDiffusionModel = TypeVar("TLatentDiffusionModel", bound="LatentDiffusionModel")
class LatentDiffusionModel(fl.Module, ABC):
def __init__(
self,
unet: fl.Module,
lda: LatentDiffusionAutoencoder,
clip_text_encoder: fl.Module,
scheduler: Scheduler,
device: Device | str = "cpu",
dtype: DType = torch.float32,
) -> None:
super().__init__()
self.device: Device = device if isinstance(device, Device) else Device(device=device)
self.dtype = dtype
self.unet = unet.to(device=self.device, dtype=self.dtype)
self.lda = lda.to(device=self.device, dtype=self.dtype)
self.clip_text_encoder = clip_text_encoder.to(device=self.device, dtype=self.dtype)
self.scheduler = scheduler.to(device=self.device, dtype=self.dtype)
def set_num_inference_steps(self, num_inference_steps: int) -> None:
initial_diffusion_rate = self.scheduler.initial_diffusion_rate
final_diffusion_rate = self.scheduler.final_diffusion_rate
device, dtype = self.scheduler.device, self.scheduler.dtype
self.scheduler = self.scheduler.__class__(
num_inference_steps,
initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate,
).to(device=device, dtype=dtype)
def init_latents(
self,
size: tuple[int, int],
init_image: Image.Image | None = None,
first_step: int = 0,
noise: Tensor | None = None,
) -> Tensor:
height, width = size
if noise is None:
noise = torch.randn(1, 4, height // 8, width // 8, device=self.device)
assert list(noise.shape[2:]) == [
height // 8,
width // 8,
], f"noise shape is not compatible: {noise.shape}, with size: {size}"
if init_image is None:
return noise
encoded_image = self.lda.encode_image(image=init_image.resize(size=(width, height)))
return self.scheduler.add_noise(x=encoded_image, noise=noise, step=self.steps[first_step])
@property
def steps(self) -> list[int]:
return self.scheduler.steps
@abstractmethod
def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None:
...
@abstractmethod
def set_self_attention_guidance(self, enable: bool, scale: float = 1.0) -> None:
...
@abstractmethod
def has_self_attention_guidance(self) -> bool:
...
@abstractmethod
def compute_self_attention_guidance(
self, x: Tensor, noise: Tensor, step: int, *, clip_text_embedding: Tensor, **kwargs: Tensor
) -> Tensor:
...
def forward(
self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5, **kwargs: Tensor
) -> Tensor:
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
self.set_unet_context(timestep=timestep, clip_text_embedding=clip_text_embedding, **kwargs)
latents = torch.cat(tensors=(x, x)) # for classifier-free guidance
unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2)
# classifier-free guidance
noise = unconditional_prediction + condition_scale * (conditional_prediction - unconditional_prediction)
x = x.narrow(dim=1, start=0, length=4) # support > 4 channels for inpainting
if self.has_self_attention_guidance():
noise += self.compute_self_attention_guidance(
x=x, noise=unconditional_prediction, step=step, clip_text_embedding=clip_text_embedding, **kwargs
)
return self.scheduler(x, noise=noise, step=step)
def structural_copy(self: TLatentDiffusionModel) -> TLatentDiffusionModel:
return self.__class__(
unet=self.unet.structural_copy(),
lda=self.lda.structural_copy(),
clip_text_encoder=self.clip_text_encoder.structural_copy(),
scheduler=self.scheduler,
device=self.device,
dtype=self.dtype,
)

@ -0,0 +1,98 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Generic, TypeVar
import torch
from PIL import Image
from torch import Tensor, device as Device, dtype as DType
from imaginairy.vendored.refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
MAX_STEPS = 1000
@dataclass
class DiffusionTarget:
size: tuple[int, int]
offset: tuple[int, int]
clip_text_embedding: Tensor
init_latents: Tensor | None = None
mask_latent: Tensor | None = None
weight: int = 1
condition_scale: float = 7.5
start_step: int = 0
end_step: int = MAX_STEPS
def crop(self, tensor: Tensor, /) -> Tensor:
height, width = self.size
top_offset, left_offset = self.offset
return tensor[:, :, top_offset : top_offset + height, left_offset : left_offset + width]
def paste(self, tensor: Tensor, /, crop: Tensor) -> Tensor:
height, width = self.size
top_offset, left_offset = self.offset
tensor[:, :, top_offset : top_offset + height, left_offset : left_offset + width] = crop
return tensor
T = TypeVar("T", bound=LatentDiffusionModel)
D = TypeVar("D", bound=DiffusionTarget)
@dataclass
class MultiDiffusion(Generic[T, D], ABC):
ldm: T
def __call__(self, x: Tensor, /, noise: Tensor, step: int, targets: list[D]) -> Tensor:
num_updates = torch.zeros_like(input=x)
cumulative_values = torch.zeros_like(input=x)
for target in targets:
match step:
case step if step == target.start_step and target.init_latents is not None:
noise_view = target.crop(noise)
view = self.ldm.scheduler.add_noise(
x=target.init_latents,
noise=noise_view,
step=step,
)
case step if target.start_step <= step <= target.end_step:
view = target.crop(x)
case _:
continue
view = self.diffuse_target(x=view, step=step, target=target)
weight = target.weight * target.mask_latent if target.mask_latent is not None else target.weight
num_updates = target.paste(num_updates, crop=target.crop(num_updates) + weight)
cumulative_values = target.paste(cumulative_values, crop=target.crop(cumulative_values) + weight * view)
return torch.where(condition=num_updates > 0, input=cumulative_values / num_updates, other=x)
@abstractmethod
def diffuse_target(self, x: Tensor, step: int, target: D) -> Tensor:
...
@property
def steps(self) -> list[int]:
return self.ldm.steps
@property
def device(self) -> Device:
return self.ldm.device
@property
def dtype(self) -> DType:
return self.ldm.dtype
def decode_latents(self, x: Tensor) -> Image.Image:
return self.ldm.lda.decode_latents(x=x)
@staticmethod
def generate_offset_grid(size: tuple[int, int], stride: int = 8) -> list[tuple[int, int]]:
height, width = size
return [
(y, x)
for y in range(0, height, stride)
for x in range(0, width, stride)
if y + 64 <= height and x + 64 <= width
]

@ -0,0 +1,107 @@
# Adapted from https://github.com/carolineec/informative-drawings, MIT License
from torch import device as Device, dtype as DType
import imaginairy.vendored.refiners.fluxion.layers as fl
class InformativeDrawings(fl.Chain):
"""Model typically used as the preprocessor for the Lineart ControlNet.
Implements the paper "Learning to generate line drawings that convey
geometry and semantics" published in 2022 by Caroline Chan, Frédo Durand
and Phillip Isola - https://arxiv.org/abs/2203.12691
For use as a preprocessor it is recommended to use the weights for "Style 2".
"""
def __init__(
self,
in_channels: int = 3, # RGB
out_channels: int = 1, # Grayscale
n_residual_blocks: int = 3,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__(
fl.Chain( # Initial convolution
fl.ReflectionPad2d(3),
fl.Conv2d(
in_channels=in_channels,
out_channels=64,
kernel_size=7,
device=device,
dtype=dtype,
),
fl.InstanceNorm2d(64, device=device, dtype=dtype),
fl.ReLU(),
),
*( # Downsampling
fl.Chain(
fl.Conv2d(
in_channels=64 * (2**i),
out_channels=128 * (2**i),
kernel_size=3,
stride=2,
padding=1,
device=device,
dtype=dtype,
),
fl.InstanceNorm2d(128 * (2**i), device=device, dtype=dtype),
fl.ReLU(),
)
for i in range(2)
),
*( # Residual blocks
fl.Residual(
fl.ReflectionPad2d(1),
fl.Conv2d(
in_channels=256,
out_channels=256,
kernel_size=3,
device=device,
dtype=dtype,
),
fl.InstanceNorm2d(256, device=device, dtype=dtype),
fl.ReLU(),
fl.ReflectionPad2d(1),
fl.Conv2d(
in_channels=256,
out_channels=256,
kernel_size=3,
device=device,
dtype=dtype,
),
fl.InstanceNorm2d(256, device=device, dtype=dtype),
)
for _ in range(n_residual_blocks)
),
*( # Upsampling
fl.Chain(
fl.ConvTranspose2d(
in_channels=128 * (2**i),
out_channels=64 * (2**i),
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
device=device,
dtype=dtype,
),
fl.InstanceNorm2d(64 * (2**i), device=device, dtype=dtype),
fl.ReLU(),
)
for i in reversed(range(2))
),
fl.Chain( # Output layer
fl.ReflectionPad2d(3),
fl.Conv2d(
in_channels=64,
out_channels=out_channels,
kernel_size=7,
device=device,
dtype=dtype,
),
fl.Sigmoid(),
),
)

@ -0,0 +1,68 @@
import math
from jaxtyping import Float, Int
from torch import Tensor, arange, cat, cos, device as Device, dtype as DType, exp, float32, sin
import imaginairy.vendored.refiners.fluxion.layers as fl
from imaginairy.vendored.refiners.fluxion.adapters.adapter import Adapter
def compute_sinusoidal_embedding(
x: Int[Tensor, "*batch 1"],
embedding_dim: int,
) -> Float[Tensor, "*batch 1 embedding_dim"]:
half_dim = embedding_dim // 2
# Note: it is important that this computation is done in float32.
# The result can be cast to lower precision later if necessary.
exponent = -math.log(10000) * arange(start=0, end=half_dim, dtype=float32, device=x.device)
exponent /= half_dim
embedding = x.unsqueeze(1).float() * exp(exponent).unsqueeze(0)
embedding = cat([cos(embedding), sin(embedding)], dim=-1)
return embedding
class RangeEncoder(fl.Chain):
def __init__(
self,
sinuosidal_embedding_dim: int,
embedding_dim: int,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.sinuosidal_embedding_dim = sinuosidal_embedding_dim
self.embedding_dim = embedding_dim
super().__init__(
fl.Lambda(self.compute_sinuosoidal_embedding),
fl.Converter(set_device=False, set_dtype=True),
fl.Linear(in_features=sinuosidal_embedding_dim, out_features=embedding_dim, device=device, dtype=dtype),
fl.SiLU(),
fl.Linear(in_features=embedding_dim, out_features=embedding_dim, device=device, dtype=dtype),
)
def compute_sinuosoidal_embedding(self, x: Int[Tensor, "*batch 1"]) -> Float[Tensor, "*batch 1 embedding_dim"]:
return compute_sinusoidal_embedding(x, embedding_dim=self.sinuosidal_embedding_dim)
class RangeAdapter2d(fl.Sum, Adapter[fl.Conv2d]):
def __init__(
self,
target: fl.Conv2d,
channels: int,
embedding_dim: int,
context_key: str,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.channels = channels
self.embedding_dim = embedding_dim
self.context_key = context_key
with self.setup_adapter(target):
super().__init__(
target,
fl.Chain(
fl.UseContext("range_adapter", context_key),
fl.SiLU(),
fl.Linear(in_features=embedding_dim, out_features=channels, device=device, dtype=dtype),
fl.View(-1, channels, 1, 1),
),
)

@ -0,0 +1,143 @@
from typing import Callable
from torch import Tensor
from imaginairy.vendored.refiners.fluxion.adapters.adapter import Adapter
from imaginairy.vendored.refiners.fluxion.layers import (
Chain,
Concatenate,
Identity,
Lambda,
Parallel,
Passthrough,
SelfAttention,
SetContext,
UseContext,
)
from imaginairy.vendored.refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
class SaveLayerNormAdapter(Chain, Adapter[SelfAttention]):
def __init__(self, target: SelfAttention, context: str) -> None:
self.context = context
with self.setup_adapter(target):
super().__init__(SetContext(self.context, "norm"), target)
class SelfAttentionInjectionAdapter(Chain, Adapter[SelfAttention]):
def __init__(
self,
target: SelfAttention,
context: str,
style_cfg: float = 0.5,
) -> None:
self.context = context
self.style_cfg = style_cfg
sa_guided = target.structural_copy()
assert isinstance(sa_guided[0], Parallel)
sa_guided.replace(
sa_guided[0],
Parallel(
Identity(),
Concatenate(Identity(), UseContext(self.context, "norm"), dim=1),
Concatenate(Identity(), UseContext(self.context, "norm"), dim=1),
),
)
with self.setup_adapter(target):
slice_tensor: Callable[[Tensor], Tensor] = lambda x: x[:1]
super().__init__(
Parallel(sa_guided, Chain(Lambda(slice_tensor), target)),
Lambda(self.compute_averaged_unconditioned_x),
)
def compute_averaged_unconditioned_x(self, x: Tensor, unguided_unconditioned_x: Tensor) -> Tensor:
x[0] = self.style_cfg * x[0] + (1.0 - self.style_cfg) * unguided_unconditioned_x
return x
class SelfAttentionInjectionPassthrough(Passthrough):
def __init__(self, target: SD1UNet) -> None:
guide_unet = target.structural_copy()
for i, attention_block in enumerate(guide_unet.layers(CrossAttentionBlock)):
sa = attention_block.ensure_find(SelfAttention)
assert sa.parent is not None
SaveLayerNormAdapter(sa, context=f"self_attention_context_{i}").inject()
super().__init__(
Lambda(self._copy_diffusion_context),
UseContext("reference_only_control", "guide"),
guide_unet,
Lambda(self._restore_diffusion_context),
)
def _copy_diffusion_context(self, x: Tensor) -> Tensor:
# This function allows to not disrupt the accumulation of residuals in the unet (if controlnet are used)
self.set_context(
"self_attention_residuals_buffer",
{"buffer": self.use_context("unet")["residuals"]},
)
self.set_context(
"unet",
{"residuals": [0.0] * 13},
)
return x
def _restore_diffusion_context(self, x: Tensor) -> Tensor:
self.set_context(
"unet",
{
"residuals": self.use_context("self_attention_residuals_buffer")["buffer"],
},
)
return x
class ReferenceOnlyControlAdapter(Chain, Adapter[SD1UNet]):
# TODO: Does not support batching yet. Assumes concatenated inputs for classifier-free guidance
def __init__(self, target: SD1UNet, style_cfg: float = 0.5) -> None:
# the style_cfg is the weight of the guide in unconditionned diffusion.
# This value is recommended to be 0.5 on the sdwebui repo.
self.sub_adapters: list[SelfAttentionInjectionAdapter] = []
self._passthrough: list[SelfAttentionInjectionPassthrough] = [
SelfAttentionInjectionPassthrough(target)
] # not registered by PyTorch
with self.setup_adapter(target):
super().__init__(target)
for i, attention_block in enumerate(target.layers(CrossAttentionBlock)):
self.set_context(f"self_attention_context_{i}", {"norm": None})
sa = attention_block.ensure_find(SelfAttention)
assert sa.parent is not None
self.sub_adapters.append(
SelfAttentionInjectionAdapter(sa, context=f"self_attention_context_{i}", style_cfg=style_cfg)
)
def inject(self: "ReferenceOnlyControlAdapter", parent: Chain | None = None) -> "ReferenceOnlyControlAdapter":
passthrough = self._passthrough[0]
assert passthrough not in self.target, f"{passthrough} is already injected"
for adapter in self.sub_adapters:
adapter.inject()
self.target.insert(0, passthrough)
return super().inject(parent)
def eject(self) -> None:
passthrough = self._passthrough[0]
assert self.target[0] == passthrough, f"{passthrough} is not the first element of target UNet"
for adapter in self.sub_adapters:
adapter.eject()
self.target.pop(0)
super().eject()
def set_controlnet_condition(self, condition: Tensor) -> None:
self.set_context("reference_only_control", {"guide": condition})
def structural_copy(self: "ReferenceOnlyControlAdapter") -> "ReferenceOnlyControlAdapter":
raise RuntimeError("ReferenceOnlyControlAdapter cannot be copied, eject it first.")

@ -0,0 +1,110 @@
from dataclasses import dataclass
from functools import cached_property
from typing import Generic, TypeVar
import torch
from imaginairy.vendored.refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
T = TypeVar("T", bound=LatentDiffusionModel)
def add_noise_interval(
scheduler: Scheduler,
/,
x: torch.Tensor,
noise: torch.Tensor,
initial_timestep: torch.Tensor,
target_timestep: torch.Tensor,
) -> torch.Tensor:
initial_cumulative_scale_factors = scheduler.cumulative_scale_factors[initial_timestep]
target_cumulative_scale_factors = scheduler.cumulative_scale_factors[target_timestep]
factor = target_cumulative_scale_factors / initial_cumulative_scale_factors
noised_x = factor * x + torch.sqrt(1 - factor**2) * noise
return noised_x
@dataclass
class Restart(Generic[T]):
"""
Implements the restart sampling strategy from the paper "Restart Sampling for Improving Generative Processes"
(https://arxiv.org/pdf/2306.14878.pdf)
Works only with the DDIM scheduler for now.
"""
ldm: T
num_steps: int = 10
num_iterations: int = 2
start_time: float = 0.1
end_time: float = 2
def __post_init__(self) -> None:
assert isinstance(self.ldm.scheduler, DDIM), "Restart sampling only works with DDIM scheduler"
def __call__(
self,
x: torch.Tensor,
/,
clip_text_embedding: torch.Tensor,
condition_scale: float = 7.5,
**kwargs: torch.Tensor,
) -> torch.Tensor:
original_scheduler = self.ldm.scheduler
new_scheduler = DDIM(self.ldm.scheduler.num_inference_steps, device=self.device, dtype=self.dtype)
new_scheduler.timesteps = self.timesteps
self.ldm.scheduler = new_scheduler
for _ in range(self.num_iterations):
noise = torch.randn_like(input=x, device=self.device, dtype=self.dtype)
x = add_noise_interval(
new_scheduler,
x=x,
noise=noise,
initial_timestep=self.timesteps[-1],
target_timestep=self.timesteps[0],
)
for step in range(len(self.timesteps) - 1):
x = self.ldm(
x, step=step, clip_text_embedding=clip_text_embedding, condition_scale=condition_scale, **kwargs
)
self.ldm.scheduler = original_scheduler
return x
@cached_property
def start_step(self) -> int:
sigmas = self.ldm.scheduler.noise_std / self.ldm.scheduler.cumulative_scale_factors
return int(torch.argmin(input=torch.abs(input=sigmas[self.ldm.scheduler.timesteps] - self.start_time)))
@cached_property
def end_timestep(self) -> int:
sigmas = self.ldm.scheduler.noise_std / self.ldm.scheduler.cumulative_scale_factors
return int(torch.argmin(input=torch.abs(input=sigmas - self.end_time)))
@cached_property
def timesteps(self) -> torch.Tensor:
return (
torch.round(
torch.linspace(
start=int(self.ldm.scheduler.timesteps[self.start_step]),
end=self.end_timestep,
steps=self.num_steps,
)
)
.flip(0)
.to(device=self.device, dtype=torch.int64)
)
@property
def device(self) -> torch.device:
return self.ldm.device
@property
def dtype(self) -> torch.dtype:
return self.ldm.dtype

@ -0,0 +1,11 @@
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.ddpm import DDPM
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
__all__ = [
"Scheduler",
"DPMSolver",
"DDPM",
"DDIM",
]

@ -0,0 +1,57 @@
from torch import Tensor, arange, device as Device, dtype as Dtype, float32, sqrt, tensor
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
class DDIM(Scheduler):
def __init__(
self,
num_inference_steps: int,
num_train_timesteps: int = 1_000,
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
device: Device | str = "cpu",
dtype: Dtype = float32,
) -> None:
super().__init__(
num_inference_steps=num_inference_steps,
num_train_timesteps=num_train_timesteps,
initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate,
noise_schedule=noise_schedule,
device=device,
dtype=dtype,
)
self.timesteps = self._generate_timesteps()
def _generate_timesteps(self) -> Tensor:
"""
Generates decreasing timesteps with 'leading' spacing and offset of 1
similar to diffusers settings for the DDIM scheduler in Stable Diffusion 1.5
"""
step_ratio = self.num_train_timesteps // self.num_inference_steps
timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio + 1
return timesteps.flip(0)
def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
timestep, previous_timestep = (
self.timesteps[step],
(
self.timesteps[step + 1]
if step < self.num_inference_steps - 1
else tensor(data=[0], device=self.device, dtype=self.dtype)
),
)
current_scale_factor, previous_scale_factor = (
self.cumulative_scale_factors[timestep],
(
self.cumulative_scale_factors[previous_timestep]
if previous_timestep > 0
else self.cumulative_scale_factors[0]
),
)
predicted_x = (x - sqrt(1 - current_scale_factor**2) * noise) / current_scale_factor
denoised_x = previous_scale_factor * predicted_x + sqrt(1 - previous_scale_factor**2) * noise
return denoised_x

@ -0,0 +1,34 @@
from torch import Tensor, arange, device as Device
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
class DDPM(Scheduler):
"""
The Denoising Diffusion Probabilistic Models (DDPM) is a specific type of diffusion model,
which uses a specific strategy to generate the timesteps and applies the diffusion process in a specific way.
"""
def __init__(
self,
num_inference_steps: int,
num_train_timesteps: int = 1_000,
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
device: Device | str = "cpu",
) -> None:
super().__init__(
num_inference_steps=num_inference_steps,
num_train_timesteps=num_train_timesteps,
initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate,
device=device,
)
def _generate_timesteps(self) -> Tensor:
step_ratio = self.num_train_timesteps // self.num_inference_steps
timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio
return timesteps.flip(0)
def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
raise NotImplementedError

@ -0,0 +1,117 @@
from collections import deque
import numpy as np
from torch import Tensor, device as Device, dtype as Dtype, exp, float32, tensor
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
class DPMSolver(Scheduler):
"""Implements DPM-Solver++ from https://arxiv.org/abs/2211.01095
We only support noise prediction for now.
"""
def __init__(
self,
num_inference_steps: int,
num_train_timesteps: int = 1_000,
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
device: Device | str = "cpu",
dtype: Dtype = float32,
):
super().__init__(
num_inference_steps=num_inference_steps,
num_train_timesteps=num_train_timesteps,
initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate,
noise_schedule=noise_schedule,
device=device,
dtype=dtype,
)
self.estimated_data = deque([tensor([])] * 2, maxlen=2)
self.initial_steps = 0
def _generate_timesteps(self) -> Tensor:
# We need to use numpy here because:
# numpy.linspace(0,999,31)[15] is 499.49999999999994
# torch.linspace(0,999,31)[15] is 499.5
# ...and we want the same result as the original codebase.
return tensor(
np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps + 1).round().astype(int)[1:],
device=self.device,
).flip(0)
def dpm_solver_first_order_update(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
timestep, previous_timestep = (
self.timesteps[step],
self.timesteps[step + 1 if step < len(self.timesteps) - 1 else 0],
)
previous_ratio, current_ratio = (
self.signal_to_noise_ratios[previous_timestep],
self.signal_to_noise_ratios[timestep],
)
previous_scale_factor = self.cumulative_scale_factors[previous_timestep]
previous_noise_std, current_noise_std = (
self.noise_std[previous_timestep],
self.noise_std[timestep],
)
factor = exp(-(previous_ratio - current_ratio)) - 1.0
denoised_x = (previous_noise_std / current_noise_std) * x - (factor * previous_scale_factor) * noise
return denoised_x
def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int) -> Tensor:
previous_timestep, current_timestep, next_timestep = (
self.timesteps[step + 1] if step < len(self.timesteps) - 1 else tensor([0]),
self.timesteps[step],
self.timesteps[step - 1],
)
current_data_estimation, next_data_estimation = self.estimated_data[-1], self.estimated_data[-2]
previous_ratio, current_ratio, next_ratio = (
self.signal_to_noise_ratios[previous_timestep],
self.signal_to_noise_ratios[current_timestep],
self.signal_to_noise_ratios[next_timestep],
)
previous_scale_factor = self.cumulative_scale_factors[previous_timestep]
previous_std, current_std = (
self.noise_std[previous_timestep],
self.noise_std[current_timestep],
)
estimation_delta = (current_data_estimation - next_data_estimation) / (
(current_ratio - next_ratio) / (previous_ratio - current_ratio)
)
factor = exp(-(previous_ratio - current_ratio)) - 1.0
denoised_x = (
(previous_std / current_std) * x
- (factor * previous_scale_factor) * current_data_estimation
- 0.5 * (factor * previous_scale_factor) * estimation_delta
)
return denoised_x
def __call__(
self,
x: Tensor,
noise: Tensor,
step: int,
) -> Tensor:
"""
Represents one step of the backward diffusion process that iteratively denoises the input data `x`.
This method works by estimating the denoised version of `x` and applying either a first-order or second-order
backward Euler update, which is a numerical method commonly used to solve ordinary differential equations
(ODEs).
"""
current_timestep = self.timesteps[step]
scale_factor, noise_ratio = self.cumulative_scale_factors[current_timestep], self.noise_std[current_timestep]
estimated_denoised_data = (x - noise_ratio * noise) / scale_factor
self.estimated_data.append(estimated_denoised_data)
denoised_x = (
self.dpm_solver_first_order_update(x=x, noise=estimated_denoised_data, step=step)
if (self.initial_steps == 0)
else self.multistep_dpm_solver_second_order_update(x=x, step=step)
)
if self.initial_steps < 2:
self.initial_steps += 1
return denoised_x

@ -0,0 +1,128 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import TypeVar
from torch import Tensor, device as Device, dtype as DType, float32, linspace, log, sqrt
T = TypeVar("T", bound="Scheduler")
class NoiseSchedule(str, Enum):
UNIFORM = "uniform"
QUADRATIC = "quadratic"
KARRAS = "karras"
class Scheduler(ABC):
"""
A base class for creating a diffusion model scheduler.
The Scheduler creates a sequence of noise and scaling factors used in the diffusion process,
which gradually transforms the original data distribution into a Gaussian one.
This process is described using several parameters such as initial and final diffusion rates,
and is encapsulated into a `__call__` method that applies a step of the diffusion process.
"""
timesteps: Tensor
def __init__(
self,
num_inference_steps: int,
num_train_timesteps: int = 1_000,
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
device: Device | str = "cpu",
dtype: DType = float32,
):
self.device: Device = Device(device)
self.dtype: DType = dtype
self.num_inference_steps = num_inference_steps
self.num_train_timesteps = num_train_timesteps
self.initial_diffusion_rate = initial_diffusion_rate
self.final_diffusion_rate = final_diffusion_rate
self.noise_schedule = noise_schedule
self.scale_factors = self.sample_noise_schedule()
self.cumulative_scale_factors = sqrt(self.scale_factors.cumprod(dim=0))
self.noise_std = sqrt(1.0 - self.scale_factors.cumprod(dim=0))
self.signal_to_noise_ratios = log(self.cumulative_scale_factors) - log(self.noise_std)
self.timesteps = self._generate_timesteps()
@abstractmethod
def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
"""
Applies a step of the diffusion process to the input tensor `x` using the provided `noise` and `timestep`.
This method should be overridden by subclasses to implement the specific diffusion process.
"""
...
@abstractmethod
def _generate_timesteps(self) -> Tensor:
"""
Generates a tensor of timesteps.
This method should be overridden by subclasses to provide the specific timesteps for the diffusion process.
"""
...
@property
def steps(self) -> list[int]:
return list(range(self.num_inference_steps))
def sample_power_distribution(self, power: float = 2, /) -> Tensor:
return (
linspace(
start=self.initial_diffusion_rate ** (1 / power),
end=self.final_diffusion_rate ** (1 / power),
steps=self.num_train_timesteps,
device=self.device,
dtype=self.dtype,
)
** power
)
def sample_noise_schedule(self) -> Tensor:
match self.noise_schedule:
case "uniform":
return 1 - self.sample_power_distribution(1)
case "quadratic":
return 1 - self.sample_power_distribution(2)
case "karras":
return 1 - self.sample_power_distribution(7)
case _:
raise ValueError(f"Unknown noise schedule: {self.noise_schedule}")
def add_noise(
self,
x: Tensor,
noise: Tensor,
step: int,
) -> Tensor:
timestep = self.timesteps[step]
cumulative_scale_factors = self.cumulative_scale_factors[timestep]
noise_stds = self.noise_std[timestep]
noised_x = cumulative_scale_factors * x + noise_stds * noise
return noised_x
def remove_noise(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
timestep = self.timesteps[step]
cumulative_scale_factors = self.cumulative_scale_factors[timestep]
noise_stds = self.noise_std[timestep]
# See equation (15) from https://arxiv.org/pdf/2006.11239.pdf. Useful to preview progress or for guidance like
# in https://arxiv.org/pdf/2210.00939.pdf (self-attention guidance)
denoised_x = (x - noise_stds * noise) / cumulative_scale_factors
return denoised_x
def to(self: T, device: Device | str | None = None, dtype: DType | None = None) -> T: # type: ignore
if device is not None:
self.device = Device(device)
self.timesteps = self.timesteps.to(device)
if dtype is not None:
self.dtype = dtype
self.scale_factors = self.scale_factors.to(device, dtype=dtype)
self.cumulative_scale_factors = self.cumulative_scale_factors.to(device, dtype=dtype)
self.noise_std = self.noise_std.to(device, dtype=dtype)
self.signal_to_noise_ratios = self.signal_to_noise_ratios.to(device, dtype=dtype)
return self

@ -0,0 +1,101 @@
import math
from typing import TYPE_CHECKING, Any, Generic, TypeVar
import torch
from jaxtyping import Float
from torch import Size, Tensor
import imaginairy.vendored.refiners.fluxion.layers as fl
from imaginairy.vendored.refiners.fluxion.adapters.adapter import Adapter
from imaginairy.vendored.refiners.fluxion.context import Contexts
from imaginairy.vendored.refiners.fluxion.utils import gaussian_blur, interpolate
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
if TYPE_CHECKING:
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
T = TypeVar("T", bound="SD1UNet | SDXLUNet")
TSAGAdapter = TypeVar("TSAGAdapter", bound="SAGAdapter[Any]") # Self (see PEP 673)
class SelfAttentionMap(fl.Passthrough):
def __init__(self, num_heads: int, context_key: str) -> None:
self.num_heads = num_heads
self.context_key = context_key
super().__init__(
fl.Lambda(func=self.compute_attention_scores),
fl.SetContext(context="self_attention_map", key=context_key),
)
def split_to_multi_head(
self, x: Float[Tensor, "batch_size sequence_length embedding_dim"]
) -> Float[Tensor, "batch_size num_heads sequence_length (embedding_dim//num_heads)"]:
assert (
len(x.shape) == 3
), f"Expected tensor with shape (batch_size sequence_length embedding_dim), got {x.shape}"
assert (
x.shape[-1] % self.num_heads == 0
), f"Embedding dim (x.shape[-1]={x.shape[-1]}) must be divisible by num heads"
return x.reshape(x.shape[0], x.shape[1], self.num_heads, x.shape[-1] // self.num_heads).transpose(1, 2)
def compute_attention_scores(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor:
query, key = self.split_to_multi_head(query), self.split_to_multi_head(key)
_, _, _, dim = query.shape
attention = query @ key.permute(0, 1, 3, 2)
attention = attention / math.sqrt(dim)
return torch.softmax(input=attention, dim=-1)
class SelfAttentionShape(fl.Passthrough):
def __init__(self, context_key: str) -> None:
self.context_key = context_key
super().__init__(
fl.SetContext(context="self_attention_map", key=context_key, callback=self.register_shape),
)
def register_shape(self, shapes: list[Size], x: Tensor) -> None:
assert x.ndim == 4, f"Expected 4D tensor, got {x.ndim}D with shape {x.shape}"
shapes.append(x.shape[-2:])
class SAGAdapter(Generic[T], fl.Chain, Adapter[T]):
def __init__(self, target: T, scale: float = 1.0, kernel_size: int = 9, sigma: float = 1.0) -> None:
self.scale = scale
self.kernel_size = kernel_size
self.sigma = sigma
with self.setup_adapter(target):
super().__init__(target)
def inject(self: "TSAGAdapter", parent: fl.Chain | None = None) -> "TSAGAdapter":
return super().inject(parent)
def eject(self) -> None:
super().eject()
def compute_sag_mask(
self, latents: Float[Tensor, "batch_size channels height width"], classifier_free_guidance: bool = True
) -> Float[Tensor, "batch_size channels height width"]:
attn_map = self.use_context("self_attention_map")["middle_block_attn_map"]
if classifier_free_guidance:
unconditional_attn, _ = attn_map.chunk(2)
attn_map = unconditional_attn
attn_shape = self.use_context("self_attention_map")["middle_block_attn_shape"].pop()
assert len(attn_shape) == 2
b, c, h, w = latents.shape
attn_h, attn_w = attn_shape
attn_mask = attn_map.mean(dim=1, keepdim=False).sum(dim=1, keepdim=False) > 1.0
attn_mask = attn_mask.reshape(b, attn_h, attn_w).unsqueeze(1).repeat(1, c, 1, 1).type(attn_map.dtype)
return interpolate(attn_mask, Size((h, w)))
def compute_degraded_latents(
self, scheduler: Scheduler, latents: Tensor, noise: Tensor, step: int, classifier_free_guidance: bool = True
) -> Tensor:
sag_mask = self.compute_sag_mask(latents=latents, classifier_free_guidance=classifier_free_guidance)
original_latents = scheduler.remove_noise(x=latents, noise=noise, step=step)
degraded_latents = gaussian_blur(original_latents, kernel_size=self.kernel_size, sigma=self.sigma)
degraded_latents = degraded_latents * sag_mask + original_latents * (1 - sag_mask)
return scheduler.add_noise(degraded_latents, noise=noise, step=step)
def init_context(self) -> Contexts:
return {"self_attention_map": {"middle_block_attn_map": None, "middle_block_attn_shape": []}}

@ -0,0 +1,17 @@
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1ControlnetAdapter
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.image_prompt import SD1IPAdapter
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.model import (
StableDiffusion_1,
StableDiffusion_1_Inpainting,
)
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.t2i_adapter import SD1T2IAdapter
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
__all__ = [
"StableDiffusion_1",
"StableDiffusion_1_Inpainting",
"SD1UNet",
"SD1ControlnetAdapter",
"SD1IPAdapter",
"SD1T2IAdapter",
]

@ -0,0 +1,179 @@
from typing import Iterable, cast
from torch import Tensor, device as Device, dtype as DType
from imaginairy.vendored.refiners.fluxion.adapters.adapter import Adapter
from imaginairy.vendored.refiners.fluxion.context import Contexts
from imaginairy.vendored.refiners.fluxion.layers import Chain, Conv2d, Lambda, Passthrough, Residual, SiLU, Slicing, UseContext
from imaginairy.vendored.refiners.foundationals.latent_diffusion.range_adapter import RangeAdapter2d
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import (
DownBlocks,
MiddleBlock,
ResidualBlock,
SD1UNet,
TimestepEncoder,
)
class ConditionEncoder(Chain):
"""Encode an image to be used as a condition for Controlnet.
Input is a `batch 3 width height` tensor, output is a `batch 320 width//8 height//8` tensor.
"""
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
self.out_channels = (16, 32, 96, 256)
super().__init__(
Chain(
Conv2d(
in_channels=3,
out_channels=self.out_channels[0],
kernel_size=3,
stride=1,
padding=1,
device=device,
dtype=dtype,
),
SiLU(),
),
*(
Chain(
Conv2d(
in_channels=self.out_channels[i],
out_channels=self.out_channels[i],
kernel_size=3,
padding=1,
device=device,
dtype=dtype,
),
SiLU(),
Conv2d(
in_channels=self.out_channels[i],
out_channels=self.out_channels[i + 1],
kernel_size=3,
stride=2,
padding=1,
device=device,
dtype=dtype,
),
SiLU(),
)
for i in range(len(self.out_channels) - 1)
),
Conv2d(
in_channels=self.out_channels[-1],
out_channels=320,
kernel_size=3,
padding=1,
device=device,
dtype=dtype,
),
)
class Controlnet(Passthrough):
def __init__(
self, name: str, scale: float = 1.0, device: Device | str | None = None, dtype: DType | None = None
) -> None:
"""Controlnet is a Half-UNet that collects residuals from the UNet and uses them to condition the UNet.
Input is a `batch 3 width height` tensor, output is a `batch 1280 width//8 height//8` tensor with residuals
stored in the context.
It has to use the same context as the UNet: `unet` and `sampling`.
"""
self.name = name
self.scale = scale
super().__init__(
TimestepEncoder(context_key=f"timestep_embedding_{name}", device=device, dtype=dtype),
Slicing(dim=1, end=4), # support inpainting
DownBlocks(in_channels=4, device=device, dtype=dtype),
MiddleBlock(device=device, dtype=dtype),
)
# We run the condition encoder at each step. Caching the result
# is not worth it as subsequent runs take virtually no time (FG-374).
self.DownBlocks[0].append(
Residual(
UseContext("controlnet", f"condition_{name}"),
ConditionEncoder(device=device, dtype=dtype),
),
)
for residual_block in self.layers(ResidualBlock):
chain = residual_block.Chain
RangeAdapter2d(
target=chain.Conv2d_1,
channels=residual_block.out_channels,
embedding_dim=1280,
context_key=f"timestep_embedding_{name}",
device=device,
dtype=dtype,
).inject(chain)
for n, block in enumerate(cast(Iterable[Chain], self.DownBlocks)):
assert hasattr(block[0], "out_channels"), (
"The first block of every subchain in DownBlocks is expected to respond to `out_channels`,"
f" {block[0]} does not."
)
out_channels: int = block[0].out_channels
block.append(
Passthrough(
Conv2d(
in_channels=out_channels, out_channels=out_channels, kernel_size=1, device=device, dtype=dtype
),
Lambda(self._store_nth_residual(n)),
)
)
self.MiddleBlock.append(
Passthrough(
Conv2d(in_channels=1280, out_channels=1280, kernel_size=1, device=device, dtype=dtype),
Lambda(self._store_nth_residual(12)),
)
)
def _store_nth_residual(self, n: int):
def _store_residual(x: Tensor):
residuals = self.use_context("unet")["residuals"]
residuals[n] = residuals[n] + x * self.scale
return x
return _store_residual
class SD1ControlnetAdapter(Chain, Adapter[SD1UNet]):
def __init__(
self, target: SD1UNet, name: str, scale: float = 1.0, weights: dict[str, Tensor] | None = None
) -> None:
self.name = name
controlnet = Controlnet(name=name, scale=scale, device=target.device, dtype=target.dtype)
if weights is not None:
controlnet.load_state_dict(weights)
self._controlnet: list[Controlnet] = [controlnet] # not registered by PyTorch
with self.setup_adapter(target):
super().__init__(target)
def inject(self: "SD1ControlnetAdapter", parent: Chain | None = None) -> "SD1ControlnetAdapter":
controlnet = self._controlnet[0]
target_controlnets = [x for x in self.target if isinstance(x, Controlnet)]
assert controlnet not in target_controlnets, f"{controlnet} is already injected"
for cn in target_controlnets:
assert cn.name != self.name, f"Controlnet named {self.name} is already injected"
self.target.insert(0, controlnet)
return super().inject(parent)
def eject(self) -> None:
self.target.remove(self._controlnet[0])
super().eject()
def init_context(self) -> Contexts:
return {"controlnet": {f"condition_{self.name}": None}}
def set_scale(self, scale: float) -> None:
self._controlnet[0].scale = scale
def set_controlnet_condition(self, condition: Tensor) -> None:
self.set_context("controlnet", {f"condition_{self.name}": condition})
def structural_copy(self: "SD1ControlnetAdapter") -> "SD1ControlnetAdapter":
raise RuntimeError("Controlnet cannot be copied, eject it first.")

@ -0,0 +1,53 @@
from torch import Tensor
from imaginairy.vendored.refiners.foundationals.clip.image_encoder import CLIPImageEncoderH
from imaginairy.vendored.refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
from imaginairy.vendored.refiners.foundationals.latent_diffusion.image_prompt import ImageProjection, IPAdapter, PerceiverResampler
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
class SD1IPAdapter(IPAdapter[SD1UNet]):
def __init__(
self,
target: SD1UNet,
clip_image_encoder: CLIPImageEncoderH | None = None,
image_proj: ImageProjection | PerceiverResampler | None = None,
scale: float = 1.0,
fine_grained: bool = False,
weights: dict[str, Tensor] | None = None,
) -> None:
clip_image_encoder = clip_image_encoder or CLIPImageEncoderH(device=target.device, dtype=target.dtype)
if image_proj is None:
cross_attn_2d = target.ensure_find(CrossAttentionBlock2d)
image_proj = (
ImageProjection(
clip_image_embedding_dim=clip_image_encoder.output_dim,
clip_text_embedding_dim=cross_attn_2d.context_embedding_dim,
device=target.device,
dtype=target.dtype,
)
if not fine_grained
else PerceiverResampler(
latents_dim=cross_attn_2d.context_embedding_dim,
num_attention_layers=4,
num_attention_heads=12,
head_dim=64,
num_tokens=16,
input_dim=clip_image_encoder.embedding_dim, # = dim before final projection
output_dim=cross_attn_2d.context_embedding_dim,
device=target.device,
dtype=target.dtype,
)
)
elif fine_grained:
assert isinstance(image_proj, PerceiverResampler)
super().__init__(
target=target,
clip_image_encoder=clip_image_encoder,
image_proj=image_proj,
scale=scale,
fine_grained=fine_grained,
weights=weights,
)

@ -0,0 +1,173 @@
import numpy as np
import torch
from PIL import Image
from torch import Tensor, device as Device, dtype as DType
from imaginairy.vendored.refiners.fluxion.utils import image_to_tensor, interpolate
from imaginairy.vendored.refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
from imaginairy.vendored.refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
from imaginairy.vendored.refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.self_attention_guidance import SD1SAGAdapter
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
class SD1Autoencoder(LatentDiffusionAutoencoder):
encoder_scale: float = 0.18215
class StableDiffusion_1(LatentDiffusionModel):
unet: SD1UNet
clip_text_encoder: CLIPTextEncoderL
def __init__(
self,
unet: SD1UNet | None = None,
lda: SD1Autoencoder | None = None,
clip_text_encoder: CLIPTextEncoderL | None = None,
scheduler: Scheduler | None = None,
device: Device | str = "cpu",
dtype: DType = torch.float32,
) -> None:
unet = unet or SD1UNet(in_channels=4)
lda = lda or SD1Autoencoder()
clip_text_encoder = clip_text_encoder or CLIPTextEncoderL()
scheduler = scheduler or DPMSolver(num_inference_steps=30)
super().__init__(
unet=unet,
lda=lda,
clip_text_encoder=clip_text_encoder,
scheduler=scheduler,
device=device,
dtype=dtype,
)
def compute_clip_text_embedding(self, text: str, negative_text: str = "") -> Tensor:
conditional_embedding = self.clip_text_encoder(text)
if text == negative_text:
return torch.cat(tensors=(conditional_embedding, conditional_embedding), dim=0)
negative_embedding = self.clip_text_encoder(negative_text or "")
return torch.cat(tensors=(negative_embedding, conditional_embedding), dim=0)
def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None:
self.unet.set_timestep(timestep=timestep)
self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
def set_self_attention_guidance(self, enable: bool, scale: float = 1.0) -> None:
if enable:
if sag := self._find_sag_adapter():
sag.scale = scale
else:
SD1SAGAdapter(target=self.unet, scale=scale).inject()
else:
if sag := self._find_sag_adapter():
sag.eject()
def has_self_attention_guidance(self) -> bool:
return self._find_sag_adapter() is not None
def _find_sag_adapter(self) -> SD1SAGAdapter | None:
for p in self.unet.get_parents():
if isinstance(p, SD1SAGAdapter):
return p
return None
def compute_self_attention_guidance(
self, x: Tensor, noise: Tensor, step: int, *, clip_text_embedding: Tensor, **kwargs: Tensor
) -> Tensor:
sag = self._find_sag_adapter()
assert sag is not None
degraded_latents = sag.compute_degraded_latents(
scheduler=self.scheduler,
latents=x,
noise=noise,
step=step,
classifier_free_guidance=True,
)
negative_embedding, _ = clip_text_embedding.chunk(2)
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs)
degraded_noise = self.unet(degraded_latents)
return sag.scale * (noise - degraded_noise)
class StableDiffusion_1_Inpainting(StableDiffusion_1):
def __init__(
self,
unet: SD1UNet | None = None,
lda: SD1Autoencoder | None = None,
clip_text_encoder: CLIPTextEncoderL | None = None,
scheduler: Scheduler | None = None,
device: Device | str = "cpu",
dtype: DType = torch.float32,
) -> None:
self.mask_latents: Tensor | None = None
self.target_image_latents: Tensor | None = None
super().__init__(
unet=unet, lda=lda, clip_text_encoder=clip_text_encoder, scheduler=scheduler, device=device, dtype=dtype
)
def forward(
self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5, **_: Tensor
) -> Tensor:
assert self.mask_latents is not None
assert self.target_image_latents is not None
x = torch.cat(tensors=(x, self.mask_latents, self.target_image_latents), dim=1)
return super().forward(
x=x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=condition_scale,
)
def set_inpainting_conditions(
self,
target_image: Image.Image,
mask: Image.Image,
latents_size: tuple[int, int] = (64, 64),
) -> tuple[Tensor, Tensor]:
target_image = target_image.convert(mode="RGB")
mask = mask.convert(mode="L")
mask_tensor = torch.tensor(data=np.array(object=mask).astype(dtype=np.float32) / 255.0).to(device=self.device)
mask_tensor = (mask_tensor > 0.5).unsqueeze(dim=0).unsqueeze(dim=0).to(dtype=self.dtype)
self.mask_latents = interpolate(x=mask_tensor, factor=torch.Size(latents_size))
init_image_tensor = image_to_tensor(image=target_image, device=self.device, dtype=self.dtype) * 2 - 1
masked_init_image = init_image_tensor * (1 - mask_tensor)
self.target_image_latents = self.lda.encode(x=masked_init_image)
return self.mask_latents, self.target_image_latents
def compute_self_attention_guidance(
self, x: Tensor, noise: Tensor, step: int, *, clip_text_embedding: Tensor, **kwargs: Tensor
) -> Tensor:
sag = self._find_sag_adapter()
assert sag is not None
assert self.mask_latents is not None
assert self.target_image_latents is not None
degraded_latents = sag.compute_degraded_latents(
scheduler=self.scheduler,
latents=x,
noise=noise,
step=step,
classifier_free_guidance=True,
)
negative_embedding, _ = clip_text_embedding.chunk(2)
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs)
x = torch.cat(
tensors=(degraded_latents, self.mask_latents, self.target_image_latents),
dim=1,
)
degraded_noise = self.unet(x)
return sag.scale * (noise - degraded_noise)

@ -0,0 +1,41 @@
from dataclasses import dataclass, field
from PIL import Image
from torch import Tensor
from imaginairy.vendored.refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget, MultiDiffusion
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.model import (
StableDiffusion_1,
StableDiffusion_1_Inpainting,
)
class SD1MultiDiffusion(MultiDiffusion[StableDiffusion_1, DiffusionTarget]):
def diffuse_target(self, x: Tensor, step: int, target: DiffusionTarget) -> Tensor:
return self.ldm(
x=x,
step=step,
clip_text_embedding=target.clip_text_embedding,
scale=target.condition_scale,
)
@dataclass
class InpaintingDiffusionTarget(DiffusionTarget):
target_image: Image.Image = field(default_factory=lambda: Image.new(mode="RGB", size=(512, 512), color=255))
mask: Image.Image = field(default_factory=lambda: Image.new(mode="L", size=(512, 512), color=255))
class SD1InpaintingMultiDiffusion(MultiDiffusion[StableDiffusion_1_Inpainting, InpaintingDiffusionTarget]):
def diffuse_target(self, x: Tensor, step: int, target: InpaintingDiffusionTarget) -> Tensor:
self.ldm.set_inpainting_conditions(
target_image=target.target_image,
mask=target.mask,
)
return self.ldm(
x=x,
step=step,
clip_text_embedding=target.clip_text_embedding,
scale=target.condition_scale,
)

@ -0,0 +1,41 @@
import imaginairy.vendored.refiners.fluxion.layers as fl
from imaginairy.vendored.refiners.fluxion.layers.attentions import ScaledDotProductAttention
from imaginairy.vendored.refiners.foundationals.latent_diffusion.self_attention_guidance import (
SAGAdapter,
SelfAttentionMap,
SelfAttentionShape,
)
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import MiddleBlock, ResidualBlock, SD1UNet
class SD1SAGAdapter(SAGAdapter[SD1UNet]):
def __init__(self, target: SD1UNet, scale: float = 1.0, kernel_size: int = 9, sigma: float = 1.0) -> None:
super().__init__(
target=target,
scale=scale,
kernel_size=kernel_size,
sigma=sigma,
)
def inject(self: "SD1SAGAdapter", parent: fl.Chain | None = None) -> "SD1SAGAdapter":
middle_block = self.target.ensure_find(MiddleBlock)
middle_block.insert_after_type(ResidualBlock, SelfAttentionShape(context_key="middle_block_attn_shape"))
# An alternative would be to replace the ScaledDotProductAttention with a version which records the attention
# scores to avoid computing these scores twice
self_attn = middle_block.ensure_find(fl.SelfAttention)
self_attn.insert_before_type(
ScaledDotProductAttention,
SelfAttentionMap(num_heads=self_attn.num_heads, context_key="middle_block_attn_map"),
)
return super().inject(parent)
def eject(self) -> None:
middle_block = self.target.ensure_find(MiddleBlock)
middle_block.remove(middle_block.ensure_find(SelfAttentionShape))
self_attn = middle_block.ensure_find(fl.SelfAttention)
self_attn.remove(self_attn.ensure_find(SelfAttentionMap))
super().eject()

@ -0,0 +1,37 @@
from torch import Tensor
import imaginairy.vendored.refiners.fluxion.layers as fl
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import ResidualAccumulator, SD1UNet
from imaginairy.vendored.refiners.foundationals.latent_diffusion.t2i_adapter import ConditionEncoder, T2IAdapter, T2IFeatures
class SD1T2IAdapter(T2IAdapter[SD1UNet]):
def __init__(
self,
target: SD1UNet,
name: str,
condition_encoder: ConditionEncoder | None = None,
scale: float = 1.0,
weights: dict[str, Tensor] | None = None,
) -> None:
self.residual_indices = (2, 5, 8, 11)
self._features = [T2IFeatures(name=name, index=i, scale=scale) for i in range(4)]
super().__init__(
target=target,
name=name,
condition_encoder=condition_encoder or ConditionEncoder(device=target.device, dtype=target.dtype),
weights=weights,
)
def inject(self: "SD1T2IAdapter", parent: fl.Chain | None = None) -> "SD1T2IAdapter":
for n, feat in zip(self.residual_indices, self._features, strict=True):
block = self.target.DownBlocks[n]
for t2i_layer in block.layers(layer_type=T2IFeatures):
assert t2i_layer.name != self.name, f"T2I-Adapter named {self.name} is already injected"
block.insert_before_type(ResidualAccumulator, feat)
return super().inject(parent)
def eject(self: "SD1T2IAdapter") -> None:
for n, feat in zip(self.residual_indices, self._features, strict=True):
self.target.DownBlocks[n].remove(feat)
super().eject()

@ -0,0 +1,288 @@
from typing import Iterable, cast
from torch import Tensor, device as Device, dtype as DType
import imaginairy.vendored.refiners.fluxion.layers as fl
from imaginairy.vendored.refiners.fluxion.context import Contexts
from imaginairy.vendored.refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
from imaginairy.vendored.refiners.foundationals.latent_diffusion.range_adapter import RangeAdapter2d, RangeEncoder
class TimestepEncoder(fl.Passthrough):
def __init__(
self,
context_key: str = "timestep_embedding",
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__(
fl.UseContext("diffusion", "timestep"),
RangeEncoder(320, 1280, device=device, dtype=dtype),
fl.SetContext("range_adapter", context_key),
)
class ResidualBlock(fl.Sum):
def __init__(
self,
in_channels: int,
out_channels: int,
num_groups: int = 32,
eps: float = 1e-5,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
if in_channels % num_groups != 0 or out_channels % num_groups != 0:
raise ValueError("Number of input and output channels must be divisible by num_groups.")
self.in_channels = in_channels
self.out_channels = out_channels
self.num_groups = num_groups
self.eps = eps
shortcut = (
fl.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, device=device, dtype=dtype)
if in_channels != out_channels
else fl.Identity()
)
super().__init__(
fl.Chain(
fl.GroupNorm(channels=in_channels, num_groups=num_groups, eps=eps, device=device, dtype=dtype),
fl.SiLU(),
fl.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
padding=1,
device=device,
dtype=dtype,
),
fl.GroupNorm(channels=out_channels, num_groups=num_groups, eps=eps, device=device, dtype=dtype),
fl.SiLU(),
fl.Conv2d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
padding=1,
device=device,
dtype=dtype,
),
),
shortcut,
)
class CLIPLCrossAttention(CrossAttentionBlock2d):
def __init__(
self,
channels: int,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__(
channels=channels,
context_embedding_dim=768,
context_key="clip_text_embedding",
num_attention_heads=8,
use_bias=False,
device=device,
dtype=dtype,
)
class DownBlocks(fl.Chain):
def __init__(
self,
in_channels: int,
device: Device | str | None = None,
dtype: DType | None = None,
):
self.in_channels = in_channels
super().__init__(
fl.Chain(
fl.Conv2d(
in_channels=in_channels, out_channels=320, kernel_size=3, padding=1, device=device, dtype=dtype
)
),
fl.Chain(
ResidualBlock(in_channels=320, out_channels=320, device=device, dtype=dtype),
CLIPLCrossAttention(channels=320, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=320, out_channels=320, device=device, dtype=dtype),
CLIPLCrossAttention(channels=320, device=device, dtype=dtype),
),
fl.Chain(fl.Downsample(channels=320, scale_factor=2, padding=1, device=device, dtype=dtype)),
fl.Chain(
ResidualBlock(in_channels=320, out_channels=640, device=device, dtype=dtype),
CLIPLCrossAttention(channels=640, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=640, out_channels=640, device=device, dtype=dtype),
CLIPLCrossAttention(channels=640, device=device, dtype=dtype),
),
fl.Chain(fl.Downsample(channels=640, scale_factor=2, padding=1, device=device, dtype=dtype)),
fl.Chain(
ResidualBlock(in_channels=640, out_channels=1280, device=device, dtype=dtype),
CLIPLCrossAttention(channels=1280, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=1280, out_channels=1280, device=device, dtype=dtype),
CLIPLCrossAttention(channels=1280, device=device, dtype=dtype),
),
fl.Chain(fl.Downsample(channels=1280, scale_factor=2, padding=1, device=device, dtype=dtype)),
fl.Chain(
ResidualBlock(in_channels=1280, out_channels=1280, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=1280, out_channels=1280, device=device, dtype=dtype),
),
)
class UpBlocks(fl.Chain):
def __init__(
self,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__(
fl.Chain(
ResidualBlock(in_channels=2560, out_channels=1280, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=2560, out_channels=1280, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=2560, out_channels=1280, device=device, dtype=dtype),
fl.Upsample(channels=1280, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=2560, out_channels=1280, device=device, dtype=dtype),
CLIPLCrossAttention(channels=1280, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=2560, out_channels=1280, device=device, dtype=dtype),
CLIPLCrossAttention(channels=1280, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=1920, out_channels=1280, device=device, dtype=dtype),
CLIPLCrossAttention(channels=1280, device=device, dtype=dtype),
fl.Upsample(channels=1280, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=1920, out_channels=640, device=device, dtype=dtype),
CLIPLCrossAttention(channels=640, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=1280, out_channels=640, device=device, dtype=dtype),
CLIPLCrossAttention(channels=640, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=960, out_channels=640, device=device, dtype=dtype),
CLIPLCrossAttention(channels=640, device=device, dtype=dtype),
fl.Upsample(channels=640, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=960, out_channels=320, device=device, dtype=dtype),
CLIPLCrossAttention(channels=320, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=640, out_channels=320, device=device, dtype=dtype),
CLIPLCrossAttention(channels=320, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=640, out_channels=320, device=device, dtype=dtype),
CLIPLCrossAttention(channels=320, device=device, dtype=dtype),
),
)
class MiddleBlock(fl.Chain):
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
super().__init__(
ResidualBlock(in_channels=1280, out_channels=1280, device=device, dtype=dtype),
CLIPLCrossAttention(channels=1280, device=device, dtype=dtype),
ResidualBlock(in_channels=1280, out_channels=1280, device=device, dtype=dtype),
)
class ResidualAccumulator(fl.Passthrough):
def __init__(self, n: int) -> None:
self.n = n
super().__init__(
fl.Residual(
fl.UseContext(context="unet", key="residuals").compose(func=lambda residuals: residuals[self.n])
),
fl.SetContext(context="unet", key="residuals", callback=self.update),
)
def update(self, residuals: list[Tensor | float], x: Tensor) -> None:
residuals[self.n] = x
class ResidualConcatenator(fl.Chain):
def __init__(self, n: int) -> None:
self.n = n
super().__init__(
fl.Concatenate(
fl.Identity(),
fl.UseContext(context="unet", key="residuals").compose(lambda residuals: residuals[self.n]),
dim=1,
),
)
class SD1UNet(fl.Chain):
def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
self.in_channels = in_channels
super().__init__(
TimestepEncoder(device=device, dtype=dtype),
DownBlocks(in_channels=in_channels, device=device, dtype=dtype),
fl.Sum(
fl.UseContext(context="unet", key="residuals").compose(lambda x: x[-1]),
MiddleBlock(device=device, dtype=dtype),
),
UpBlocks(),
fl.Chain(
fl.GroupNorm(channels=320, num_groups=32, device=device, dtype=dtype),
fl.SiLU(),
fl.Conv2d(
in_channels=320,
out_channels=4,
kernel_size=3,
stride=1,
padding=1,
device=device,
dtype=dtype,
),
),
)
for residual_block in self.layers(ResidualBlock):
chain = residual_block.Chain
RangeAdapter2d(
target=chain.Conv2d_1,
channels=residual_block.out_channels,
embedding_dim=1280,
context_key="timestep_embedding",
device=device,
dtype=dtype,
).inject(chain)
for n, block in enumerate(cast(Iterable[fl.Chain], self.DownBlocks)):
block.append(ResidualAccumulator(n))
for n, block in enumerate(cast(Iterable[fl.Chain], self.UpBlocks)):
block.insert(0, ResidualConcatenator(-n - 2))
def init_context(self) -> Contexts:
return {
"unet": {"residuals": [0.0] * 13},
"diffusion": {"timestep": None},
"range_adapter": {"timestep_embedding": None},
"sampling": {"shapes": []},
}
def set_clip_text_embedding(self, clip_text_embedding: Tensor) -> None:
self.set_context("cross_attention_block", {"clip_text_embedding": clip_text_embedding})
def set_timestep(self, timestep: Tensor) -> None:
self.set_context("diffusion", {"timestep": timestep})

@ -0,0 +1,13 @@
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_xl.image_prompt import SDXLIPAdapter
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_xl.t2i_adapter import SDXLT2IAdapter
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
__all__ = [
"SDXLUNet",
"DoubleTextEncoder",
"StableDiffusion_XL",
"SDXLIPAdapter",
"SDXLT2IAdapter",
]

@ -0,0 +1,53 @@
from torch import Tensor
from imaginairy.vendored.refiners.foundationals.clip.image_encoder import CLIPImageEncoderH
from imaginairy.vendored.refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
from imaginairy.vendored.refiners.foundationals.latent_diffusion.image_prompt import ImageProjection, IPAdapter, PerceiverResampler
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
class SDXLIPAdapter(IPAdapter[SDXLUNet]):
def __init__(
self,
target: SDXLUNet,
clip_image_encoder: CLIPImageEncoderH | None = None,
image_proj: ImageProjection | PerceiverResampler | None = None,
scale: float = 1.0,
fine_grained: bool = False,
weights: dict[str, Tensor] | None = None,
) -> None:
clip_image_encoder = clip_image_encoder or CLIPImageEncoderH(device=target.device, dtype=target.dtype)
if image_proj is None:
cross_attn_2d = target.ensure_find(CrossAttentionBlock2d)
image_proj = (
ImageProjection(
clip_image_embedding_dim=clip_image_encoder.output_dim,
clip_text_embedding_dim=cross_attn_2d.context_embedding_dim,
device=target.device,
dtype=target.dtype,
)
if not fine_grained
else PerceiverResampler(
latents_dim=1280, # not `cross_attn_2d.context_embedding_dim` in this case
num_attention_layers=4,
num_attention_heads=20,
head_dim=64,
num_tokens=16,
input_dim=clip_image_encoder.embedding_dim, # = dim before final projection
output_dim=cross_attn_2d.context_embedding_dim,
device=target.device,
dtype=target.dtype,
)
)
elif fine_grained:
assert isinstance(image_proj, PerceiverResampler)
super().__init__(
target=target,
clip_image_encoder=clip_image_encoder,
image_proj=image_proj,
scale=scale,
fine_grained=fine_grained,
weights=weights,
)

@ -0,0 +1,154 @@
import torch
from torch import Tensor, device as Device, dtype as DType
from imaginairy.vendored.refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
from imaginairy.vendored.refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_xl.self_attention_guidance import SDXLSAGAdapter
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
class SDXLAutoencoder(LatentDiffusionAutoencoder):
encoder_scale: float = 0.13025
class StableDiffusion_XL(LatentDiffusionModel):
unet: SDXLUNet
clip_text_encoder: DoubleTextEncoder
def __init__(
self,
unet: SDXLUNet | None = None,
lda: SDXLAutoencoder | None = None,
clip_text_encoder: DoubleTextEncoder | None = None,
scheduler: Scheduler | None = None,
device: Device | str = "cpu",
dtype: DType = torch.float32,
) -> None:
unet = unet or SDXLUNet(in_channels=4)
lda = lda or SDXLAutoencoder()
clip_text_encoder = clip_text_encoder or DoubleTextEncoder()
scheduler = scheduler or DDIM(num_inference_steps=30)
super().__init__(
unet=unet,
lda=lda,
clip_text_encoder=clip_text_encoder,
scheduler=scheduler,
device=device,
dtype=dtype,
)
def compute_clip_text_embedding(self, text: str, negative_text: str | None = None) -> tuple[Tensor, Tensor]:
conditional_embedding, conditional_pooled_embedding = self.clip_text_encoder(text)
if text == negative_text:
return torch.cat(tensors=(conditional_embedding, conditional_embedding), dim=0), torch.cat(
tensors=(conditional_pooled_embedding, conditional_pooled_embedding), dim=0
)
# TODO: when negative_text is None, use zero tensor?
negative_embedding, negative_pooled_embedding = self.clip_text_encoder(negative_text or "")
return torch.cat(tensors=(negative_embedding, conditional_embedding), dim=0), torch.cat(
tensors=(negative_pooled_embedding, conditional_pooled_embedding), dim=0
)
@property
def default_time_ids(self) -> Tensor:
# [original_height, original_width, crop_top, crop_left, target_height, target_width]
# See https://arxiv.org/abs/2307.01952 > 2.2 Micro-Conditioning
time_ids = torch.tensor(data=[1024, 1024, 0, 0, 1024, 1024], device=self.device)
return time_ids.repeat(2, 1)
def set_unet_context(
self,
*,
timestep: Tensor,
clip_text_embedding: Tensor,
pooled_text_embedding: Tensor,
time_ids: Tensor,
**_: Tensor,
) -> None:
self.unet.set_timestep(timestep=timestep)
self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
self.unet.set_pooled_text_embedding(pooled_text_embedding=pooled_text_embedding)
self.unet.set_time_ids(time_ids=time_ids)
def forward(
self,
x: Tensor,
step: int,
*,
clip_text_embedding: Tensor,
pooled_text_embedding: Tensor,
time_ids: Tensor,
condition_scale: float = 5.0,
**kwargs: Tensor,
) -> Tensor:
return super().forward(
x=x,
step=step,
clip_text_embedding=clip_text_embedding,
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
condition_scale=condition_scale,
**kwargs,
)
def set_self_attention_guidance(self, enable: bool, scale: float = 1.0) -> None:
if enable:
if sag := self._find_sag_adapter():
sag.scale = scale
else:
SDXLSAGAdapter(target=self.unet, scale=scale).inject()
else:
if sag := self._find_sag_adapter():
sag.eject()
def has_self_attention_guidance(self) -> bool:
return self._find_sag_adapter() is not None
def _find_sag_adapter(self) -> SDXLSAGAdapter | None:
for p in self.unet.get_parents():
if isinstance(p, SDXLSAGAdapter):
return p
return None
def compute_self_attention_guidance(
self,
x: Tensor,
noise: Tensor,
step: int,
*,
clip_text_embedding: Tensor,
pooled_text_embedding: Tensor,
time_ids: Tensor,
**kwargs: Tensor,
) -> Tensor:
sag = self._find_sag_adapter()
assert sag is not None
degraded_latents = sag.compute_degraded_latents(
scheduler=self.scheduler,
latents=x,
noise=noise,
step=step,
classifier_free_guidance=True,
)
negative_embedding, _ = clip_text_embedding.chunk(2)
negative_pooled_embedding, _ = pooled_text_embedding.chunk(2)
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
time_ids, _ = time_ids.chunk(2)
self.set_unet_context(
timestep=timestep,
clip_text_embedding=negative_embedding,
pooled_text_embedding=negative_pooled_embedding,
time_ids=time_ids,
**kwargs,
)
degraded_noise = self.unet(degraded_latents)
return sag.scale * (noise - degraded_noise)

@ -0,0 +1,21 @@
from torch import Tensor
from imaginairy.vendored.refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget, MultiDiffusion
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
class SDXLDiffusionTarget(DiffusionTarget):
pooled_text_embedding: Tensor
time_ids: Tensor
class SDXLMultiDiffusion(MultiDiffusion[StableDiffusion_XL, SDXLDiffusionTarget]):
def diffuse_target(self, x: Tensor, step: int, target: SDXLDiffusionTarget) -> Tensor:
return self.ldm(
x=x,
step=step,
clip_text_embedding=target.clip_text_embedding,
pooled_text_embedding=target.pooled_text_embedding,
time_ids=target.time_ids,
condition_scale=target.condition_scale,
)

@ -0,0 +1,41 @@
import imaginairy.vendored.refiners.fluxion.layers as fl
from imaginairy.vendored.refiners.fluxion.layers.attentions import ScaledDotProductAttention
from imaginairy.vendored.refiners.foundationals.latent_diffusion.self_attention_guidance import (
SAGAdapter,
SelfAttentionMap,
SelfAttentionShape,
)
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import MiddleBlock, ResidualBlock, SDXLUNet
class SDXLSAGAdapter(SAGAdapter[SDXLUNet]):
def __init__(self, target: SDXLUNet, scale: float = 1.0, kernel_size: int = 9, sigma: float = 1.0) -> None:
super().__init__(
target=target,
scale=scale,
kernel_size=kernel_size,
sigma=sigma,
)
def inject(self: "SDXLSAGAdapter", parent: fl.Chain | None = None) -> "SDXLSAGAdapter":
middle_block = self.target.ensure_find(MiddleBlock)
middle_block.insert_after_type(ResidualBlock, SelfAttentionShape(context_key="middle_block_attn_shape"))
# An alternative would be to replace the ScaledDotProductAttention with a version which records the attention
# scores to avoid computing these scores twice
self_attn = middle_block.ensure_find(fl.SelfAttention)
self_attn.insert_before_type(
ScaledDotProductAttention,
SelfAttentionMap(num_heads=self_attn.num_heads, context_key="middle_block_attn_map"),
)
return super().inject(parent)
def eject(self) -> None:
middle_block = self.target.ensure_find(MiddleBlock)
middle_block.remove(middle_block.ensure_find(SelfAttentionShape))
self_attn = middle_block.ensure_find(fl.SelfAttention)
self_attn.remove(self_attn.ensure_find(SelfAttentionMap))
super().eject()

@ -0,0 +1,48 @@
from torch import Tensor
import imaginairy.vendored.refiners.fluxion.layers as fl
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import ResidualAccumulator
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
from imaginairy.vendored.refiners.foundationals.latent_diffusion.t2i_adapter import ConditionEncoderXL, T2IAdapter, T2IFeatures
class SDXLT2IAdapter(T2IAdapter[SDXLUNet]):
def __init__(
self,
target: SDXLUNet,
name: str,
condition_encoder: ConditionEncoderXL | None = None,
scale: float = 1.0,
weights: dict[str, Tensor] | None = None,
) -> None:
self.residual_indices = (3, 5, 8) # the UNet's middle block is handled separately (see `inject` and `eject`)
self._features = [T2IFeatures(name=name, index=i, scale=scale) for i in range(4)]
super().__init__(
target=target,
name=name,
condition_encoder=condition_encoder or ConditionEncoderXL(device=target.device, dtype=target.dtype),
weights=weights,
)
def inject(self: "SDXLT2IAdapter", parent: fl.Chain | None = None) -> "SDXLT2IAdapter":
def sanity_check_t2i(block: fl.Chain) -> None:
for t2i_layer in block.layers(layer_type=T2IFeatures):
assert t2i_layer.name != self.name, f"T2I-Adapter named {self.name} is already injected"
# Note: `strict=False` because `residual_indices` is shorter than `_features` due to MiddleBlock (see below)
for n, feat in zip(self.residual_indices, self._features, strict=False):
block = self.target.DownBlocks[n]
sanity_check_t2i(block)
block.insert_before_type(ResidualAccumulator, feat)
# Special case: the MiddleBlock has no ResidualAccumulator (this is done via a subsequent layer) so just append
sanity_check_t2i(self.target.MiddleBlock)
self.target.MiddleBlock.append(self._features[-1])
return super().inject(parent)
def eject(self: "SDXLT2IAdapter") -> None:
# See `inject` re: `strict=False`
for n, feat in zip(self.residual_indices, self._features, strict=False):
self.target.DownBlocks[n].remove(feat)
self.target.MiddleBlock.remove(self._features[-1])
super().eject()

@ -0,0 +1,85 @@
from typing import cast
from jaxtyping import Float
from torch import Tensor, cat, device as Device, dtype as DType
import imaginairy.vendored.refiners.fluxion.layers as fl
from imaginairy.vendored.refiners.fluxion.adapters.adapter import Adapter
from imaginairy.vendored.refiners.fluxion.context import Contexts
from imaginairy.vendored.refiners.foundationals.clip.text_encoder import CLIPTextEncoderG, CLIPTextEncoderL
from imaginairy.vendored.refiners.foundationals.clip.tokenizer import CLIPTokenizer
class TextEncoderWithPooling(fl.Chain, Adapter[CLIPTextEncoderG]):
def __init__(
self,
target: CLIPTextEncoderG,
projection: fl.Linear | None = None,
) -> None:
with self.setup_adapter(target=target):
tokenizer = target.ensure_find(CLIPTokenizer)
super().__init__(
tokenizer,
fl.SetContext(
context="text_encoder_pooling", key="end_of_text_index", callback=self.set_end_of_text_index
),
target[1:-2],
fl.Parallel(
fl.Identity(),
fl.Chain(
target[-2:],
projection
or fl.Linear(
in_features=1280, out_features=1280, bias=False, device=target.device, dtype=target.dtype
),
fl.Lambda(func=self.pool),
),
),
)
def init_context(self) -> Contexts:
return {"text_encoder_pooling": {"end_of_text_index": []}}
def __call__(self, text: str) -> tuple[Float[Tensor, "1 77 1280"], Float[Tensor, "1 1280"]]:
return super().__call__(text)
@property
def tokenizer(self) -> CLIPTokenizer:
return self.ensure_find(CLIPTokenizer)
def set_end_of_text_index(self, end_of_text_index: list[int], tokens: Tensor) -> None:
position = (tokens == self.tokenizer.end_of_text_token_id).nonzero(as_tuple=True)[1].item()
end_of_text_index.append(cast(int, position))
def pool(self, x: Float[Tensor, "1 77 1280"]) -> Float[Tensor, "1 1280"]:
end_of_text_index = self.use_context(context_name="text_encoder_pooling").get("end_of_text_index", [])
assert len(end_of_text_index) == 1, "End of text index not found."
return x[:, end_of_text_index[0], :]
class DoubleTextEncoder(fl.Chain):
def __init__(
self,
text_encoder_l: CLIPTextEncoderL | None = None,
text_encoder_g: CLIPTextEncoderG | None = None,
projection: fl.Linear | None = None,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
text_encoder_l = text_encoder_l or CLIPTextEncoderL(device=device, dtype=dtype)
text_encoder_g = text_encoder_g or CLIPTextEncoderG(device=device, dtype=dtype)
super().__init__(
fl.Parallel(text_encoder_l[:-2], text_encoder_g),
fl.Lambda(func=self.concatenate_embeddings),
)
TextEncoderWithPooling(target=text_encoder_g, projection=projection).inject(parent=self.Parallel)
def __call__(self, text: str) -> tuple[Float[Tensor, "1 77 2048"], Float[Tensor, "1 1280"]]:
return super().__call__(text)
def concatenate_embeddings(
self, text_embedding_l: Tensor, text_embedding_with_pooling: tuple[Tensor, Tensor]
) -> tuple[Tensor, Tensor]:
text_embedding_g, pooled_text_embedding = text_embedding_with_pooling
text_embedding = cat(tensors=[text_embedding_l, text_embedding_g], dim=-1)
return text_embedding, pooled_text_embedding

@ -0,0 +1,285 @@
from typing import cast
from torch import Tensor, device as Device, dtype as DType
import imaginairy.vendored.refiners.fluxion.layers as fl
from imaginairy.vendored.refiners.fluxion.context import Contexts
from imaginairy.vendored.refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
from imaginairy.vendored.refiners.foundationals.latent_diffusion.range_adapter import (
RangeAdapter2d,
RangeEncoder,
compute_sinusoidal_embedding,
)
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import (
ResidualAccumulator,
ResidualBlock,
ResidualConcatenator,
)
class TextTimeEmbedding(fl.Chain):
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
self.timestep_embedding_dim = 1280
self.time_ids_embedding_dim = 256
self.text_time_embedding_dim = 2816
super().__init__(
fl.Concatenate(
fl.UseContext(context="diffusion", key="pooled_text_embedding"),
fl.Chain(
fl.UseContext(context="diffusion", key="time_ids"),
fl.Unsqueeze(dim=-1),
fl.Lambda(func=self.compute_sinuosoidal_embedding),
fl.Reshape(-1),
),
dim=1,
),
fl.Converter(set_device=False, set_dtype=True),
fl.Linear(
in_features=self.text_time_embedding_dim,
out_features=self.timestep_embedding_dim,
device=device,
dtype=dtype,
),
fl.SiLU(),
fl.Linear(
in_features=self.timestep_embedding_dim,
out_features=self.timestep_embedding_dim,
device=device,
dtype=dtype,
),
)
def compute_sinuosoidal_embedding(self, x: Tensor) -> Tensor:
return compute_sinusoidal_embedding(x=x, embedding_dim=self.time_ids_embedding_dim)
class TimestepEncoder(fl.Passthrough):
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
self.timestep_embedding_dim = 1280
super().__init__(
fl.Sum(
fl.Chain(
fl.UseContext(context="diffusion", key="timestep"),
RangeEncoder(
sinuosidal_embedding_dim=320,
embedding_dim=self.timestep_embedding_dim,
device=device,
dtype=dtype,
),
),
TextTimeEmbedding(device=device, dtype=dtype),
),
fl.SetContext(context="range_adapter", key="timestep_embedding"),
)
class SDXLCrossAttention(CrossAttentionBlock2d):
def __init__(
self,
channels: int,
num_attention_layers: int = 1,
num_attention_heads: int = 10,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__(
channels=channels,
context_embedding_dim=2048,
context_key="clip_text_embedding",
num_attention_layers=num_attention_layers,
num_attention_heads=num_attention_heads,
use_bias=False,
use_linear_projection=True,
device=device,
dtype=dtype,
)
class DownBlocks(fl.Chain):
def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
self.in_channels = in_channels
in_block = fl.Chain(
fl.Conv2d(in_channels=in_channels, out_channels=320, kernel_size=3, padding=1, device=device, dtype=dtype)
)
first_blocks = [
fl.Chain(
ResidualBlock(in_channels=320, out_channels=320, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=320, out_channels=320, device=device, dtype=dtype),
),
fl.Chain(
fl.Downsample(channels=320, scale_factor=2, padding=1, device=device, dtype=dtype),
),
]
second_blocks = [
fl.Chain(
ResidualBlock(in_channels=320, out_channels=640, device=device, dtype=dtype),
SDXLCrossAttention(
channels=640, num_attention_layers=2, num_attention_heads=10, device=device, dtype=dtype
),
),
fl.Chain(
ResidualBlock(in_channels=640, out_channels=640, device=device, dtype=dtype),
SDXLCrossAttention(
channels=640, num_attention_layers=2, num_attention_heads=10, device=device, dtype=dtype
),
),
fl.Chain(
fl.Downsample(channels=640, scale_factor=2, padding=1, device=device, dtype=dtype),
),
]
third_blocks = [
fl.Chain(
ResidualBlock(in_channels=640, out_channels=1280, device=device, dtype=dtype),
SDXLCrossAttention(
channels=1280, num_attention_layers=10, num_attention_heads=20, device=device, dtype=dtype
),
),
fl.Chain(
ResidualBlock(in_channels=1280, out_channels=1280, device=device, dtype=dtype),
SDXLCrossAttention(
channels=1280, num_attention_layers=10, num_attention_heads=20, device=device, dtype=dtype
),
),
]
super().__init__(
in_block,
*first_blocks,
*second_blocks,
*third_blocks,
)
class UpBlocks(fl.Chain):
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
first_blocks = [
fl.Chain(
ResidualBlock(in_channels=2560, out_channels=1280, device=device, dtype=dtype),
SDXLCrossAttention(
channels=1280, num_attention_layers=10, num_attention_heads=20, device=device, dtype=dtype
),
),
fl.Chain(
ResidualBlock(in_channels=2560, out_channels=1280, device=device, dtype=dtype),
SDXLCrossAttention(
channels=1280, num_attention_layers=10, num_attention_heads=20, device=device, dtype=dtype
),
),
fl.Chain(
ResidualBlock(in_channels=1920, out_channels=1280, device=device, dtype=dtype),
SDXLCrossAttention(
channels=1280, num_attention_layers=10, num_attention_heads=20, device=device, dtype=dtype
),
fl.Upsample(channels=1280, device=device, dtype=dtype),
),
]
second_blocks = [
fl.Chain(
ResidualBlock(in_channels=1920, out_channels=640, device=device, dtype=dtype),
SDXLCrossAttention(
channels=640, num_attention_layers=2, num_attention_heads=10, device=device, dtype=dtype
),
),
fl.Chain(
ResidualBlock(in_channels=1280, out_channels=640, device=device, dtype=dtype),
SDXLCrossAttention(
channels=640, num_attention_layers=2, num_attention_heads=10, device=device, dtype=dtype
),
),
fl.Chain(
ResidualBlock(in_channels=960, out_channels=640, device=device, dtype=dtype),
SDXLCrossAttention(
channels=640, num_attention_layers=2, num_attention_heads=10, device=device, dtype=dtype
),
fl.Upsample(channels=640, device=device, dtype=dtype),
),
]
third_blocks = [
fl.Chain(
ResidualBlock(in_channels=960, out_channels=320, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=640, out_channels=320, device=device, dtype=dtype),
),
fl.Chain(
ResidualBlock(in_channels=640, out_channels=320, device=device, dtype=dtype),
),
]
super().__init__(
*first_blocks,
*second_blocks,
*third_blocks,
)
class MiddleBlock(fl.Chain):
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
super().__init__(
ResidualBlock(in_channels=1280, out_channels=1280, device=device, dtype=dtype),
SDXLCrossAttention(
channels=1280, num_attention_layers=10, num_attention_heads=20, device=device, dtype=dtype
),
ResidualBlock(in_channels=1280, out_channels=1280, device=device, dtype=dtype),
)
class OutputBlock(fl.Chain):
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
super().__init__(
fl.GroupNorm(channels=320, num_groups=32, device=device, dtype=dtype),
fl.SiLU(),
fl.Conv2d(in_channels=320, out_channels=4, kernel_size=3, stride=1, padding=1, device=device, dtype=dtype),
)
class SDXLUNet(fl.Chain):
def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
self.in_channels = in_channels
super().__init__(
TimestepEncoder(device=device, dtype=dtype),
DownBlocks(in_channels=in_channels, device=device, dtype=dtype),
MiddleBlock(device=device, dtype=dtype),
fl.Residual(fl.UseContext(context="unet", key="residuals").compose(lambda x: x[-1])),
UpBlocks(device=device, dtype=dtype),
OutputBlock(device=device, dtype=dtype),
)
for residual_block in self.layers(ResidualBlock):
chain = residual_block.Chain
RangeAdapter2d(
target=chain.Conv2d_1,
channels=residual_block.out_channels,
embedding_dim=1280,
context_key="timestep_embedding",
device=device,
dtype=dtype,
).inject(chain)
for n, block in enumerate(iterable=cast(list[fl.Chain], self.DownBlocks)):
block.append(module=ResidualAccumulator(n=n))
for n, block in enumerate(iterable=cast(list[fl.Chain], self.UpBlocks)):
block.insert(index=0, module=ResidualConcatenator(n=-n - 2))
def init_context(self) -> Contexts:
return {
"unet": {"residuals": [0.0] * 10},
"diffusion": {"timestep": None, "time_ids": None, "pooled_text_embedding": None},
"range_adapter": {"timestep_embedding": None},
"sampling": {"shapes": []},
}
def set_clip_text_embedding(self, clip_text_embedding: Tensor) -> None:
self.set_context(context="cross_attention_block", value={"clip_text_embedding": clip_text_embedding})
def set_timestep(self, timestep: Tensor) -> None:
self.set_context(context="diffusion", value={"timestep": timestep})
def set_time_ids(self, time_ids: Tensor) -> None:
self.set_context(context="diffusion", value={"time_ids": time_ids})
def set_pooled_text_embedding(self, pooled_text_embedding: Tensor) -> None:
self.set_context(context="diffusion", value={"pooled_text_embedding": pooled_text_embedding})

@ -0,0 +1,215 @@
from typing import TYPE_CHECKING, Any, Generic, TypeVar
from torch import Tensor, device as Device, dtype as DType
from torch.nn import AvgPool2d as _AvgPool2d
import imaginairy.vendored.refiners.fluxion.layers as fl
from imaginairy.vendored.refiners.fluxion.adapters.adapter import Adapter
from imaginairy.vendored.refiners.fluxion.context import Contexts
from imaginairy.vendored.refiners.fluxion.layers.module import Module
if TYPE_CHECKING:
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
T = TypeVar("T", bound="SD1UNet | SDXLUNet")
TT2IAdapter = TypeVar("TT2IAdapter", bound="T2IAdapter[Any]") # Self (see PEP 673)
class Downsample2d(_AvgPool2d, Module):
def __init__(self, scale_factor: int) -> None:
_AvgPool2d.__init__(self, kernel_size=scale_factor, stride=scale_factor)
class ResidualBlock(fl.Residual):
def __init__(
self,
channels: int,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__(
fl.Conv2d(
in_channels=channels, out_channels=channels, kernel_size=3, padding=1, device=device, dtype=dtype
),
fl.ReLU(),
fl.Conv2d(in_channels=channels, out_channels=channels, kernel_size=1, device=device, dtype=dtype),
)
class ResidualBlocks(fl.Chain):
def __init__(
self,
in_channels: int,
out_channels: int,
num_residual_blocks: int = 2,
downsample: bool = False,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
preproc = Downsample2d(scale_factor=2) if downsample else fl.Identity()
shortcut = (
fl.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, device=device, dtype=dtype)
if in_channels != out_channels
else fl.Identity()
)
super().__init__(
preproc,
shortcut,
fl.Chain(
ResidualBlock(channels=out_channels, device=device, dtype=dtype) for _ in range(num_residual_blocks)
),
)
class StatefulResidualBlocks(fl.Chain):
def __init__(
self,
in_channels: int,
out_channels: int,
num_residual_blocks: int = 2,
downsample: bool = False,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__(
ResidualBlocks(
in_channels=in_channels,
out_channels=out_channels,
num_residual_blocks=num_residual_blocks,
downsample=downsample,
device=device,
dtype=dtype,
),
fl.SetContext(context="t2iadapter", key="features", callback=self.push),
)
def push(self, features: list[Tensor], x: Tensor) -> None:
features.append(x)
class ConditionEncoder(fl.Chain):
def __init__(
self,
in_channels: int = 3,
channels: tuple[int, int, int, int] = (320, 640, 1280, 1280),
num_residual_blocks: int = 2,
downscale_factor: int = 8,
scale: float = 1.0,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.scale = scale
super().__init__(
fl.PixelUnshuffle(downscale_factor=downscale_factor),
fl.Conv2d(
in_channels=in_channels * downscale_factor**2,
out_channels=channels[0],
kernel_size=3,
padding=1,
device=device,
dtype=dtype,
),
StatefulResidualBlocks(channels[0], channels[0], num_residual_blocks, device=device, dtype=dtype),
*(
StatefulResidualBlocks(
channels[i - 1], channels[i], num_residual_blocks, downsample=True, device=device, dtype=dtype
)
for i in range(1, len(channels))
),
fl.UseContext(context="t2iadapter", key="features"),
)
def init_context(self) -> Contexts:
return {"t2iadapter": {"features": []}}
class ConditionEncoderXL(ConditionEncoder, fl.Chain):
def __init__(
self,
in_channels: int = 3,
channels: tuple[int, int, int, int] = (320, 640, 1280, 1280),
num_residual_blocks: int = 2,
downscale_factor: int = 16,
scale: float = 1.0,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.scale = scale
fl.Chain.__init__(
self,
fl.PixelUnshuffle(downscale_factor=downscale_factor),
fl.Conv2d(
in_channels=in_channels * downscale_factor**2,
out_channels=channels[0],
kernel_size=3,
padding=1,
device=device,
dtype=dtype,
),
StatefulResidualBlocks(channels[0], channels[0], num_residual_blocks, device=device, dtype=dtype),
StatefulResidualBlocks(channels[0], channels[1], num_residual_blocks, device=device, dtype=dtype),
StatefulResidualBlocks(
channels[1], channels[2], num_residual_blocks, downsample=True, device=device, dtype=dtype
),
StatefulResidualBlocks(channels[2], channels[3], num_residual_blocks, device=device, dtype=dtype),
fl.UseContext(context="t2iadapter", key="features"),
)
class T2IFeatures(fl.Residual):
def __init__(self, name: str, index: int, scale: float = 1.0) -> None:
self.name = name
self.index = index
self.scale = scale
super().__init__(
fl.UseContext(context="t2iadapter", key=f"condition_features_{self.name}").compose(
func=lambda features: self.scale * features[self.index]
)
)
class T2IAdapter(Generic[T], fl.Chain, Adapter[T]):
_condition_encoder: list[ConditionEncoder] # prevent PyTorch module registration
_features: list[T2IFeatures] = []
def __init__(
self,
target: T,
name: str,
condition_encoder: ConditionEncoder,
weights: dict[str, Tensor] | None = None,
) -> None:
self.name = name
if weights is not None:
condition_encoder.load_state_dict(weights)
self._condition_encoder = [condition_encoder]
with self.setup_adapter(target):
super().__init__(target)
def inject(self: TT2IAdapter, parent: fl.Chain | None = None) -> TT2IAdapter:
return super().inject(parent)
def eject(self) -> None:
super().eject()
@property
def condition_encoder(self) -> ConditionEncoder:
return self._condition_encoder[0]
def compute_condition_features(self, condition: Tensor) -> tuple[Tensor, ...]:
return self.condition_encoder(condition)
def set_condition_features(self, features: tuple[Tensor, ...]) -> None:
self.set_context("t2iadapter", {f"condition_features_{self.name}": features})
def set_scale(self, scale: float) -> None:
for f in self._features:
f.scale = scale
def init_context(self) -> Contexts:
return {"t2iadapter": {f"condition_features_{self.name}": None}}
def structural_copy(self: "TT2IAdapter") -> "TT2IAdapter":
raise RuntimeError("T2I-Adapter cannot be copied, eject it first.")

@ -0,0 +1,368 @@
import torch
from torch import Tensor, device as Device, dtype as DType, nn
import imaginairy.vendored.refiners.fluxion.layers as fl
from imaginairy.vendored.refiners.fluxion.context import Contexts
from imaginairy.vendored.refiners.fluxion.utils import pad
class PatchEncoder(fl.Chain):
def __init__(
self,
in_channels: int,
out_channels: int,
patch_size: int = 16,
use_bias: bool = True,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.in_channels = in_channels
self.out_channels = out_channels
self.patch_size = patch_size
self.use_bias = use_bias
super().__init__(
fl.Conv2d(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=(self.patch_size, self.patch_size),
stride=(self.patch_size, self.patch_size),
use_bias=self.use_bias,
device=device,
dtype=dtype,
),
fl.Permute(0, 2, 3, 1),
)
class PositionalEncoder(fl.Residual):
def __init__(
self,
embedding_dim: int,
image_embedding_size: tuple[int, int],
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.embedding_dim = embedding_dim
self.image_embedding_size = image_embedding_size
super().__init__(
fl.Parameter(
image_embedding_size[0],
image_embedding_size[1],
embedding_dim,
device=device,
dtype=dtype,
),
)
class RelativePositionAttention(fl.WeightedModule):
def __init__(
self,
embedding_dim: int,
num_heads: int,
spatial_size: tuple[int, int],
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__()
self.embedding_dim = embedding_dim
self.num_heads = num_heads
self.head_dim = embedding_dim // num_heads
self.spatial_size = spatial_size
self.horizontal_embedding = nn.Parameter(
data=torch.zeros(2 * spatial_size[0] - 1, self.head_dim, device=device, dtype=dtype)
)
self.vertical_embedding = nn.Parameter(
data=torch.zeros(2 * spatial_size[1] - 1, self.head_dim, device=device, dtype=dtype)
)
@property
def device(self) -> Device:
return self.horizontal_embedding.device
@property
def dtype(self) -> DType:
return self.horizontal_embedding.dtype
def forward(self, x: Tensor) -> Tensor:
batch, height, width, _ = x.shape
x = (
x.reshape(batch, width * height, 3, self.num_heads, -1)
.permute(2, 0, 3, 1, 4)
.reshape(3, batch * self.num_heads, width * height, -1)
)
query, key, value = x.unbind(dim=0)
horizontal_relative_embedding, vertical_relative_embedding = self.compute_relative_embedding(x=query)
attention = (query * self.head_dim**-0.5) @ key.transpose(dim0=-2, dim1=-1)
# Order of operations is important here
attention = (
(attention.reshape(-1, height, width, height, width) + vertical_relative_embedding)
+ horizontal_relative_embedding
).reshape(attention.shape)
attention = attention.softmax(dim=-1)
attention = attention @ value
attention = (
attention.reshape(batch, self.num_heads, height, width, -1)
.permute(0, 2, 3, 1, 4)
.reshape(batch, height, width, -1)
)
return attention
def compute_relative_coords(self, size: int) -> Tensor:
x, y = torch.meshgrid(torch.arange(end=size), torch.arange(end=size), indexing="ij")
return x - y + size - 1
def compute_relative_embedding(self, x: Tensor) -> tuple[Tensor, Tensor]:
width, height = self.spatial_size
horizontal_coords = self.compute_relative_coords(size=width)
vertical_coords = self.compute_relative_coords(size=height)
horizontal_positional_embedding = self.horizontal_embedding[horizontal_coords]
vertical_positional_embedding = self.vertical_embedding[vertical_coords]
x = x.reshape(x.shape[0], width, height, -1)
horizontal_relative_embedding = torch.einsum("bhwc,wkc->bhwk", x, horizontal_positional_embedding).unsqueeze(
dim=-2
)
vertical_relative_embedding = torch.einsum("bhwc,hkc->bhwk", x, vertical_positional_embedding).unsqueeze(dim=-1)
return horizontal_relative_embedding, vertical_relative_embedding
class FusedSelfAttention(fl.Chain):
def __init__(
self,
embedding_dim: int = 768,
spatial_size: tuple[int, int] = (64, 64),
num_heads: int = 1,
use_bias: bool = True,
is_causal: bool = False,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
assert (
embedding_dim % num_heads == 0
), f"Embedding dim (embedding_dim={embedding_dim}) must be divisible by num heads (num_heads={num_heads})"
self.embedding_dim = embedding_dim
self.num_heads = num_heads
self.use_bias = use_bias
self.is_causal = is_causal
super().__init__(
fl.Linear(
in_features=self.embedding_dim,
out_features=3 * self.embedding_dim,
bias=self.use_bias,
device=device,
dtype=dtype,
),
RelativePositionAttention(
embedding_dim=self.embedding_dim,
num_heads=self.num_heads,
spatial_size=spatial_size,
device=device,
dtype=dtype,
),
fl.Linear(
in_features=self.embedding_dim,
out_features=self.embedding_dim,
bias=True,
device=device,
dtype=dtype,
),
)
class FeedForward(fl.Chain):
def __init__(
self,
embedding_dim: int,
feedforward_dim: int,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.embedding_dim = embedding_dim
self.feedforward_dim = feedforward_dim
super().__init__(
fl.Linear(
in_features=self.embedding_dim,
out_features=self.feedforward_dim,
bias=True,
device=device,
dtype=dtype,
),
fl.GeLU(),
fl.Linear(
in_features=self.feedforward_dim,
out_features=self.embedding_dim,
bias=True,
device=device,
dtype=dtype,
),
)
class WindowPartition(fl.ContextModule):
def __init__(self) -> None:
super().__init__()
def forward(self, x: Tensor) -> Tensor:
batch, height, width, channels = x.shape
context = self.use_context(context_name="window_partition")
context.update({"original_height": height, "original_width": width})
window_size = context["window_size"]
padding_height = (window_size - height % window_size) % window_size
padding_width = (window_size - width % window_size) % window_size
if padding_height > 0 or padding_width > 0:
x = pad(x=x, pad=(0, 0, 0, padding_width, 0, padding_height))
padded_height, padded_width = height + padding_height, width + padding_width
context.update({"padded_height": padded_height, "padded_width": padded_width})
x = x.view(batch, padded_height // window_size, window_size, padded_width // window_size, window_size, channels)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, channels)
return windows
class WindowMerge(fl.ContextModule):
def __init__(self) -> None:
super().__init__()
def forward(self, x: Tensor) -> Tensor:
context = self.use_context(context_name="window_partition")
window_size = context["window_size"]
padded_height, padded_width = context["padded_height"], context["padded_width"]
original_height, original_width = context["original_height"], context["original_width"]
batch_size = x.shape[0] // (padded_height * padded_width // window_size // window_size)
x = x.view(batch_size, padded_height // window_size, padded_width // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, padded_height, padded_width, -1)
if padded_height > original_height or padded_width > original_width:
x = x[:, :original_height, :original_width, :].contiguous()
return x
class TransformerLayer(fl.Chain):
def __init__(
self,
embedding_dim: int,
num_heads: int,
feedforward_dim: int,
image_embedding_size: tuple[int, int],
window_size: int | None = None,
layer_norm_eps: float = 1e-6,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.embedding_dim = embedding_dim
self.num_heads = num_heads
self.feedforward_dim = feedforward_dim
self.window_size = window_size
self.layer_norm_eps = layer_norm_eps
self.image_embedding_size = image_embedding_size
attention_spatial_size = (window_size, window_size) if window_size is not None else image_embedding_size
reshape_or_merge = (
WindowMerge()
if self.window_size is not None
else fl.Reshape(self.image_embedding_size[0], self.image_embedding_size[1], embedding_dim)
)
super().__init__(
fl.Residual(
fl.LayerNorm(normalized_shape=embedding_dim, eps=self.layer_norm_eps, device=device, dtype=dtype),
WindowPartition() if self.window_size is not None else fl.Identity(),
FusedSelfAttention(
embedding_dim=embedding_dim,
num_heads=num_heads,
spatial_size=attention_spatial_size,
device=device,
dtype=dtype,
),
reshape_or_merge,
),
fl.Residual(
fl.LayerNorm(normalized_shape=embedding_dim, eps=self.layer_norm_eps, device=device, dtype=dtype),
FeedForward(embedding_dim=embedding_dim, feedforward_dim=feedforward_dim, device=device, dtype=dtype),
),
)
def init_context(self) -> Contexts:
return {"window_partition": {"window_size": self.window_size}}
class Neck(fl.Chain):
def __init__(self, in_channels: int = 768, device: Device | str | None = None, dtype: DType | None = None) -> None:
self.in_channels = in_channels
super().__init__(
fl.Permute(0, 3, 1, 2),
fl.Conv2d(
in_channels=self.in_channels,
out_channels=256,
kernel_size=1,
use_bias=False,
device=device,
dtype=dtype,
),
fl.LayerNorm2d(channels=256, device=device, dtype=dtype),
fl.Conv2d(
in_channels=256,
out_channels=256,
kernel_size=3,
padding=1,
use_bias=False,
device=device,
dtype=dtype,
),
fl.LayerNorm2d(channels=256, device=device, dtype=dtype),
)
class Transformer(fl.Chain):
pass
class SAMViT(fl.Chain):
def __init__(
self,
embedding_dim: int,
num_layers: int,
num_heads: int,
global_attention_indices: tuple[int, ...] | None = None,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.embedding_dim = embedding_dim
self.num_layers = num_layers
self.num_heads = num_heads
self.image_size = (1024, 1024)
self.patch_size = 16
self.window_size = 14
self.image_embedding_size = (self.image_size[0] // self.patch_size, self.image_size[1] // self.patch_size)
self.feed_forward_dim = 4 * self.embedding_dim
self.global_attention_indices = global_attention_indices or tuple()
super().__init__(
PatchEncoder(
in_channels=3, out_channels=embedding_dim, patch_size=self.patch_size, device=device, dtype=dtype
),
PositionalEncoder(
embedding_dim=embedding_dim, image_embedding_size=self.image_embedding_size, device=device, dtype=dtype
),
Transformer(
TransformerLayer(
embedding_dim=embedding_dim,
num_heads=num_heads,
feedforward_dim=self.feed_forward_dim,
window_size=self.window_size if i not in self.global_attention_indices else None,
image_embedding_size=self.image_embedding_size,
device=device,
dtype=dtype,
)
for i in range(num_layers)
),
Neck(in_channels=embedding_dim, device=device, dtype=dtype),
)
class SAMViTH(SAMViT):
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
super().__init__(
embedding_dim=1280,
num_layers=32,
num_heads=16,
global_attention_indices=(7, 15, 23, 31),
device=device,
dtype=dtype,
)

@ -0,0 +1,264 @@
import torch
from torch import Tensor, device as Device, dtype as DType, nn
import imaginairy.vendored.refiners.fluxion.layers as fl
from imaginairy.vendored.refiners.fluxion.context import Contexts
from imaginairy.vendored.refiners.foundationals.segment_anything.transformer import (
SparseCrossDenseAttention,
TwoWayTranformerLayer,
)
class EmbeddingsAggregator(fl.ContextModule):
def __init__(self, num_output_mask: int = 3) -> None:
super().__init__()
self.num_mask_tokens = num_output_mask
def forward(self, iou_mask_tokens: Tensor) -> Tensor:
mask_decoder = self.ensure_parent
mask_decoder_context = mask_decoder.use_context(context_name="mask_decoder")
image_embedding = mask_decoder_context["image_embedding"]
point_embedding = mask_decoder_context["point_embedding"]
mask_embedding = mask_decoder_context["mask_embedding"]
dense_positional_embedding = mask_decoder_context["dense_positional_embedding"]
sparse_embedding = torch.cat(tensors=(iou_mask_tokens, point_embedding), dim=1)
dense_embedding = (image_embedding + mask_embedding).flatten(start_dim=2).transpose(1, 2)
if dense_positional_embedding.shape != dense_embedding.shape:
dense_positional_embedding = dense_positional_embedding.flatten(start_dim=2).transpose(1, 2)
mask_decoder_context.update(
{
"dense_embedding": dense_embedding,
"dense_positional_embedding": dense_positional_embedding,
"sparse_embedding": sparse_embedding,
}
)
mask_decoder.set_context(context="mask_decoder", value=mask_decoder_context)
return sparse_embedding
class Transformer(fl.Chain):
pass
class Hypernetworks(fl.Concatenate):
def __init__(
self,
embedding_dim: int = 256,
num_layers: int = 3,
num_mask_tokens: int = 3,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__()
self.embedding_dim = embedding_dim
self.num_layers = num_layers
self.num_mask_tokens = num_mask_tokens
super().__init__(
*[
fl.Chain(
fl.Slicing(dim=1, start=i + 1, end=i + 2),
fl.MultiLinear(
input_dim=embedding_dim,
output_dim=embedding_dim // 8,
inner_dim=embedding_dim,
num_layers=num_layers,
device=device,
dtype=dtype,
),
)
for i in range(num_mask_tokens + 1)
],
dim=1,
)
class DenseEmbeddingUpscaling(fl.Chain):
def __init__(
self,
embedding_dim: int = 256,
dense_embedding_side_dim: int = 64,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__()
self.embedding_dim = embedding_dim
self.dense_embedding_side_dim = dense_embedding_side_dim
super().__init__(
fl.UseContext(context="mask_decoder", key="dense_embedding"),
fl.Transpose(dim0=1, dim1=2),
fl.Reshape(embedding_dim, dense_embedding_side_dim, dense_embedding_side_dim),
fl.ConvTranspose2d(
in_channels=embedding_dim,
out_channels=embedding_dim // 4,
kernel_size=2,
stride=2,
device=device,
dtype=dtype,
),
fl.LayerNorm2d(channels=embedding_dim // 4, device=device, dtype=dtype),
fl.GeLU(),
fl.ConvTranspose2d(
in_channels=embedding_dim // 4,
out_channels=embedding_dim // 8,
kernel_size=2,
stride=2,
device=device,
dtype=dtype,
),
fl.GeLU(),
fl.Flatten(start_dim=2),
)
class IOUMaskEncoder(fl.WeightedModule):
def __init__(
self,
embedding_dim: int = 256,
num_mask_tokens: int = 4,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__()
self.embedding_dim = embedding_dim
self.num_mask_tokens = num_mask_tokens
# aka prompt tokens + output token (for IoU scores prediction)
self.weight = nn.Parameter(data=torch.randn(num_mask_tokens + 1, embedding_dim, device=device, dtype=dtype))
def forward(self) -> Tensor:
return self.weight.unsqueeze(dim=0)
class MaskPrediction(fl.Chain):
def __init__(
self,
embedding_dim: int,
num_mask_tokens: int,
num_layers: int = 3,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.embedding_dim = embedding_dim
self.num_mask_tokens = num_mask_tokens
self.num_layers = num_layers
super().__init__(
fl.Matmul(
input=Hypernetworks(
embedding_dim=embedding_dim,
num_layers=num_layers,
num_mask_tokens=num_mask_tokens,
device=device,
dtype=dtype,
),
other=DenseEmbeddingUpscaling(embedding_dim=embedding_dim, device=device, dtype=dtype),
),
fl.Slicing(dim=1, start=1),
fl.Reshape(num_mask_tokens, embedding_dim, embedding_dim),
)
class IOUPrediction(fl.Chain):
def __init__(
self,
embedding_dim: int,
num_layers: int,
num_mask_tokens: int,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.embedding_dim = embedding_dim
self.num_layers = num_layers
super().__init__(
fl.Slicing(dim=1, start=0, end=1),
fl.Squeeze(dim=0),
fl.MultiLinear(
input_dim=embedding_dim,
output_dim=num_mask_tokens + 1,
inner_dim=embedding_dim,
num_layers=num_layers,
device=device,
dtype=dtype,
),
fl.Slicing(dim=-1, start=1),
)
class MaskDecoder(fl.Chain):
def __init__(
self,
embedding_dim: int = 256,
feed_forward_dim: int = 2048,
num_layers: int = 2,
num_output_mask: int = 3,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__()
self.embedding_dim = embedding_dim
self.num_mask_tokens = num_output_mask
self.feed_forward_dim = feed_forward_dim
self.num_layers = num_layers
super().__init__(
IOUMaskEncoder(
embedding_dim=embedding_dim, num_mask_tokens=num_output_mask + 1, device=device, dtype=dtype
),
EmbeddingsAggregator(num_output_mask=num_output_mask),
Transformer(
*(
TwoWayTranformerLayer(
embedding_dim=embedding_dim,
num_heads=8,
feed_forward_dim=feed_forward_dim,
use_residual_self_attention=i > 0,
device=device,
dtype=dtype,
)
for i in range(num_layers)
),
SparseCrossDenseAttention(embedding_dim=embedding_dim, device=device, dtype=dtype),
fl.LayerNorm(normalized_shape=embedding_dim, device=device, dtype=dtype),
),
fl.Parallel(
MaskPrediction(
embedding_dim=embedding_dim, num_mask_tokens=num_output_mask, device=device, dtype=dtype
),
IOUPrediction(
embedding_dim=embedding_dim,
num_layers=3,
num_mask_tokens=num_output_mask,
device=device,
dtype=dtype,
),
),
)
def init_context(self) -> Contexts:
return {
"mask_decoder": {
"image_embedding": None,
"point_embedding": None,
"mask_embedding": None,
"dense_positional_embedding": None,
}
}
def set_image_embedding(self, image_embedding: Tensor) -> None:
mask_decoder_context = self.use_context(context_name="mask_decoder")
mask_decoder_context["image_embedding"] = image_embedding
def set_point_embedding(self, point_embedding: Tensor) -> None:
mask_decoder_context = self.use_context(context_name="mask_decoder")
mask_decoder_context["point_embedding"] = point_embedding
def set_mask_embedding(self, mask_embedding: Tensor) -> None:
mask_decoder_context = self.use_context(context_name="mask_decoder")
mask_decoder_context["mask_embedding"] = mask_embedding
def set_dense_positional_embedding(self, dense_positional_embedding: Tensor) -> None:
mask_decoder_context = self.use_context(context_name="mask_decoder")
mask_decoder_context["dense_positional_embedding"] = dense_positional_embedding

@ -0,0 +1,170 @@
from dataclasses import dataclass
from typing import Sequence
import numpy as np
import torch
from PIL import Image
from torch import Tensor, device as Device, dtype as DType
import imaginairy.vendored.refiners.fluxion.layers as fl
from imaginairy.vendored.refiners.fluxion.utils import image_to_tensor, interpolate, no_grad, normalize, pad
from imaginairy.vendored.refiners.foundationals.segment_anything.image_encoder import SAMViT, SAMViTH
from imaginairy.vendored.refiners.foundationals.segment_anything.mask_decoder import MaskDecoder
from imaginairy.vendored.refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder
@dataclass
class ImageEmbedding:
features: Tensor
original_image_size: tuple[int, int] # (height, width)
class SegmentAnything(fl.Module):
mask_threshold: float = 0.0
def __init__(
self,
image_encoder: SAMViT,
point_encoder: PointEncoder,
mask_encoder: MaskEncoder,
mask_decoder: MaskDecoder,
device: Device | str = "cpu",
dtype: DType = torch.float32,
) -> None:
super().__init__()
self.device: Device = device if isinstance(device, Device) else Device(device=device)
self.dtype = dtype
self.image_encoder = image_encoder.to(device=self.device, dtype=self.dtype)
self.point_encoder = point_encoder.to(device=self.device, dtype=self.dtype)
self.mask_encoder = mask_encoder.to(device=self.device, dtype=self.dtype)
self.mask_decoder = mask_decoder.to(device=self.device, dtype=self.dtype)
@no_grad()
def compute_image_embedding(self, image: Image.Image) -> ImageEmbedding:
original_size = (image.height, image.width)
target_size = self.compute_target_size(original_size)
return ImageEmbedding(
features=self.image_encoder(self.preprocess_image(image=image, target_size=target_size)),
original_image_size=original_size,
)
@no_grad()
def predict(
self,
input: Image.Image | ImageEmbedding,
foreground_points: Sequence[tuple[float, float]] | None = None,
background_points: Sequence[tuple[float, float]] | None = None,
box_points: Sequence[Sequence[tuple[float, float]]] | None = None,
masks: Sequence[Image.Image] | None = None,
binarize: bool = True,
) -> tuple[Tensor, Tensor, Tensor]:
if isinstance(input, ImageEmbedding):
original_size = input.original_image_size
target_size = self.compute_target_size(original_size)
image_embedding = input.features
else:
original_size = (input.height, input.width)
target_size = self.compute_target_size(original_size)
image_embedding = self.image_encoder(self.preprocess_image(image=input, target_size=target_size))
coordinates, type_mask = self.point_encoder.points_to_tensor(
foreground_points=foreground_points,
background_points=background_points,
box_points=box_points,
)
self.point_encoder.set_type_mask(type_mask=type_mask)
if masks is not None:
mask_tensor = torch.stack(
tensors=[image_to_tensor(image=mask, device=self.device, dtype=self.dtype) for mask in masks]
)
mask_embedding = self.mask_encoder(mask_tensor)
else:
mask_embedding = self.mask_encoder.get_no_mask_dense_embedding(
image_embedding_size=self.image_encoder.image_embedding_size
)
point_embedding = self.point_encoder(
self.normalize(coordinates, target_size=target_size, original_size=original_size)
)
dense_positional_embedding = self.point_encoder.get_dense_positional_embedding(
image_embedding_size=self.image_encoder.image_embedding_size
)
self.mask_decoder.set_image_embedding(image_embedding=image_embedding)
self.mask_decoder.set_mask_embedding(mask_embedding=mask_embedding)
self.mask_decoder.set_point_embedding(point_embedding=point_embedding)
self.mask_decoder.set_dense_positional_embedding(dense_positional_embedding=dense_positional_embedding)
low_res_masks, iou_predictions = self.mask_decoder()
high_res_masks = self.postprocess_masks(
masks=low_res_masks, target_size=target_size, original_size=original_size
)
if binarize:
high_res_masks = high_res_masks > self.mask_threshold
return high_res_masks, iou_predictions, low_res_masks
@property
def image_size(self) -> int:
w, h = self.image_encoder.image_size
assert w == h
return w
def compute_target_size(self, size: tuple[int, int]) -> tuple[int, int]:
oldh, oldw = size
scale = self.image_size * 1.0 / max(oldh, oldw)
newh, neww = oldh * scale, oldw * scale
neww = int(neww + 0.5)
newh = int(newh + 0.5)
return (newh, neww)
def preprocess_image(self, image: Image.Image, target_size: tuple[int, int]) -> Tensor:
h, w = target_size
padh = self.image_size - h
padw = self.image_size - w
image_tensor = torch.tensor(
np.array(image.resize((w, h), resample=Image.Resampling.BILINEAR)).astype(np.float32).transpose(2, 0, 1),
device=self.device,
dtype=self.dtype,
).unsqueeze(0)
return pad(
normalize(image_tensor, mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]), (0, padw, 0, padh)
)
def normalize(self, coordinates: Tensor, target_size: tuple[int, int], original_size: tuple[int, int]) -> Tensor:
coordinates[:, :, 0] = ((coordinates[:, :, 0] * (target_size[1] / original_size[1])) + 0.5) / self.image_size
coordinates[:, :, 1] = ((coordinates[:, :, 1] * (target_size[0] / original_size[0])) + 0.5) / self.image_size
return coordinates
def postprocess_masks(self, masks: Tensor, target_size: tuple[int, int], original_size: tuple[int, int]) -> Tensor:
masks = interpolate(masks, factor=torch.Size((self.image_size, self.image_size)), mode="bilinear")
masks = masks[..., : target_size[0], : target_size[1]] # remove padding added at `preprocess_image` time
masks = interpolate(masks, factor=torch.Size(original_size), mode="bilinear")
return masks
class SegmentAnythingH(SegmentAnything):
def __init__(
self,
image_encoder: SAMViTH | None = None,
point_encoder: PointEncoder | None = None,
mask_encoder: MaskEncoder | None = None,
mask_decoder: MaskDecoder | None = None,
device: Device | str = "cpu",
dtype: DType = torch.float32,
) -> None:
image_encoder = image_encoder or SAMViTH()
point_encoder = point_encoder or PointEncoder()
mask_encoder = mask_encoder or MaskEncoder()
mask_decoder = mask_decoder or MaskDecoder()
super().__init__(
image_encoder=image_encoder,
point_encoder=point_encoder,
mask_encoder=mask_encoder,
mask_decoder=mask_decoder,
device=device,
dtype=dtype,
)

@ -0,0 +1,192 @@
from collections.abc import Sequence
from enum import Enum, auto
import torch
from jaxtyping import Float, Int
from torch import Tensor, device as Device, dtype as DType, nn
import imaginairy.vendored.refiners.fluxion.layers as fl
from imaginairy.vendored.refiners.fluxion.context import Contexts
class CoordinateEncoder(fl.Chain):
def __init__(
self,
num_positional_features: int = 64,
scale: float = 1,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.num_positional_features = num_positional_features
self.scale = scale
super().__init__(
fl.Multiply(scale=2, bias=-1),
fl.Linear(in_features=2, out_features=num_positional_features, bias=False, device=device, dtype=dtype),
fl.Multiply(scale=2 * torch.pi * self.scale),
fl.Concatenate(fl.Sin(), fl.Cos(), dim=-1),
)
class PointType(Enum):
BACKGROUND = auto()
FOREGROUND = auto()
BOX_TOP_LEFT = auto()
BOX_BOTTOM_RIGHT = auto()
NOT_A_POINT = auto()
class PointTypeEmbedding(fl.WeightedModule, fl.ContextModule):
def __init__(self, embedding_dim: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
super().__init__()
self.embedding_dim = embedding_dim
self.weight = nn.Parameter(data=torch.randn(len(PointType), self.embedding_dim, device=device, dtype=dtype))
def forward(self, type_mask: Int[Tensor, "1 num_points"]) -> Float[Tensor, "1 num_points embedding_dim"]:
assert isinstance(type_mask, Tensor), "type_mask must be a Tensor."
embeddings = torch.zeros(*type_mask.shape, self.embedding_dim).to(device=type_mask.device)
for type_id in PointType:
mask = type_mask == type_id.value
embeddings[mask] = self.weight[type_id.value - 1]
return embeddings
class PointEncoder(fl.Chain):
def __init__(
self, embedding_dim: int = 256, scale: float = 1, device: Device | str | None = None, dtype: DType | None = None
) -> None:
assert embedding_dim % 2 == 0, "embedding_dim must be divisible by 2."
self.embedding_dim = embedding_dim
self.scale = scale
super().__init__(
CoordinateEncoder(num_positional_features=embedding_dim // 2, scale=scale, device=device, dtype=dtype),
fl.Lambda(func=self.pad),
fl.Residual(
fl.UseContext(context="point_encoder", key="type_mask"),
PointTypeEmbedding(embedding_dim=embedding_dim, device=device, dtype=dtype),
),
)
def pad(self, x: Tensor) -> Tensor:
type_mask: Tensor = self.use_context("point_encoder")["type_mask"]
if torch.any((type_mask == PointType.BOX_TOP_LEFT.value) | (type_mask == PointType.BOX_BOTTOM_RIGHT.value)):
# Some boxes have been passed: no need to pad in this case
return x
type_mask = torch.cat(
[type_mask, torch.full((type_mask.shape[0], 1), PointType.NOT_A_POINT.value, device=type_mask.device)],
dim=1,
)
self.set_context(context="point_encoder", value={"type_mask": type_mask})
return torch.cat([x, torch.zeros((x.shape[0], 1, x.shape[-1]), device=x.device)], dim=1)
def init_context(self) -> Contexts:
return {
"point_encoder": {
"type_mask": None,
}
}
def set_type_mask(self, type_mask: Int[Tensor, "1 num_points"]) -> None:
self.set_context(context="point_encoder", value={"type_mask": type_mask})
def get_dense_positional_embedding(
self, image_embedding_size: tuple[int, int]
) -> Float[Tensor, "num_positional_features height width"]:
coordinate_encoder = self.ensure_find(layer_type=CoordinateEncoder)
height, width = image_embedding_size
grid = torch.ones((height, width), device=self.device, dtype=torch.float32)
y_embedding = grid.cumsum(dim=0) - 0.5
x_embedding = grid.cumsum(dim=1) - 0.5
y_embedding = y_embedding / height
x_embedding = x_embedding / width
positional_embedding = (
coordinate_encoder(torch.stack(tensors=[x_embedding, y_embedding], dim=-1))
.permute(2, 0, 1)
.unsqueeze(dim=0)
)
return positional_embedding
def points_to_tensor(
self,
foreground_points: Sequence[tuple[float, float]] | None = None,
background_points: Sequence[tuple[float, float]] | None = None,
not_a_points: Sequence[tuple[float, float]] | None = None,
box_points: Sequence[Sequence[tuple[float, float]]] | None = None,
) -> tuple[Float[Tensor, "1 num_points 2"], Int[Tensor, "1 num_points"]]:
foreground_points = foreground_points or []
background_points = background_points or []
not_a_points = not_a_points or []
box_points = box_points or []
top_left_points = [box[0] for box in box_points]
bottom_right_points = [box[1] for box in box_points]
coordinates: list[Tensor] = []
type_ids: list[Tensor] = []
# Must be in sync with PointType enum
for type_id, coords_seq in zip(
PointType, [background_points, foreground_points, top_left_points, bottom_right_points, not_a_points]
):
if len(coords_seq) > 0:
coords_tensor = torch.tensor(data=list(coords_seq), dtype=torch.float, device=self.device)
coordinates.append(coords_tensor)
point_ids = torch.tensor(data=[type_id.value] * len(coords_seq), dtype=torch.int, device=self.device)
type_ids.append(point_ids)
all_coordinates = torch.cat(tensors=coordinates, dim=0).unsqueeze(dim=0)
type_mask = torch.cat(tensors=type_ids, dim=0).unsqueeze(dim=0)
return all_coordinates, type_mask
class MaskEncoder(fl.Chain):
def __init__(
self,
embedding_dim: int = 256,
intermediate_channels: int = 16,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.embedding_dim = embedding_dim
self.intermediate_channels = intermediate_channels
super().__init__(
fl.Conv2d(
in_channels=1,
out_channels=self.intermediate_channels // 4,
kernel_size=2,
stride=2,
device=device,
dtype=dtype,
),
fl.LayerNorm2d(channels=self.intermediate_channels // 4, device=device, dtype=dtype),
fl.GeLU(),
fl.Conv2d(
in_channels=self.intermediate_channels // 4,
out_channels=self.intermediate_channels,
kernel_size=2,
stride=2,
device=device,
dtype=dtype,
),
fl.LayerNorm2d(channels=self.intermediate_channels, device=device, dtype=dtype),
fl.GeLU(),
fl.Conv2d(
in_channels=self.intermediate_channels,
out_channels=self.embedding_dim,
kernel_size=1,
device=device,
dtype=dtype,
),
)
self.register_parameter(
"no_mask_embedding", nn.Parameter(torch.randn(1, embedding_dim, device=device, dtype=dtype))
)
def get_no_mask_dense_embedding(
self, image_embedding_size: tuple[int, int], batch_size: int = 1
) -> Float[Tensor, "batch embedding_dim image_embedding_height image_embedding_width"]:
return self.no_mask_embedding.reshape(1, -1, 1, 1).expand(
batch_size, -1, image_embedding_size[0], image_embedding_size[1]
)

@ -0,0 +1,158 @@
from torch import device as Device, dtype as DType
import imaginairy.vendored.refiners.fluxion.layers as fl
class CrossAttention(fl.Attention):
def __init__(
self,
embedding_dim: int,
cross_embedding_dim: int | None = None,
num_heads: int = 1,
inner_dim: int | None = None,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__(
embedding_dim=embedding_dim,
key_embedding_dim=cross_embedding_dim,
num_heads=num_heads,
inner_dim=inner_dim,
is_optimized=False,
device=device,
dtype=dtype,
)
self.cross_embedding_dim = cross_embedding_dim or embedding_dim
self.insert(index=0, module=fl.Parallel(fl.GetArg(index=0), fl.GetArg(index=1), fl.GetArg(index=1)))
class FeedForward(fl.Residual):
def __init__(
self, embedding_dim: int, feed_forward_dim: int, device: Device | str | None = None, dtype: DType | None = None
) -> None:
self.embedding_dim = embedding_dim
self.feed_forward_dim = feed_forward_dim
super().__init__(
fl.Linear(in_features=embedding_dim, out_features=feed_forward_dim, device=device, dtype=dtype),
fl.ReLU(),
fl.Linear(in_features=feed_forward_dim, out_features=embedding_dim, device=device, dtype=dtype),
)
class SparseSelfAttention(fl.Residual):
def __init__(
self,
embedding_dim: int,
inner_dim: int | None = None,
num_heads: int = 1,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
add_sparse_embedding = fl.Residual(fl.UseContext(context="mask_decoder", key="sparse_embedding"))
super().__init__(
fl.Parallel(add_sparse_embedding, add_sparse_embedding, fl.Identity()),
fl.Attention(
embedding_dim=embedding_dim,
inner_dim=inner_dim,
num_heads=num_heads,
is_optimized=False,
device=device,
dtype=dtype,
),
)
class SparseCrossDenseAttention(fl.Residual):
def __init__(
self, embedding_dim: int, num_heads: int = 8, device: Device | str | None = None, dtype: DType | None = None
) -> None:
self.embedding_dim = embedding_dim
self.num_heads = num_heads
super().__init__(
fl.Parallel(
fl.Residual(
fl.UseContext(context="mask_decoder", key="sparse_embedding"),
),
fl.Sum(
fl.UseContext(context="mask_decoder", key="dense_embedding"),
fl.UseContext(context="mask_decoder", key="dense_positional_embedding"),
),
fl.UseContext(context="mask_decoder", key="dense_embedding"),
),
fl.Attention(
embedding_dim=embedding_dim,
inner_dim=embedding_dim // 2,
num_heads=num_heads,
is_optimized=False,
device=device,
dtype=dtype,
),
)
class DenseCrossSparseAttention(fl.Chain):
def __init__(
self, embedding_dim: int, num_heads: int = 8, device: Device | str | None = None, dtype: DType | None = None
) -> None:
super().__init__(
fl.Parallel(
fl.Sum(
fl.UseContext(context="mask_decoder", key="dense_embedding"),
fl.UseContext(context="mask_decoder", key="dense_positional_embedding"),
),
fl.Residual(
fl.UseContext(context="mask_decoder", key="sparse_embedding"),
),
fl.Identity(),
),
fl.Attention(
embedding_dim=embedding_dim,
inner_dim=embedding_dim // 2,
num_heads=num_heads,
is_optimized=False,
device=device,
dtype=dtype,
),
)
class TwoWayTranformerLayer(fl.Chain):
def __init__(
self,
embedding_dim: int,
num_heads: int = 8,
feed_forward_dim: int = 2048,
use_residual_self_attention: bool = True,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.embedding_dim = embedding_dim
self.num_heads = num_heads
self.feed_forward_dim = feed_forward_dim
self_attention = (
SparseSelfAttention(embedding_dim=embedding_dim, num_heads=num_heads, device=device, dtype=dtype)
if use_residual_self_attention
else fl.SelfAttention(
embedding_dim=embedding_dim, num_heads=num_heads, is_optimized=False, device=device, dtype=dtype
)
)
super().__init__(
self_attention,
fl.LayerNorm(normalized_shape=embedding_dim, device=device, dtype=dtype),
SparseCrossDenseAttention(embedding_dim=embedding_dim, num_heads=num_heads, device=device, dtype=dtype),
fl.LayerNorm(normalized_shape=embedding_dim, device=device, dtype=dtype),
FeedForward(embedding_dim=embedding_dim, feed_forward_dim=feed_forward_dim, device=device, dtype=dtype),
fl.LayerNorm(normalized_shape=embedding_dim, device=device, dtype=dtype),
fl.Passthrough(
fl.Sum(
fl.UseContext(context="mask_decoder", key="dense_embedding"),
DenseCrossSparseAttention(
embedding_dim=embedding_dim, num_heads=num_heads, device=device, dtype=dtype
),
),
fl.LayerNorm(normalized_shape=embedding_dim, device=device, dtype=dtype),
fl.SetContext(context="mask_decoder", key="dense_embedding"),
),
)

@ -0,0 +1 @@
vendored from git@github.com:finegrain-ai/refiners.git @ 20c229903f53d05dc1c44659ec97603660ef964c

@ -94,7 +94,7 @@ importlib-metadata==7.0.1
iniconfig==2.0.0
# via pytest
jaxtyping==0.2.25
# via refiners
# via imaginAIry (setup.py)
jinja2==3.1.2
# via torch
kiwisolver==1.4.5
@ -133,7 +133,6 @@ numpy==1.24.4
# matplotlib
# numba
# opencv-python
# refiners
# scipy
# torchvision
# transformers
@ -160,7 +159,6 @@ pillow==10.2.0
# imageio
# imaginAIry (setup.py)
# matplotlib
# refiners
# torchvision
pluggy==1.3.0
# via pytest
@ -199,8 +197,6 @@ pyyaml==6.0.1
# responses
# timm
# transformers
refiners==0.2.0
# via imaginAIry (setup.py)
regex==2023.12.25
# via
# diffusers
@ -216,13 +212,12 @@ requests==2.31.0
# transformers
responses==0.24.1
# via -r requirements-dev.in
ruff==0.1.9
ruff==0.1.11
# via -r requirements-dev.in
safetensors==0.3.3
safetensors==0.4.1
# via
# diffusers
# imaginAIry (setup.py)
# refiners
# timm
# transformers
scipy==1.10.1
@ -262,7 +257,6 @@ torch==2.1.2
# imaginAIry (setup.py)
# kornia
# open-clip-torch
# refiners
# timm
# torchdiffeq
# torchvision

@ -95,9 +95,10 @@ setup(
# need to migration to 2.0
"pydantic>=2.3.0",
"requests>=2.28.1",
"refiners>=0.2.0",
# "refiners>=0.2.0",
"jaxtyping>=0.2.23", # refiners dependency
"einops>=0.3.0",
"safetensors>=0.2.1",
"safetensors>=0.4.0",
# scipy is a sub dependency but v1.11 doesn't support python 3.8. https://docs.scipy.org/doc/scipy/dev/toolchain.html#numpy
"scipy<1.11",
"timm>=0.4.12,!=0.9.0,!=0.9.1", # for vendored blip

Loading…
Cancel
Save