diff --git a/Makefile b/Makefile index e4ce240..b56d193 100644 --- a/Makefile +++ b/Makefile @@ -210,7 +210,7 @@ vendorize_normal_map: vendorize_refiners: - export REPO=git@github.com:finegrain-ai/refiners.git PKG=refiners COMMIT=20c229903f53d05dc1c44659ec97603660ef964c && \ + export REPO=git@github.com:finegrain-ai/refiners.git PKG=refiners COMMIT=ce3035923ba71bcb5044708d2f1c37fd1d6722e9 && \ make download_repo REPO=$$REPO PKG=$$PKG COMMIT=$$COMMIT && \ mkdir -p ./imaginairy/vendored/$$PKG && \ rm -rf ./imaginairy/vendored/$$PKG/* && \ diff --git a/imaginairy/vendored/refiners/fluxion/adapters/lora.py b/imaginairy/vendored/refiners/fluxion/adapters/lora.py index 3dd4c0b..4ee5d3c 100644 --- a/imaginairy/vendored/refiners/fluxion/adapters/lora.py +++ b/imaginairy/vendored/refiners/fluxion/adapters/lora.py @@ -1,4 +1,5 @@ -from typing import Any, Generic, Iterable, TypeVar +from abc import ABC, abstractmethod +from typing import Any from torch import Tensor, device as Device, dtype as DType from torch.nn import Parameter as TorchParameter @@ -6,125 +7,259 @@ from torch.nn.init import normal_, zeros_ import imaginairy.vendored.refiners.fluxion.layers as fl from imaginairy.vendored.refiners.fluxion.adapters.adapter import Adapter +from imaginairy.vendored.refiners.fluxion.layers.chain import Chain -T = TypeVar("T", bound=fl.Chain) -TLoraAdapter = TypeVar("TLoraAdapter", bound="LoraAdapter[Any]") # Self (see PEP 673) - -class Lora(fl.Chain): +class Lora(fl.Chain, ABC): def __init__( self, - in_features: int, - out_features: int, rank: int = 16, + scale: float = 1.0, device: Device | str | None = None, dtype: DType | None = None, ) -> None: - self.in_features = in_features - self.out_features = out_features self.rank = rank - self.scale: float = 1.0 + self._scale = scale - super().__init__( - fl.Linear(in_features=in_features, out_features=rank, bias=False, device=device, dtype=dtype), - fl.Linear(in_features=rank, out_features=out_features, bias=False, device=device, dtype=dtype), - fl.Lambda(func=self.scale_outputs), - ) + super().__init__(*self.lora_layers(device=device, dtype=dtype), fl.Multiply(scale)) - normal_(tensor=self.Linear_1.weight, std=1 / self.rank) - zeros_(tensor=self.Linear_2.weight) + normal_(tensor=self.down.weight, std=1 / self.rank) + zeros_(tensor=self.up.weight) - def scale_outputs(self, x: Tensor) -> Tensor: - return x * self.scale + @abstractmethod + def lora_layers( + self, device: Device | str | None = None, dtype: DType | None = None + ) -> tuple[fl.WeightedModule, fl.WeightedModule]: + ... - def set_scale(self, scale: float) -> None: - self.scale = scale - - def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None: - self.Linear_1.weight = TorchParameter(down_weight.to(device=self.device, dtype=self.dtype)) - self.Linear_2.weight = TorchParameter(up_weight.to(device=self.device, dtype=self.dtype)) + @property + def down(self) -> fl.WeightedModule: + down_layer = self[0] + assert isinstance(down_layer, fl.WeightedModule) + return down_layer @property - def up_weight(self) -> Tensor: - return self.Linear_2.weight.data + def up(self) -> fl.WeightedModule: + up_layer = self[1] + assert isinstance(up_layer, fl.WeightedModule) + return up_layer @property - def down_weight(self) -> Tensor: - return self.Linear_1.weight.data + def scale(self) -> float: + return self._scale + + @scale.setter + def scale(self, value: float) -> None: + self._scale = value + self.ensure_find(fl.Multiply).scale = value + + @classmethod + def from_weights( + cls, + down: Tensor, + up: Tensor, + ) -> "Lora": + match (up.ndim, down.ndim): + case (2, 2): + return LinearLora.from_weights(up=up, down=down) + case (4, 4): + return Conv2dLora.from_weights(up=up, down=down) + case _: + raise ValueError(f"Unsupported weight shapes: up={up.shape}, down={down.shape}") + + @classmethod + def from_dict(cls, state_dict: dict[str, Tensor], /) -> dict[str, "Lora"]: + """ + Create a dictionary of LoRA layers from a state dict. + Expects the state dict to be a succession of down and up weights. + """ + state_dict = {k: v for k, v in state_dict.items() if ".weight" in k} + loras: dict[str, Lora] = {} + for down_key, down_tensor, up_tensor in zip( + list(state_dict.keys())[::2], list(state_dict.values())[::2], list(state_dict.values())[1::2] + ): + key = ".".join(down_key.split(".")[:-2]) + loras[key] = cls.from_weights(down=down_tensor, up=up_tensor) + return loras -class SingleLoraAdapter(fl.Sum, Adapter[fl.Linear]): + @abstractmethod + def auto_attach(self, target: fl.Chain, exclude: list[str] | None = None) -> Any: + ... + + def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None: + assert down_weight.shape == self.down.weight.shape + assert up_weight.shape == self.up.weight.shape + self.down.weight = TorchParameter(down_weight.to(device=self.device, dtype=self.dtype)) + self.up.weight = TorchParameter(up_weight.to(device=self.device, dtype=self.dtype)) + + +class LinearLora(Lora): def __init__( self, - target: fl.Linear, + in_features: int, + out_features: int, rank: int = 16, scale: float = 1.0, + device: Device | str | None = None, + dtype: DType | None = None, ) -> None: - self.in_features = target.in_features - self.out_features = target.out_features - self.rank = rank - self.scale = scale - with self.setup_adapter(target): - super().__init__( - target, - Lora( - in_features=target.in_features, - out_features=target.out_features, - rank=rank, - device=target.device, - dtype=target.dtype, - ), - ) - self.Lora.set_scale(scale=scale) - - -class LoraAdapter(Generic[T], fl.Chain, Adapter[T]): + self.in_features = in_features + self.out_features = out_features + + super().__init__(rank=rank, scale=scale, device=device, dtype=dtype) + + @classmethod + def from_weights( + cls, + down: Tensor, + up: Tensor, + ) -> "LinearLora": + assert up.ndim == 2 and down.ndim == 2 + assert down.shape[0] == up.shape[1], f"Rank mismatch: down rank={down.shape[0]} and up rank={up.shape[1]}" + lora = cls( + in_features=down.shape[1], out_features=up.shape[0], rank=down.shape[0], device=up.device, dtype=up.dtype + ) + lora.load_weights(down_weight=down, up_weight=up) + return lora + + def auto_attach(self, target: Chain, exclude: list[str] | None = None) -> "tuple[LoraAdapter, fl.Chain] | None": + for layer, parent in target.walk(fl.Linear): + if isinstance(parent, Lora) or isinstance(parent, LoraAdapter): + continue + + if exclude is not None and any( + [any([p.__class__.__name__ == e for p in parent.get_parents() + [parent]]) for e in exclude] + ): + continue + + if layer.in_features == self.in_features and layer.out_features == self.out_features: + return LoraAdapter(target=layer, lora=self), parent + + def lora_layers( + self, device: Device | str | None = None, dtype: DType | None = None + ) -> tuple[fl.Linear, fl.Linear]: + return ( + fl.Linear( + in_features=self.in_features, + out_features=self.rank, + bias=False, + device=device, + dtype=dtype, + ), + fl.Linear( + in_features=self.rank, + out_features=self.out_features, + bias=False, + device=device, + dtype=dtype, + ), + ) + + +class Conv2dLora(Lora): def __init__( self, - target: T, - sub_targets: Iterable[tuple[fl.Linear, fl.Chain]], - rank: int | None = None, + in_channels: int, + out_channels: int, + rank: int = 16, scale: float = 1.0, - weights: list[Tensor] | None = None, + kernel_size: tuple[int, int] = (1, 3), + stride: tuple[int, int] = (1, 1), + padding: tuple[int, int] = (0, 1), + device: Device | str | None = None, + dtype: DType | None = None, ) -> None: + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + + super().__init__(rank=rank, scale=scale, device=device, dtype=dtype) + + @classmethod + def from_weights( + cls, + down: Tensor, + up: Tensor, + ) -> "Conv2dLora": + assert up.ndim == 4 and down.ndim == 4 + assert down.shape[0] == up.shape[1], f"Rank mismatch: down rank={down.shape[0]} and up rank={up.shape[1]}" + down_kernel_size, up_kernel_size = down.shape[2], up.shape[2] + down_padding = 1 if down_kernel_size == 3 else 0 + up_padding = 1 if up_kernel_size == 3 else 0 + lora = cls( + in_channels=down.shape[1], + out_channels=up.shape[0], + rank=down.shape[0], + kernel_size=(down_kernel_size, up_kernel_size), + padding=(down_padding, up_padding), + device=up.device, + dtype=up.dtype, + ) + lora.load_weights(down_weight=down, up_weight=up) + return lora + + def auto_attach(self, target: Chain, exclude: list[str] | None = None) -> "tuple[LoraAdapter, fl.Chain] | None": + for layer, parent in target.walk(fl.Conv2d): + if isinstance(parent, Lora) or isinstance(parent, LoraAdapter): + continue + + if exclude is not None and any( + [any([p.__class__.__name__ == e for p in parent.get_parents() + [parent]]) for e in exclude] + ): + continue + + if layer.in_channels == self.in_channels and layer.out_channels == self.out_channels: + if layer.stride != (self.stride[0], self.stride[0]): + self.down.stride = layer.stride + + return LoraAdapter( + target=layer, + lora=self, + ), parent + + def lora_layers( + self, device: Device | str | None = None, dtype: DType | None = None + ) -> tuple[fl.Conv2d, fl.Conv2d]: + return ( + fl.Conv2d( + in_channels=self.in_channels, + out_channels=self.rank, + kernel_size=self.kernel_size[0], + stride=self.stride[0], + padding=self.padding[0], + use_bias=False, + device=device, + dtype=dtype, + ), + fl.Conv2d( + in_channels=self.rank, + out_channels=self.out_channels, + kernel_size=self.kernel_size[1], + stride=self.stride[1], + padding=self.padding[1], + use_bias=False, + device=device, + dtype=dtype, + ), + ) + + +class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]): + def __init__(self, target: fl.WeightedModule, lora: Lora) -> None: with self.setup_adapter(target): - super().__init__(target) - - if weights is not None: - assert len(weights) % 2 == 0 - weights_rank = weights[0].shape[1] - if rank is None: - rank = weights_rank - else: - assert rank == weights_rank - - assert rank is not None, "either pass a rank or weights" - - self.sub_targets = sub_targets - self.sub_adapters: list[tuple[SingleLoraAdapter, fl.Chain]] = [] - - for linear, parent in self.sub_targets: - self.sub_adapters.append((SingleLoraAdapter(target=linear, rank=rank, scale=scale), parent)) - - if weights is not None: - assert len(self.sub_adapters) == (len(weights) // 2) - for i, (adapter, _) in enumerate(self.sub_adapters): - lora = adapter.Lora - assert ( - lora.rank == weights[i * 2].shape[1] - ), f"Rank of Lora layer {lora.rank} must match shape of weights {weights[i*2].shape[1]}" - adapter.Lora.load_weights(up_weight=weights[i * 2], down_weight=weights[i * 2 + 1]) - - def inject(self: TLoraAdapter, parent: fl.Chain | None = None) -> TLoraAdapter: - for adapter, adapter_parent in self.sub_adapters: - adapter.inject(adapter_parent) - return super().inject(parent) - - def eject(self) -> None: - for adapter, _ in self.sub_adapters: - adapter.eject() - super().eject() + super().__init__(target, lora) @property - def weights(self) -> list[Tensor]: - return [w for adapter, _ in self.sub_adapters for w in [adapter.Lora.up_weight, adapter.Lora.down_weight]] + def lora(self) -> Lora: + return self.ensure_find(Lora) + + @property + def scale(self) -> float: + return self.lora.scale + + @scale.setter + def scale(self, value: float) -> None: + self.lora.scale = value diff --git a/imaginairy/vendored/refiners/fluxion/layers/module.py b/imaginairy/vendored/refiners/fluxion/layers/module.py index 390a581..b7ac475 100644 --- a/imaginairy/vendored/refiners/fluxion/layers/module.py +++ b/imaginairy/vendored/refiners/fluxion/layers/module.py @@ -164,6 +164,9 @@ class WeightedModule(Module): def dtype(self) -> DType: return self.weight.dtype + def __str__(self) -> str: + return f"{super().__str__().removesuffix(')')}, device={self.device}, dtype={str(self.dtype).removeprefix('torch.')})" + class TreeNode(TypedDict): value: str diff --git a/imaginairy/vendored/refiners/fluxion/utils.py b/imaginairy/vendored/refiners/fluxion/utils.py index deb0d46..09d9c06 100644 --- a/imaginairy/vendored/refiners/fluxion/utils.py +++ b/imaginairy/vendored/refiners/fluxion/utils.py @@ -146,6 +146,7 @@ def tensor_to_image(tensor: Tensor) -> Image.Image: assert tensor.ndim == 4 and tensor.shape[0] == 1, f"Unsupported tensor shape: {tensor.shape}" num_channels = tensor.shape[1] tensor = tensor.clamp(0, 1).squeeze(0) + tensor = tensor.to(torch.float32) # to avoid numpy error with bfloat16 match num_channels: case 1: @@ -187,20 +188,26 @@ def save_to_safetensors(path: Path | str, tensors: dict[str, Tensor], metadata: def summarize_tensor(tensor: torch.Tensor, /) -> str: - return ( - "Tensor(" - + ", ".join( + info_list = [ + f"shape=({', '.join(map(str, tensor.shape))})", + f"dtype={str(object=tensor.dtype).removeprefix('torch.')}", + f"device={tensor.device}", + ] + if not tensor.is_complex(): + info_list.extend( [ - f"shape=({', '.join(map(str, tensor.shape))})", - f"dtype={str(object=tensor.dtype).removeprefix('torch.')}", - f"device={tensor.device}", f"min={tensor.min():.2f}", # type: ignore f"max={tensor.max():.2f}", # type: ignore - f"mean={tensor.mean():.2f}", - f"std={tensor.std():.2f}", - f"norm={norm(x=tensor):.2f}", - f"grad={tensor.requires_grad}", ] ) - + ")" + + info_list.extend( + [ + f"mean={tensor.float().mean():.2f}", + f"std={tensor.float().std():.2f}", + f"norm={norm(x=tensor.float()):.2f}", + f"grad={tensor.requires_grad}", + ] ) + + return "Tensor(" + ", ".join(info_list) + ")" diff --git a/imaginairy/vendored/refiners/foundationals/latent_diffusion/image_prompt.py b/imaginairy/vendored/refiners/foundationals/latent_diffusion/image_prompt.py index fe3253b..a8b7af5 100644 --- a/imaginairy/vendored/refiners/foundationals/latent_diffusion/image_prompt.py +++ b/imaginairy/vendored/refiners/foundationals/latent_diffusion/image_prompt.py @@ -1,15 +1,12 @@ import math -from enum import IntEnum -from functools import partial -from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar +from typing import TYPE_CHECKING, Any, Generic, TypeVar from jaxtyping import Float from PIL import Image -from torch import Tensor, cat, device as Device, dtype as DType, softmax, zeros_like +from torch import Tensor, cat, device as Device, dtype as DType, nn, softmax, zeros_like import imaginairy.vendored.refiners.fluxion.layers as fl from imaginairy.vendored.refiners.fluxion.adapters.adapter import Adapter -from imaginairy.vendored.refiners.fluxion.adapters.lora import Lora from imaginairy.vendored.refiners.fluxion.context import Contexts from imaginairy.vendored.refiners.fluxion.layers.attentions import ScaledDotProductAttention from imaginairy.vendored.refiners.fluxion.utils import image_to_tensor, normalize @@ -236,120 +233,99 @@ class PerceiverResampler(fl.Chain): return {"perceiver_resampler": {"x": None}} -class _CrossAttnIndex(IntEnum): - TXT_CROSS_ATTN = 0 # text cross-attention - IMG_CROSS_ATTN = 1 # image cross-attention +class ImageCrossAttention(fl.Chain): + def __init__(self, text_cross_attention: fl.Attention, scale: float = 1.0) -> None: + self._scale = scale + super().__init__( + fl.Distribute( + fl.Identity(), + fl.Chain( + fl.UseContext(context="ip_adapter", key="clip_image_embedding"), + fl.Linear( + in_features=text_cross_attention.key_embedding_dim, + out_features=text_cross_attention.inner_dim, + bias=text_cross_attention.use_bias, + device=text_cross_attention.device, + dtype=text_cross_attention.dtype, + ), + ), + fl.Chain( + fl.UseContext(context="ip_adapter", key="clip_image_embedding"), + fl.Linear( + in_features=text_cross_attention.value_embedding_dim, + out_features=text_cross_attention.inner_dim, + bias=text_cross_attention.use_bias, + device=text_cross_attention.device, + dtype=text_cross_attention.dtype, + ), + ), + ), + ScaledDotProductAttention( + num_heads=text_cross_attention.num_heads, is_causal=text_cross_attention.is_causal + ), + fl.Multiply(self.scale), + ) + @property + def scale(self) -> float: + return self._scale -class InjectionPoint(fl.Chain): - pass + @scale.setter + def scale(self, value: float) -> None: + self._scale = value + self.ensure_find(fl.Multiply).scale = value class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]): def __init__( self, target: fl.Attention, - text_sequence_length: int = 77, - image_sequence_length: int = 4, scale: float = 1.0, ) -> None: - self.text_sequence_length = text_sequence_length - self.image_sequence_length = image_sequence_length - self.scale = scale - + self._scale = scale with self.setup_adapter(target): - super().__init__( - fl.Distribute( - # Note: the same query is used for image cross-attention as for text cross-attention - InjectionPoint(), # Wq - fl.Parallel( - fl.Chain( - fl.Slicing(dim=1, end=text_sequence_length), - InjectionPoint(), # Wk - ), - fl.Chain( - fl.Slicing(dim=1, start=text_sequence_length), - fl.Linear( - in_features=self.target.key_embedding_dim, - out_features=self.target.inner_dim, - bias=self.target.use_bias, - device=target.device, - dtype=target.dtype, - ), # Wk' - ), - ), - fl.Parallel( - fl.Chain( - fl.Slicing(dim=1, end=text_sequence_length), - InjectionPoint(), # Wv - ), - fl.Chain( - fl.Slicing(dim=1, start=text_sequence_length), - fl.Linear( - in_features=self.target.key_embedding_dim, - out_features=self.target.inner_dim, - bias=self.target.use_bias, - device=target.device, - dtype=target.dtype, - ), # Wv' - ), - ), - ), - fl.Sum( - fl.Chain( - fl.Lambda(func=partial(self.select_qkv, index=_CrossAttnIndex.TXT_CROSS_ATTN)), - ScaledDotProductAttention(num_heads=target.num_heads, is_causal=target.is_causal), - ), - fl.Chain( - fl.Lambda(func=partial(self.select_qkv, index=_CrossAttnIndex.IMG_CROSS_ATTN)), - ScaledDotProductAttention(num_heads=target.num_heads, is_causal=target.is_causal), - fl.Lambda(func=self.scale_outputs), - ), + clone = target.structural_copy() + scaled_dot_product = clone.ensure_find(ScaledDotProductAttention) + image_cross_attention = ImageCrossAttention( + text_cross_attention=clone, + scale=self.scale, + ) + clone.replace( + old_module=scaled_dot_product, + new_module=fl.Sum( + scaled_dot_product, + image_cross_attention, ), - InjectionPoint(), # proj + ) + super().__init__( + clone, ) - def select_qkv( - self, query: Tensor, keys: tuple[Tensor, Tensor], values: tuple[Tensor, Tensor], index: _CrossAttnIndex - ) -> tuple[Tensor, Tensor, Tensor]: - return (query, keys[index.value], values[index.value]) - - def scale_outputs(self, x: Tensor) -> Tensor: - return x * self.scale - - def _predicate(self, k: type[fl.Module]) -> Callable[[fl.Module, fl.Chain], bool]: - def f(m: fl.Module, _: fl.Chain) -> bool: - if isinstance(m, Lora): # do not adapt LoRAs - raise StopIteration - return isinstance(m, k) - - return f - - def _target_linears(self) -> list[fl.Linear]: - return [m for m, _ in self.target.walk(self._predicate(fl.Linear)) if isinstance(m, fl.Linear)] - - def inject(self: "CrossAttentionAdapter", parent: fl.Chain | None = None) -> "CrossAttentionAdapter": - linears = self._target_linears() - assert len(linears) == 4 # Wq, Wk, Wv and Proj - - injection_points = list(self.layers(InjectionPoint)) - assert len(injection_points) == 4 + @property + def image_cross_attention(self) -> ImageCrossAttention: + return self.ensure_find(ImageCrossAttention) - for linear, ip in zip(linears, injection_points): - ip.append(linear) - assert len(ip) == 1 + @property + def image_key_projection(self) -> fl.Linear: + return self.image_cross_attention.Distribute[1].Linear - return super().inject(parent) + @property + def image_value_projection(self) -> fl.Linear: + return self.image_cross_attention.Distribute[2].Linear - def eject(self) -> None: - injection_points = list(self.layers(InjectionPoint)) - assert len(injection_points) == 4 + @property + def scale(self) -> float: + return self._scale - for ip in injection_points: - ip.pop() - assert len(ip) == 0 + @scale.setter + def scale(self, value: float) -> None: + self._scale = value + self.image_cross_attention.scale = value - super().eject() + def load_weights(self, key_tensor: Tensor, value_tensor: Tensor) -> None: + self.image_key_projection.weight = nn.Parameter(key_tensor) + self.image_value_projection.weight = nn.Parameter(value_tensor) + self.image_cross_attention.to(self.device, self.dtype) class IPAdapter(Generic[T], fl.Chain, Adapter[T]): @@ -377,7 +353,7 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]): self._image_proj = [image_proj] self.sub_adapters = [ - CrossAttentionAdapter(target=cross_attn, scale=scale, image_sequence_length=self.image_proj.num_tokens) + CrossAttentionAdapter(target=cross_attn, scale=scale) for cross_attn in filter(lambda attn: type(attn) != fl.SelfAttention, target.layers(fl.Attention)) ] @@ -388,14 +364,15 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]): self.image_proj.load_state_dict(image_proj_state_dict) for i, cross_attn in enumerate(self.sub_adapters): - cross_attn_state_dict: dict[str, Tensor] = {} + cross_attention_weights: list[Tensor] = [] for k, v in weights.items(): prefix = f"ip_adapter.{i:03d}." if not k.startswith(prefix): continue - cross_attn_state_dict[k.removeprefix(prefix)] = v + cross_attention_weights.append(v) - cross_attn.load_state_dict(state_dict=cross_attn_state_dict) + assert len(cross_attention_weights) == 2 + cross_attn.load_weights(*cross_attention_weights) @property def clip_image_encoder(self) -> CLIPImageEncoderH: @@ -420,10 +397,22 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]): adapter.eject() super().eject() + @property + def scale(self) -> float: + return self.sub_adapters[0].scale + + @scale.setter + def scale(self, value: float) -> None: + for cross_attn in self.sub_adapters: + cross_attn.scale = value + def set_scale(self, scale: float) -> None: for cross_attn in self.sub_adapters: cross_attn.scale = scale + def set_clip_image_embedding(self, image_embedding: Tensor) -> None: + self.set_context("ip_adapter", {"clip_image_embedding": image_embedding}) + # These should be concatenated to the CLIP text embedding before setting the UNet context def compute_clip_image_embedding(self, image_prompt: Tensor) -> Tensor: image_encoder = self.clip_image_encoder if not self.fine_grained else self.grid_image_encoder diff --git a/imaginairy/vendored/refiners/foundationals/latent_diffusion/lora.py b/imaginairy/vendored/refiners/foundationals/latent_diffusion/lora.py index ed48270..592c244 100644 --- a/imaginairy/vendored/refiners/foundationals/latent_diffusion/lora.py +++ b/imaginairy/vendored/refiners/foundationals/latent_diffusion/lora.py @@ -1,146 +1,140 @@ -from enum import Enum -from pathlib import Path -from typing import Callable, Iterator +from warnings import warn from torch import Tensor import imaginairy.vendored.refiners.fluxion.layers as fl -from imaginairy.vendored.refiners.fluxion.adapters.adapter import Adapter from imaginairy.vendored.refiners.fluxion.adapters.lora import Lora, LoraAdapter -from imaginairy.vendored.refiners.fluxion.utils import load_from_safetensors, load_metadata_from_safetensors -from imaginairy.vendored.refiners.foundationals.clip.text_encoder import FeedForward, TransformerLayer -from imaginairy.vendored.refiners.foundationals.latent_diffusion import ( - CLIPTextEncoderL, - LatentDiffusionAutoencoder, - SD1UNet, - StableDiffusion_1, -) -from imaginairy.vendored.refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d -from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import Controlnet - -MODELS = ["unet", "text_encoder", "lda"] - - -class LoraTarget(str, Enum): - Self = "self" - Attention = "Attention" - SelfAttention = "SelfAttention" - CrossAttention = "CrossAttentionBlock2d" - FeedForward = "FeedForward" - TransformerLayer = "TransformerLayer" - - def get_class(self) -> type[fl.Chain]: - match self: - case LoraTarget.Self: - return fl.Chain - case LoraTarget.Attention: - return fl.Attention - case LoraTarget.SelfAttention: - return fl.SelfAttention - case LoraTarget.CrossAttention: - return CrossAttentionBlock2d - case LoraTarget.FeedForward: - return FeedForward - case LoraTarget.TransformerLayer: - return TransformerLayer - - -def _predicate(k: type[fl.Module]) -> Callable[[fl.Module, fl.Chain], bool]: - def f(m: fl.Module, _: fl.Chain) -> bool: - if isinstance(m, Lora): # do not adapt other LoRAs - raise StopIteration - if isinstance(m, Controlnet): # do not adapt Controlnet linears - raise StopIteration - return isinstance(m, k) - - return f - - -def _iter_linears(module: fl.Chain) -> Iterator[tuple[fl.Linear, fl.Chain]]: - for m, p in module.walk(_predicate(fl.Linear)): - assert isinstance(m, fl.Linear) - yield (m, p) - - -def lora_targets( - module: fl.Chain, - target: LoraTarget | list[LoraTarget], -) -> Iterator[tuple[fl.Linear, fl.Chain]]: - if isinstance(target, list): - for t in target: - yield from lora_targets(module, t) - return - - if target == LoraTarget.Self: - yield from _iter_linears(module) - return - - for layer, _ in module.walk(_predicate(target.get_class())): - assert isinstance(layer, fl.Chain) - yield from _iter_linears(layer) - - -class SD1LoraAdapter(fl.Chain, Adapter[StableDiffusion_1]): - metadata: dict[str, str] | None - tensors: dict[str, Tensor] +from imaginairy.vendored.refiners.foundationals.latent_diffusion.model import LatentDiffusionModel + +class SDLoraManager: def __init__( self, - target: StableDiffusion_1, - sub_targets: dict[str, list[LoraTarget]], - scale: float = 1.0, - weights: dict[str, Tensor] | None = None, - ): - with self.setup_adapter(target): - super().__init__(target) - - self.sub_adapters: list[LoraAdapter[SD1UNet | CLIPTextEncoderL | LatentDiffusionAutoencoder]] = [] - - for model_name in MODELS: - if not (model_targets := sub_targets.get(model_name, [])): - continue - model = getattr(target, "clip_text_encoder" if model_name == "text_encoder" else model_name) - - lora_weights = [weights[k] for k in sorted(weights) if k.startswith(model_name)] if weights else None - self.sub_adapters.append( - LoraAdapter[type(model)]( - model, - sub_targets=lora_targets(model, model_targets), - scale=scale, - weights=lora_weights, - ) - ) - - @classmethod - def from_safetensors( - cls, - target: StableDiffusion_1, - checkpoint_path: Path | str, + target: LatentDiffusionModel, + ) -> None: + self.target = target + + @property + def unet(self) -> fl.Chain: + unet = self.target.unet + assert isinstance(unet, fl.Chain) + return unet + + @property + def clip_text_encoder(self) -> fl.Chain: + clip_text_encoder = self.target.clip_text_encoder + assert isinstance(clip_text_encoder, fl.Chain) + return clip_text_encoder + + def load( + self, + tensors: dict[str, Tensor], + /, scale: float = 1.0, - ): - metadata = load_metadata_from_safetensors(checkpoint_path) - assert metadata is not None, "Invalid safetensors checkpoint: missing metadata" - tensors = load_from_safetensors(checkpoint_path, device=target.device) - - sub_targets: dict[str, list[LoraTarget]] = {} - for model_name in MODELS: - if not (v := metadata.get(f"{model_name}_targets", "")): - continue - sub_targets[model_name] = [LoraTarget(x) for x in v.split(",")] - - return cls( - target, - sub_targets, - scale=scale, - weights=tensors, + ) -> None: + """Load the LoRA weights from a dictionary of tensors. + + Expects the keys to be in the commonly found formats on CivitAI's hub. + """ + assert len(self.lora_adapters) == 0, "Loras already loaded" + loras = Lora.from_dict( + {key: value.to(device=self.target.device, dtype=self.target.dtype) for key, value in tensors.items()} ) - - def inject(self: "SD1LoraAdapter", parent: fl.Chain | None = None) -> "SD1LoraAdapter": - for adapter in self.sub_adapters: - adapter.inject() - return super().inject(parent) - - def eject(self) -> None: - for adapter in self.sub_adapters: - adapter.eject() - super().eject() + loras = {key: loras[key] for key in sorted(loras.keys(), key=SDLoraManager.sort_keys)} + + # if no key contains "unet" or "text", assume all keys are for the unet + if not "unet" in loras and not "text" in loras: + loras = {f"unet_{key}": loras[key] for key in loras.keys()} + + self.load_unet(loras) + self.load_text_encoder(loras) + + self.scale = scale + + def load_text_encoder(self, loras: dict[str, Lora], /) -> None: + text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key} + SDLoraManager.auto_attach(text_encoder_loras, self.clip_text_encoder) + + def load_unet(self, loras: dict[str, Lora], /) -> None: + unet_loras = {key: loras[key] for key in loras.keys() if "unet" in key} + exclude: list[str] = [] + exclude = [ + self.unet_exclusions[exclusion] + for exclusion in self.unet_exclusions + if all([exclusion not in key for key in unet_loras.keys()]) + ] + SDLoraManager.auto_attach(unet_loras, self.unet, exclude=exclude) + + def unload(self) -> None: + for lora_adapter in self.lora_adapters: + lora_adapter.eject() + + @property + def loras(self) -> list[Lora]: + return list(self.unet.layers(Lora)) + list(self.clip_text_encoder.layers(Lora)) + + @property + def lora_adapters(self) -> list[LoraAdapter]: + return list(self.unet.layers(LoraAdapter)) + list(self.clip_text_encoder.layers(LoraAdapter)) + + @property + def unet_exclusions(self) -> dict[str, str]: + return { + "time": "TimestepEncoder", + "res": "ResidualBlock", + "downsample": "DownsampleBlock", + "upsample": "UpsampleBlock", + } + + @property + def scale(self) -> float: + assert len(self.loras) > 0, "No loras found" + assert all([lora.scale == self.loras[0].scale for lora in self.loras]) + return self.loras[0].scale + + @scale.setter + def scale(self, value: float) -> None: + for lora in self.loras: + lora.scale = value + + @staticmethod + def pad(input: str, /, padding_length: int = 2) -> str: + new_split: list[str] = [] + for s in input.split("_"): + if s.isdigit(): + new_split.append(s.zfill(padding_length)) + else: + new_split.append(s) + return "_".join(new_split) + + @staticmethod + def sort_keys(key: str, /) -> tuple[str, int]: + # out0 happens sometimes as an alias for out ; this dict might not be exhaustive + key_char_order = {"q": 1, "k": 2, "v": 3, "out": 4, "out0": 4} + + for i, s in enumerate(key.split("_")): + if s in key_char_order: + prefix = SDLoraManager.pad("_".join(key.split("_")[:i])) + return (prefix, key_char_order[s]) + + return (SDLoraManager.pad(key), 5) + + @staticmethod + def auto_attach( + loras: dict[str, Lora], + target: fl.Chain, + /, + exclude: list[str] | None = None, + ) -> None: + failed_loras: dict[str, Lora] = {} + for key, lora in loras.items(): + if attach := lora.auto_attach(target, exclude=exclude): + adapter, parent = attach + adapter.inject(parent) + else: + failed_loras[key] = lora + + if failed_loras: + warn(f"failed to attach {len(failed_loras)}/{len(loras)} loras to {target.__class__.__name__}") + + # TODO: add a stronger sanity check to make sure loras are attached correctly diff --git a/imaginairy/vendored/refiners/foundationals/latent_diffusion/model.py b/imaginairy/vendored/refiners/foundationals/latent_diffusion/model.py index 8627dc3..1b80848 100644 --- a/imaginairy/vendored/refiners/foundationals/latent_diffusion/model.py +++ b/imaginairy/vendored/refiners/foundationals/latent_diffusion/model.py @@ -11,7 +11,6 @@ from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.sche T = TypeVar("T", bound="fl.Module") - TLatentDiffusionModel = TypeVar("TLatentDiffusionModel", bound="LatentDiffusionModel") @@ -91,6 +90,8 @@ class LatentDiffusionModel(fl.Module, ABC): self.set_unet_context(timestep=timestep, clip_text_embedding=clip_text_embedding, **kwargs) latents = torch.cat(tensors=(x, x)) # for classifier-free guidance + # scale latents for schedulers that need it + latents = self.scheduler.scale_model_input(latents, step=step) unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2) # classifier-free guidance diff --git a/imaginairy/vendored/refiners/foundationals/latent_diffusion/schedulers/__init__.py b/imaginairy/vendored/refiners/foundationals/latent_diffusion/schedulers/__init__.py index 8c88eb9..b7ed728 100644 --- a/imaginairy/vendored/refiners/foundationals/latent_diffusion/schedulers/__init__.py +++ b/imaginairy/vendored/refiners/foundationals/latent_diffusion/schedulers/__init__.py @@ -1,11 +1,7 @@ from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.ddpm import DDPM from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver +from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.euler import EulerScheduler from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler -__all__ = [ - "Scheduler", - "DPMSolver", - "DDPM", - "DDIM", -] +__all__ = ["Scheduler", "DPMSolver", "DDPM", "DDIM", "EulerScheduler"] diff --git a/imaginairy/vendored/refiners/foundationals/latent_diffusion/schedulers/ddim.py b/imaginairy/vendored/refiners/foundationals/latent_diffusion/schedulers/ddim.py index 7ddc5c5..34c5e6b 100644 --- a/imaginairy/vendored/refiners/foundationals/latent_diffusion/schedulers/ddim.py +++ b/imaginairy/vendored/refiners/foundationals/latent_diffusion/schedulers/ddim.py @@ -1,4 +1,4 @@ -from torch import Tensor, arange, device as Device, dtype as Dtype, float32, sqrt, tensor +from torch import Generator, Tensor, arange, device as Device, dtype as Dtype, float32, sqrt, tensor from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler @@ -34,7 +34,7 @@ class DDIM(Scheduler): timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio + 1 return timesteps.flip(0) - def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor: + def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor: timestep, previous_timestep = ( self.timesteps[step], ( @@ -52,6 +52,12 @@ class DDIM(Scheduler): ), ) predicted_x = (x - sqrt(1 - current_scale_factor**2) * noise) / current_scale_factor - denoised_x = previous_scale_factor * predicted_x + sqrt(1 - previous_scale_factor**2) * noise + noise_factor = sqrt(1 - previous_scale_factor**2) + + # Do not add noise at the last step to avoid visual artifacts. + if step == self.num_inference_steps - 1: + noise_factor = 0 + + denoised_x = previous_scale_factor * predicted_x + noise_factor * noise return denoised_x diff --git a/imaginairy/vendored/refiners/foundationals/latent_diffusion/schedulers/ddpm.py b/imaginairy/vendored/refiners/foundationals/latent_diffusion/schedulers/ddpm.py index 7ae26d0..40873b3 100644 --- a/imaginairy/vendored/refiners/foundationals/latent_diffusion/schedulers/ddpm.py +++ b/imaginairy/vendored/refiners/foundationals/latent_diffusion/schedulers/ddpm.py @@ -1,4 +1,4 @@ -from torch import Tensor, arange, device as Device +from torch import Generator, Tensor, arange, device as Device from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler @@ -30,5 +30,5 @@ class DDPM(Scheduler): timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio return timesteps.flip(0) - def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor: + def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor: raise NotImplementedError diff --git a/imaginairy/vendored/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py b/imaginairy/vendored/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py index 8ddf510..5df3df1 100644 --- a/imaginairy/vendored/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py +++ b/imaginairy/vendored/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py @@ -1,15 +1,19 @@ from collections import deque import numpy as np -from torch import Tensor, device as Device, dtype as Dtype, exp, float32, tensor +from torch import Generator, Tensor, device as Device, dtype as Dtype, exp, float32, tensor from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler class DPMSolver(Scheduler): - """Implements DPM-Solver++ from https://arxiv.org/abs/2211.01095 + """ + Implements DPM-Solver++ from https://arxiv.org/abs/2211.01095 - We only support noise prediction for now. + Regarding last_step_first_order: DPM-Solver++ is known to introduce artifacts + when used with SDXL and few steps. This parameter is a way to mitigate that + effect by using a first-order (Euler) update instead of a second-order update + for the last step of the diffusion. """ def __init__( @@ -18,6 +22,7 @@ class DPMSolver(Scheduler): num_train_timesteps: int = 1_000, initial_diffusion_rate: float = 8.5e-4, final_diffusion_rate: float = 1.2e-2, + last_step_first_order: bool = False, noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC, device: Device | str = "cpu", dtype: Dtype = float32, @@ -32,7 +37,8 @@ class DPMSolver(Scheduler): dtype=dtype, ) self.estimated_data = deque([tensor([])] * 2, maxlen=2) - self.initial_steps = 0 + self.last_step_first_order = last_step_first_order + self._first_step_has_been_run = False def _generate_timesteps(self) -> Tensor: # We need to use numpy here because: @@ -45,57 +51,48 @@ class DPMSolver(Scheduler): ).flip(0) def dpm_solver_first_order_update(self, x: Tensor, noise: Tensor, step: int) -> Tensor: - timestep, previous_timestep = ( - self.timesteps[step], - self.timesteps[step + 1 if step < len(self.timesteps) - 1 else 0], - ) - previous_ratio, current_ratio = ( - self.signal_to_noise_ratios[previous_timestep], - self.signal_to_noise_ratios[timestep], - ) + current_timestep = self.timesteps[step] + previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else tensor([0]) + + previous_ratio = self.signal_to_noise_ratios[previous_timestep] + current_ratio = self.signal_to_noise_ratios[current_timestep] + previous_scale_factor = self.cumulative_scale_factors[previous_timestep] - previous_noise_std, current_noise_std = ( - self.noise_std[previous_timestep], - self.noise_std[timestep], - ) + + previous_noise_std = self.noise_std[previous_timestep] + current_noise_std = self.noise_std[current_timestep] + factor = exp(-(previous_ratio - current_ratio)) - 1.0 denoised_x = (previous_noise_std / current_noise_std) * x - (factor * previous_scale_factor) * noise return denoised_x def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int) -> Tensor: - previous_timestep, current_timestep, next_timestep = ( - self.timesteps[step + 1] if step < len(self.timesteps) - 1 else tensor([0]), - self.timesteps[step], - self.timesteps[step - 1], - ) - current_data_estimation, next_data_estimation = self.estimated_data[-1], self.estimated_data[-2] - previous_ratio, current_ratio, next_ratio = ( - self.signal_to_noise_ratios[previous_timestep], - self.signal_to_noise_ratios[current_timestep], - self.signal_to_noise_ratios[next_timestep], - ) + previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else tensor([0]) + current_timestep = self.timesteps[step] + next_timestep = self.timesteps[step - 1] + + current_data_estimation = self.estimated_data[-1] + next_data_estimation = self.estimated_data[-2] + + previous_ratio = self.signal_to_noise_ratios[previous_timestep] + current_ratio = self.signal_to_noise_ratios[current_timestep] + next_ratio = self.signal_to_noise_ratios[next_timestep] + previous_scale_factor = self.cumulative_scale_factors[previous_timestep] - previous_std, current_std = ( - self.noise_std[previous_timestep], - self.noise_std[current_timestep], - ) + previous_noise_std = self.noise_std[previous_timestep] + current_noise_std = self.noise_std[current_timestep] estimation_delta = (current_data_estimation - next_data_estimation) / ( (current_ratio - next_ratio) / (previous_ratio - current_ratio) ) factor = exp(-(previous_ratio - current_ratio)) - 1.0 denoised_x = ( - (previous_std / current_std) * x + (previous_noise_std / current_noise_std) * x - (factor * previous_scale_factor) * current_data_estimation - 0.5 * (factor * previous_scale_factor) * estimation_delta ) return denoised_x - def __call__( - self, - x: Tensor, - noise: Tensor, - step: int, - ) -> Tensor: + def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor: """ Represents one step of the backward diffusion process that iteratively denoises the input data `x`. @@ -107,11 +104,9 @@ class DPMSolver(Scheduler): scale_factor, noise_ratio = self.cumulative_scale_factors[current_timestep], self.noise_std[current_timestep] estimated_denoised_data = (x - noise_ratio * noise) / scale_factor self.estimated_data.append(estimated_denoised_data) - denoised_x = ( - self.dpm_solver_first_order_update(x=x, noise=estimated_denoised_data, step=step) - if (self.initial_steps == 0) - else self.multistep_dpm_solver_second_order_update(x=x, step=step) - ) - if self.initial_steps < 2: - self.initial_steps += 1 - return denoised_x + + if step == 0 or (self.last_step_first_order and step == self.num_inference_steps - 1) or not self._first_step_has_been_run: + self._first_step_has_been_run = True + return self.dpm_solver_first_order_update(x=x, noise=estimated_denoised_data, step=step) + + return self.multistep_dpm_solver_second_order_update(x=x, step=step) diff --git a/imaginairy/vendored/refiners/foundationals/latent_diffusion/schedulers/euler.py b/imaginairy/vendored/refiners/foundationals/latent_diffusion/schedulers/euler.py new file mode 100644 index 0000000..3cc22d0 --- /dev/null +++ b/imaginairy/vendored/refiners/foundationals/latent_diffusion/schedulers/euler.py @@ -0,0 +1,84 @@ +import numpy as np +import torch +from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor + +from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler + + +class EulerScheduler(Scheduler): + def __init__( + self, + num_inference_steps: int, + num_train_timesteps: int = 1_000, + initial_diffusion_rate: float = 8.5e-4, + final_diffusion_rate: float = 1.2e-2, + noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC, + device: Device | str = "cpu", + dtype: Dtype = float32, + ): + if noise_schedule != NoiseSchedule.QUADRATIC: + raise NotImplementedError + super().__init__( + num_inference_steps=num_inference_steps, + num_train_timesteps=num_train_timesteps, + initial_diffusion_rate=initial_diffusion_rate, + final_diffusion_rate=final_diffusion_rate, + noise_schedule=noise_schedule, + device=device, + dtype=dtype, + ) + self.sigmas = self._generate_sigmas() + + @property + def init_noise_sigma(self) -> Tensor: + return self.sigmas.max() + + def _generate_timesteps(self) -> Tensor: + # We need to use numpy here because: + # numpy.linspace(0,999,31)[15] is 499.49999999999994 + # torch.linspace(0,999,31)[15] is 499.5 + # ...and we want the same result as the original codebase. + timesteps = torch.tensor( + np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps), dtype=self.dtype, device=self.device + ).flip(0) + return timesteps + + def _generate_sigmas(self) -> Tensor: + sigmas = self.noise_std / self.cumulative_scale_factors + sigmas = torch.tensor(np.interp(self.timesteps.cpu().numpy(), np.arange(0, len(sigmas)), sigmas.cpu().numpy())) + sigmas = torch.cat([sigmas, tensor([0.0])]) + return sigmas.to(device=self.device, dtype=self.dtype) + + def scale_model_input(self, x: Tensor, step: int) -> Tensor: + sigma = self.sigmas[step] + return x / ((sigma**2 + 1) ** 0.5) + + def __call__( + self, + x: Tensor, + noise: Tensor, + step: int, + generator: Generator | None = None, + s_churn: float = 0.0, + s_tmin: float = 0.0, + s_tmax: float = float("inf"), + s_noise: float = 1.0, + ) -> Tensor: + sigma = self.sigmas[step] + + gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0 + + alt_noise = torch.randn(noise.shape, generator=generator, device=noise.device, dtype=noise.dtype) + eps = alt_noise * s_noise + sigma_hat = sigma * (gamma + 1) + if gamma > 0: + x = x + eps * (sigma_hat**2 - sigma**2) ** 0.5 + + predicted_x = x - sigma_hat * noise + + # 1st order Euler + derivative = (x - predicted_x) / sigma_hat + dt = self.sigmas[step + 1] - sigma_hat + denoised_x = x + derivative * dt + + return denoised_x diff --git a/imaginairy/vendored/refiners/foundationals/latent_diffusion/schedulers/scheduler.py b/imaginairy/vendored/refiners/foundationals/latent_diffusion/schedulers/scheduler.py index abf106c..f64a4cc 100644 --- a/imaginairy/vendored/refiners/foundationals/latent_diffusion/schedulers/scheduler.py +++ b/imaginairy/vendored/refiners/foundationals/latent_diffusion/schedulers/scheduler.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from enum import Enum from typing import TypeVar -from torch import Tensor, device as Device, dtype as DType, float32, linspace, log, sqrt +from torch import Generator, Tensor, device as Device, dtype as DType, float32, linspace, log, sqrt T = TypeVar("T", bound="Scheduler") @@ -50,7 +50,7 @@ class Scheduler(ABC): self.timesteps = self._generate_timesteps() @abstractmethod - def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor: + def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor: """ Applies a step of the diffusion process to the input tensor `x` using the provided `noise` and `timestep`. @@ -71,6 +71,12 @@ class Scheduler(ABC): def steps(self) -> list[int]: return list(range(self.num_inference_steps)) + def scale_model_input(self, x: Tensor, step: int) -> Tensor: + """ + For compatibility with schedulers that need to scale the input according to the current timestep. + """ + return x + def sample_power_distribution(self, power: float = 2, /) -> Tensor: return ( linspace( diff --git a/imaginairy/vendored/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py b/imaginairy/vendored/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py index 6d44140..bd5b8c0 100644 --- a/imaginairy/vendored/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py +++ b/imaginairy/vendored/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py @@ -89,10 +89,18 @@ class StableDiffusion_1(LatentDiffusionModel): classifier_free_guidance=True, ) - negative_embedding, _ = clip_text_embedding.chunk(2) timestep = self.scheduler.timesteps[step].unsqueeze(dim=0) + negative_embedding, _ = clip_text_embedding.chunk(2) self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs) - degraded_noise = self.unet(degraded_latents) + if "ip_adapter" in self.unet.provider.contexts: + # this implementation is a bit hacky, it should be refactored in the future + ip_adapter_context = self.unet.use_context("ip_adapter") + image_embedding_copy = ip_adapter_context["clip_image_embedding"].clone() + ip_adapter_context["clip_image_embedding"], _ = ip_adapter_context["clip_image_embedding"].chunk(2) + degraded_noise = self.unet(degraded_latents) + ip_adapter_context["clip_image_embedding"] = image_embedding_copy + else: + degraded_noise = self.unet(degraded_latents) return sag.scale * (noise - degraded_noise) @@ -160,14 +168,23 @@ class StableDiffusion_1_Inpainting(StableDiffusion_1): step=step, classifier_free_guidance=True, ) - - negative_embedding, _ = clip_text_embedding.chunk(2) - timestep = self.scheduler.timesteps[step].unsqueeze(dim=0) - self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs) x = torch.cat( tensors=(degraded_latents, self.mask_latents, self.target_image_latents), dim=1, ) - degraded_noise = self.unet(x) + + timestep = self.scheduler.timesteps[step].unsqueeze(dim=0) + negative_embedding, _ = clip_text_embedding.chunk(2) + self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs) + + if "ip_adapter" in self.unet.provider.contexts: + # this implementation is a bit hacky, it should be refactored in the future + ip_adapter_context = self.unet.use_context("ip_adapter") + image_embedding_copy = ip_adapter_context["clip_image_embedding"].clone() + ip_adapter_context["clip_image_embedding"], _ = ip_adapter_context["clip_image_embedding"].chunk(2) + degraded_noise = self.unet(x) + ip_adapter_context["clip_image_embedding"] = image_embedding_copy + else: + degraded_noise = self.unet(x) return sag.scale * (noise - degraded_noise) diff --git a/imaginairy/vendored/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py b/imaginairy/vendored/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py index 7267cb4..f5b36e3 100644 --- a/imaginairy/vendored/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py +++ b/imaginairy/vendored/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py @@ -138,17 +138,25 @@ class StableDiffusion_XL(LatentDiffusionModel): classifier_free_guidance=True, ) - negative_embedding, _ = clip_text_embedding.chunk(2) + negative_text_embedding, _ = clip_text_embedding.chunk(2) negative_pooled_embedding, _ = pooled_text_embedding.chunk(2) timestep = self.scheduler.timesteps[step].unsqueeze(dim=0) time_ids, _ = time_ids.chunk(2) + self.set_unet_context( timestep=timestep, - clip_text_embedding=negative_embedding, + clip_text_embedding=negative_text_embedding, pooled_text_embedding=negative_pooled_embedding, time_ids=time_ids, - **kwargs, ) - degraded_noise = self.unet(degraded_latents) + if "ip_adapter" in self.unet.provider.contexts: + # this implementation is a bit hacky, it should be refactored in the future + ip_adapter_context = self.unet.use_context("ip_adapter") + image_embedding_copy = ip_adapter_context["clip_image_embedding"].clone() + ip_adapter_context["clip_image_embedding"], _ = ip_adapter_context["clip_image_embedding"].chunk(2) + degraded_noise = self.unet(degraded_latents) + ip_adapter_context["clip_image_embedding"] = image_embedding_copy + else: + degraded_noise = self.unet(degraded_latents) return sag.scale * (noise - degraded_noise) diff --git a/imaginairy/vendored/refiners/foundationals/segment_anything/model.py b/imaginairy/vendored/refiners/foundationals/segment_anything/model.py index d9dfcc8..9c27236 100644 --- a/imaginairy/vendored/refiners/foundationals/segment_anything/model.py +++ b/imaginairy/vendored/refiners/foundationals/segment_anything/model.py @@ -3,11 +3,12 @@ from typing import Sequence import numpy as np import torch +from jaxtyping import Float from PIL import Image from torch import Tensor, device as Device, dtype as DType import imaginairy.vendored.refiners.fluxion.layers as fl -from imaginairy.vendored.refiners.fluxion.utils import image_to_tensor, interpolate, no_grad, normalize, pad +from imaginairy.vendored.refiners.fluxion.utils import interpolate, no_grad, normalize, pad from imaginairy.vendored.refiners.foundationals.segment_anything.image_encoder import SAMViT, SAMViTH from imaginairy.vendored.refiners.foundationals.segment_anything.mask_decoder import MaskDecoder from imaginairy.vendored.refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder @@ -55,7 +56,7 @@ class SegmentAnything(fl.Module): foreground_points: Sequence[tuple[float, float]] | None = None, background_points: Sequence[tuple[float, float]] | None = None, box_points: Sequence[Sequence[tuple[float, float]]] | None = None, - masks: Sequence[Image.Image] | None = None, + low_res_mask: Float[Tensor, "1 1 256 256"] | None = None, binarize: bool = True, ) -> tuple[Tensor, Tensor, Tensor]: if isinstance(input, ImageEmbedding): @@ -74,15 +75,13 @@ class SegmentAnything(fl.Module): ) self.point_encoder.set_type_mask(type_mask=type_mask) - if masks is not None: - mask_tensor = torch.stack( - tensors=[image_to_tensor(image=mask, device=self.device, dtype=self.dtype) for mask in masks] - ) - mask_embedding = self.mask_encoder(mask_tensor) + if low_res_mask is not None: + mask_embedding = self.mask_encoder(low_res_mask) else: mask_embedding = self.mask_encoder.get_no_mask_dense_embedding( image_embedding_size=self.image_encoder.image_embedding_size ) + point_embedding = self.point_encoder( self.normalize(coordinates, target_size=target_size, original_size=original_size) ) diff --git a/imaginairy/vendored/refiners/readme.txt b/imaginairy/vendored/refiners/readme.txt index a1bb296..d6f973e 100644 --- a/imaginairy/vendored/refiners/readme.txt +++ b/imaginairy/vendored/refiners/readme.txt @@ -1 +1 @@ -vendored from git@github.com:finegrain-ai/refiners.git @ 20c229903f53d05dc1c44659ec97603660ef964c +vendored from git@github.com:finegrain-ai/refiners.git @ ce3035923ba71bcb5044708d2f1c37fd1d6722e9