feature: updates refiners vendored library (#458)

* feature: updates refiners vendored library

has a small bugfix that will soon be replaced by a better fix from upstream refiners

Co-authored-by: Bryce <github20210803@accounts.brycedrennan.com>
This commit is contained in:
jaydrennan 2024-01-19 09:45:23 -07:00 committed by GitHub
parent fbb16f6c62
commit 1bf53e47cf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 666 additions and 426 deletions

View File

@ -210,7 +210,7 @@ vendorize_normal_map:
vendorize_refiners: vendorize_refiners:
export REPO=git@github.com:finegrain-ai/refiners.git PKG=refiners COMMIT=20c229903f53d05dc1c44659ec97603660ef964c && \ export REPO=git@github.com:finegrain-ai/refiners.git PKG=refiners COMMIT=ce3035923ba71bcb5044708d2f1c37fd1d6722e9 && \
make download_repo REPO=$$REPO PKG=$$PKG COMMIT=$$COMMIT && \ make download_repo REPO=$$REPO PKG=$$PKG COMMIT=$$COMMIT && \
mkdir -p ./imaginairy/vendored/$$PKG && \ mkdir -p ./imaginairy/vendored/$$PKG && \
rm -rf ./imaginairy/vendored/$$PKG/* && \ rm -rf ./imaginairy/vendored/$$PKG/* && \

View File

@ -1,4 +1,5 @@
from typing import Any, Generic, Iterable, TypeVar from abc import ABC, abstractmethod
from typing import Any
from torch import Tensor, device as Device, dtype as DType from torch import Tensor, device as Device, dtype as DType
from torch.nn import Parameter as TorchParameter from torch.nn import Parameter as TorchParameter
@ -6,125 +7,259 @@ from torch.nn.init import normal_, zeros_
import imaginairy.vendored.refiners.fluxion.layers as fl import imaginairy.vendored.refiners.fluxion.layers as fl
from imaginairy.vendored.refiners.fluxion.adapters.adapter import Adapter from imaginairy.vendored.refiners.fluxion.adapters.adapter import Adapter
from imaginairy.vendored.refiners.fluxion.layers.chain import Chain
T = TypeVar("T", bound=fl.Chain)
TLoraAdapter = TypeVar("TLoraAdapter", bound="LoraAdapter[Any]") # Self (see PEP 673)
class Lora(fl.Chain): class Lora(fl.Chain, ABC):
def __init__(
self,
rank: int = 16,
scale: float = 1.0,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.rank = rank
self._scale = scale
super().__init__(*self.lora_layers(device=device, dtype=dtype), fl.Multiply(scale))
normal_(tensor=self.down.weight, std=1 / self.rank)
zeros_(tensor=self.up.weight)
@abstractmethod
def lora_layers(
self, device: Device | str | None = None, dtype: DType | None = None
) -> tuple[fl.WeightedModule, fl.WeightedModule]:
...
@property
def down(self) -> fl.WeightedModule:
down_layer = self[0]
assert isinstance(down_layer, fl.WeightedModule)
return down_layer
@property
def up(self) -> fl.WeightedModule:
up_layer = self[1]
assert isinstance(up_layer, fl.WeightedModule)
return up_layer
@property
def scale(self) -> float:
return self._scale
@scale.setter
def scale(self, value: float) -> None:
self._scale = value
self.ensure_find(fl.Multiply).scale = value
@classmethod
def from_weights(
cls,
down: Tensor,
up: Tensor,
) -> "Lora":
match (up.ndim, down.ndim):
case (2, 2):
return LinearLora.from_weights(up=up, down=down)
case (4, 4):
return Conv2dLora.from_weights(up=up, down=down)
case _:
raise ValueError(f"Unsupported weight shapes: up={up.shape}, down={down.shape}")
@classmethod
def from_dict(cls, state_dict: dict[str, Tensor], /) -> dict[str, "Lora"]:
"""
Create a dictionary of LoRA layers from a state dict.
Expects the state dict to be a succession of down and up weights.
"""
state_dict = {k: v for k, v in state_dict.items() if ".weight" in k}
loras: dict[str, Lora] = {}
for down_key, down_tensor, up_tensor in zip(
list(state_dict.keys())[::2], list(state_dict.values())[::2], list(state_dict.values())[1::2]
):
key = ".".join(down_key.split(".")[:-2])
loras[key] = cls.from_weights(down=down_tensor, up=up_tensor)
return loras
@abstractmethod
def auto_attach(self, target: fl.Chain, exclude: list[str] | None = None) -> Any:
...
def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None:
assert down_weight.shape == self.down.weight.shape
assert up_weight.shape == self.up.weight.shape
self.down.weight = TorchParameter(down_weight.to(device=self.device, dtype=self.dtype))
self.up.weight = TorchParameter(up_weight.to(device=self.device, dtype=self.dtype))
class LinearLora(Lora):
def __init__( def __init__(
self, self,
in_features: int, in_features: int,
out_features: int, out_features: int,
rank: int = 16, rank: int = 16,
scale: float = 1.0,
device: Device | str | None = None, device: Device | str | None = None,
dtype: DType | None = None, dtype: DType | None = None,
) -> None: ) -> None:
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
self.rank = rank
self.scale: float = 1.0
super().__init__( super().__init__(rank=rank, scale=scale, device=device, dtype=dtype)
fl.Linear(in_features=in_features, out_features=rank, bias=False, device=device, dtype=dtype),
fl.Linear(in_features=rank, out_features=out_features, bias=False, device=device, dtype=dtype), @classmethod
fl.Lambda(func=self.scale_outputs), def from_weights(
cls,
down: Tensor,
up: Tensor,
) -> "LinearLora":
assert up.ndim == 2 and down.ndim == 2
assert down.shape[0] == up.shape[1], f"Rank mismatch: down rank={down.shape[0]} and up rank={up.shape[1]}"
lora = cls(
in_features=down.shape[1], out_features=up.shape[0], rank=down.shape[0], device=up.device, dtype=up.dtype
)
lora.load_weights(down_weight=down, up_weight=up)
return lora
def auto_attach(self, target: Chain, exclude: list[str] | None = None) -> "tuple[LoraAdapter, fl.Chain] | None":
for layer, parent in target.walk(fl.Linear):
if isinstance(parent, Lora) or isinstance(parent, LoraAdapter):
continue
if exclude is not None and any(
[any([p.__class__.__name__ == e for p in parent.get_parents() + [parent]]) for e in exclude]
):
continue
if layer.in_features == self.in_features and layer.out_features == self.out_features:
return LoraAdapter(target=layer, lora=self), parent
def lora_layers(
self, device: Device | str | None = None, dtype: DType | None = None
) -> tuple[fl.Linear, fl.Linear]:
return (
fl.Linear(
in_features=self.in_features,
out_features=self.rank,
bias=False,
device=device,
dtype=dtype,
),
fl.Linear(
in_features=self.rank,
out_features=self.out_features,
bias=False,
device=device,
dtype=dtype,
),
) )
normal_(tensor=self.Linear_1.weight, std=1 / self.rank)
zeros_(tensor=self.Linear_2.weight)
def scale_outputs(self, x: Tensor) -> Tensor: class Conv2dLora(Lora):
return x * self.scale
def set_scale(self, scale: float) -> None:
self.scale = scale
def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None:
self.Linear_1.weight = TorchParameter(down_weight.to(device=self.device, dtype=self.dtype))
self.Linear_2.weight = TorchParameter(up_weight.to(device=self.device, dtype=self.dtype))
@property
def up_weight(self) -> Tensor:
return self.Linear_2.weight.data
@property
def down_weight(self) -> Tensor:
return self.Linear_1.weight.data
class SingleLoraAdapter(fl.Sum, Adapter[fl.Linear]):
def __init__( def __init__(
self, self,
target: fl.Linear, in_channels: int,
out_channels: int,
rank: int = 16, rank: int = 16,
scale: float = 1.0, scale: float = 1.0,
kernel_size: tuple[int, int] = (1, 3),
stride: tuple[int, int] = (1, 1),
padding: tuple[int, int] = (0, 1),
device: Device | str | None = None,
dtype: DType | None = None,
) -> None: ) -> None:
self.in_features = target.in_features self.in_channels = in_channels
self.out_features = target.out_features self.out_channels = out_channels
self.rank = rank self.kernel_size = kernel_size
self.scale = scale self.stride = stride
self.padding = padding
super().__init__(rank=rank, scale=scale, device=device, dtype=dtype)
@classmethod
def from_weights(
cls,
down: Tensor,
up: Tensor,
) -> "Conv2dLora":
assert up.ndim == 4 and down.ndim == 4
assert down.shape[0] == up.shape[1], f"Rank mismatch: down rank={down.shape[0]} and up rank={up.shape[1]}"
down_kernel_size, up_kernel_size = down.shape[2], up.shape[2]
down_padding = 1 if down_kernel_size == 3 else 0
up_padding = 1 if up_kernel_size == 3 else 0
lora = cls(
in_channels=down.shape[1],
out_channels=up.shape[0],
rank=down.shape[0],
kernel_size=(down_kernel_size, up_kernel_size),
padding=(down_padding, up_padding),
device=up.device,
dtype=up.dtype,
)
lora.load_weights(down_weight=down, up_weight=up)
return lora
def auto_attach(self, target: Chain, exclude: list[str] | None = None) -> "tuple[LoraAdapter, fl.Chain] | None":
for layer, parent in target.walk(fl.Conv2d):
if isinstance(parent, Lora) or isinstance(parent, LoraAdapter):
continue
if exclude is not None and any(
[any([p.__class__.__name__ == e for p in parent.get_parents() + [parent]]) for e in exclude]
):
continue
if layer.in_channels == self.in_channels and layer.out_channels == self.out_channels:
if layer.stride != (self.stride[0], self.stride[0]):
self.down.stride = layer.stride
return LoraAdapter(
target=layer,
lora=self,
), parent
def lora_layers(
self, device: Device | str | None = None, dtype: DType | None = None
) -> tuple[fl.Conv2d, fl.Conv2d]:
return (
fl.Conv2d(
in_channels=self.in_channels,
out_channels=self.rank,
kernel_size=self.kernel_size[0],
stride=self.stride[0],
padding=self.padding[0],
use_bias=False,
device=device,
dtype=dtype,
),
fl.Conv2d(
in_channels=self.rank,
out_channels=self.out_channels,
kernel_size=self.kernel_size[1],
stride=self.stride[1],
padding=self.padding[1],
use_bias=False,
device=device,
dtype=dtype,
),
)
class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]):
def __init__(self, target: fl.WeightedModule, lora: Lora) -> None:
with self.setup_adapter(target): with self.setup_adapter(target):
super().__init__( super().__init__(target, lora)
target,
Lora(
in_features=target.in_features,
out_features=target.out_features,
rank=rank,
device=target.device,
dtype=target.dtype,
),
)
self.Lora.set_scale(scale=scale)
class LoraAdapter(Generic[T], fl.Chain, Adapter[T]):
def __init__(
self,
target: T,
sub_targets: Iterable[tuple[fl.Linear, fl.Chain]],
rank: int | None = None,
scale: float = 1.0,
weights: list[Tensor] | None = None,
) -> None:
with self.setup_adapter(target):
super().__init__(target)
if weights is not None:
assert len(weights) % 2 == 0
weights_rank = weights[0].shape[1]
if rank is None:
rank = weights_rank
else:
assert rank == weights_rank
assert rank is not None, "either pass a rank or weights"
self.sub_targets = sub_targets
self.sub_adapters: list[tuple[SingleLoraAdapter, fl.Chain]] = []
for linear, parent in self.sub_targets:
self.sub_adapters.append((SingleLoraAdapter(target=linear, rank=rank, scale=scale), parent))
if weights is not None:
assert len(self.sub_adapters) == (len(weights) // 2)
for i, (adapter, _) in enumerate(self.sub_adapters):
lora = adapter.Lora
assert (
lora.rank == weights[i * 2].shape[1]
), f"Rank of Lora layer {lora.rank} must match shape of weights {weights[i*2].shape[1]}"
adapter.Lora.load_weights(up_weight=weights[i * 2], down_weight=weights[i * 2 + 1])
def inject(self: TLoraAdapter, parent: fl.Chain | None = None) -> TLoraAdapter:
for adapter, adapter_parent in self.sub_adapters:
adapter.inject(adapter_parent)
return super().inject(parent)
def eject(self) -> None:
for adapter, _ in self.sub_adapters:
adapter.eject()
super().eject()
@property @property
def weights(self) -> list[Tensor]: def lora(self) -> Lora:
return [w for adapter, _ in self.sub_adapters for w in [adapter.Lora.up_weight, adapter.Lora.down_weight]] return self.ensure_find(Lora)
@property
def scale(self) -> float:
return self.lora.scale
@scale.setter
def scale(self, value: float) -> None:
self.lora.scale = value

View File

@ -164,6 +164,9 @@ class WeightedModule(Module):
def dtype(self) -> DType: def dtype(self) -> DType:
return self.weight.dtype return self.weight.dtype
def __str__(self) -> str:
return f"{super().__str__().removesuffix(')')}, device={self.device}, dtype={str(self.dtype).removeprefix('torch.')})"
class TreeNode(TypedDict): class TreeNode(TypedDict):
value: str value: str

View File

@ -146,6 +146,7 @@ def tensor_to_image(tensor: Tensor) -> Image.Image:
assert tensor.ndim == 4 and tensor.shape[0] == 1, f"Unsupported tensor shape: {tensor.shape}" assert tensor.ndim == 4 and tensor.shape[0] == 1, f"Unsupported tensor shape: {tensor.shape}"
num_channels = tensor.shape[1] num_channels = tensor.shape[1]
tensor = tensor.clamp(0, 1).squeeze(0) tensor = tensor.clamp(0, 1).squeeze(0)
tensor = tensor.to(torch.float32) # to avoid numpy error with bfloat16
match num_channels: match num_channels:
case 1: case 1:
@ -187,20 +188,26 @@ def save_to_safetensors(path: Path | str, tensors: dict[str, Tensor], metadata:
def summarize_tensor(tensor: torch.Tensor, /) -> str: def summarize_tensor(tensor: torch.Tensor, /) -> str:
return ( info_list = [
"Tensor(" f"shape=({', '.join(map(str, tensor.shape))})",
+ ", ".join( f"dtype={str(object=tensor.dtype).removeprefix('torch.')}",
f"device={tensor.device}",
]
if not tensor.is_complex():
info_list.extend(
[ [
f"shape=({', '.join(map(str, tensor.shape))})",
f"dtype={str(object=tensor.dtype).removeprefix('torch.')}",
f"device={tensor.device}",
f"min={tensor.min():.2f}", # type: ignore f"min={tensor.min():.2f}", # type: ignore
f"max={tensor.max():.2f}", # type: ignore f"max={tensor.max():.2f}", # type: ignore
f"mean={tensor.mean():.2f}",
f"std={tensor.std():.2f}",
f"norm={norm(x=tensor):.2f}",
f"grad={tensor.requires_grad}",
] ]
) )
+ ")"
info_list.extend(
[
f"mean={tensor.float().mean():.2f}",
f"std={tensor.float().std():.2f}",
f"norm={norm(x=tensor.float()):.2f}",
f"grad={tensor.requires_grad}",
]
) )
return "Tensor(" + ", ".join(info_list) + ")"

View File

@ -1,15 +1,12 @@
import math import math
from enum import IntEnum from typing import TYPE_CHECKING, Any, Generic, TypeVar
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar
from jaxtyping import Float from jaxtyping import Float
from PIL import Image from PIL import Image
from torch import Tensor, cat, device as Device, dtype as DType, softmax, zeros_like from torch import Tensor, cat, device as Device, dtype as DType, nn, softmax, zeros_like
import imaginairy.vendored.refiners.fluxion.layers as fl import imaginairy.vendored.refiners.fluxion.layers as fl
from imaginairy.vendored.refiners.fluxion.adapters.adapter import Adapter from imaginairy.vendored.refiners.fluxion.adapters.adapter import Adapter
from imaginairy.vendored.refiners.fluxion.adapters.lora import Lora
from imaginairy.vendored.refiners.fluxion.context import Contexts from imaginairy.vendored.refiners.fluxion.context import Contexts
from imaginairy.vendored.refiners.fluxion.layers.attentions import ScaledDotProductAttention from imaginairy.vendored.refiners.fluxion.layers.attentions import ScaledDotProductAttention
from imaginairy.vendored.refiners.fluxion.utils import image_to_tensor, normalize from imaginairy.vendored.refiners.fluxion.utils import image_to_tensor, normalize
@ -236,120 +233,99 @@ class PerceiverResampler(fl.Chain):
return {"perceiver_resampler": {"x": None}} return {"perceiver_resampler": {"x": None}}
class _CrossAttnIndex(IntEnum): class ImageCrossAttention(fl.Chain):
TXT_CROSS_ATTN = 0 # text cross-attention def __init__(self, text_cross_attention: fl.Attention, scale: float = 1.0) -> None:
IMG_CROSS_ATTN = 1 # image cross-attention self._scale = scale
super().__init__(
fl.Distribute(
fl.Identity(),
fl.Chain(
fl.UseContext(context="ip_adapter", key="clip_image_embedding"),
fl.Linear(
in_features=text_cross_attention.key_embedding_dim,
out_features=text_cross_attention.inner_dim,
bias=text_cross_attention.use_bias,
device=text_cross_attention.device,
dtype=text_cross_attention.dtype,
),
),
fl.Chain(
fl.UseContext(context="ip_adapter", key="clip_image_embedding"),
fl.Linear(
in_features=text_cross_attention.value_embedding_dim,
out_features=text_cross_attention.inner_dim,
bias=text_cross_attention.use_bias,
device=text_cross_attention.device,
dtype=text_cross_attention.dtype,
),
),
),
ScaledDotProductAttention(
num_heads=text_cross_attention.num_heads, is_causal=text_cross_attention.is_causal
),
fl.Multiply(self.scale),
)
@property
def scale(self) -> float:
return self._scale
class InjectionPoint(fl.Chain): @scale.setter
pass def scale(self, value: float) -> None:
self._scale = value
self.ensure_find(fl.Multiply).scale = value
class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]): class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
def __init__( def __init__(
self, self,
target: fl.Attention, target: fl.Attention,
text_sequence_length: int = 77,
image_sequence_length: int = 4,
scale: float = 1.0, scale: float = 1.0,
) -> None: ) -> None:
self.text_sequence_length = text_sequence_length self._scale = scale
self.image_sequence_length = image_sequence_length
self.scale = scale
with self.setup_adapter(target): with self.setup_adapter(target):
clone = target.structural_copy()
scaled_dot_product = clone.ensure_find(ScaledDotProductAttention)
image_cross_attention = ImageCrossAttention(
text_cross_attention=clone,
scale=self.scale,
)
clone.replace(
old_module=scaled_dot_product,
new_module=fl.Sum(
scaled_dot_product,
image_cross_attention,
),
)
super().__init__( super().__init__(
fl.Distribute( clone,
# Note: the same query is used for image cross-attention as for text cross-attention
InjectionPoint(), # Wq
fl.Parallel(
fl.Chain(
fl.Slicing(dim=1, end=text_sequence_length),
InjectionPoint(), # Wk
),
fl.Chain(
fl.Slicing(dim=1, start=text_sequence_length),
fl.Linear(
in_features=self.target.key_embedding_dim,
out_features=self.target.inner_dim,
bias=self.target.use_bias,
device=target.device,
dtype=target.dtype,
), # Wk'
),
),
fl.Parallel(
fl.Chain(
fl.Slicing(dim=1, end=text_sequence_length),
InjectionPoint(), # Wv
),
fl.Chain(
fl.Slicing(dim=1, start=text_sequence_length),
fl.Linear(
in_features=self.target.key_embedding_dim,
out_features=self.target.inner_dim,
bias=self.target.use_bias,
device=target.device,
dtype=target.dtype,
), # Wv'
),
),
),
fl.Sum(
fl.Chain(
fl.Lambda(func=partial(self.select_qkv, index=_CrossAttnIndex.TXT_CROSS_ATTN)),
ScaledDotProductAttention(num_heads=target.num_heads, is_causal=target.is_causal),
),
fl.Chain(
fl.Lambda(func=partial(self.select_qkv, index=_CrossAttnIndex.IMG_CROSS_ATTN)),
ScaledDotProductAttention(num_heads=target.num_heads, is_causal=target.is_causal),
fl.Lambda(func=self.scale_outputs),
),
),
InjectionPoint(), # proj
) )
def select_qkv( @property
self, query: Tensor, keys: tuple[Tensor, Tensor], values: tuple[Tensor, Tensor], index: _CrossAttnIndex def image_cross_attention(self) -> ImageCrossAttention:
) -> tuple[Tensor, Tensor, Tensor]: return self.ensure_find(ImageCrossAttention)
return (query, keys[index.value], values[index.value])
def scale_outputs(self, x: Tensor) -> Tensor: @property
return x * self.scale def image_key_projection(self) -> fl.Linear:
return self.image_cross_attention.Distribute[1].Linear
def _predicate(self, k: type[fl.Module]) -> Callable[[fl.Module, fl.Chain], bool]: @property
def f(m: fl.Module, _: fl.Chain) -> bool: def image_value_projection(self) -> fl.Linear:
if isinstance(m, Lora): # do not adapt LoRAs return self.image_cross_attention.Distribute[2].Linear
raise StopIteration
return isinstance(m, k)
return f @property
def scale(self) -> float:
return self._scale
def _target_linears(self) -> list[fl.Linear]: @scale.setter
return [m for m, _ in self.target.walk(self._predicate(fl.Linear)) if isinstance(m, fl.Linear)] def scale(self, value: float) -> None:
self._scale = value
self.image_cross_attention.scale = value
def inject(self: "CrossAttentionAdapter", parent: fl.Chain | None = None) -> "CrossAttentionAdapter": def load_weights(self, key_tensor: Tensor, value_tensor: Tensor) -> None:
linears = self._target_linears() self.image_key_projection.weight = nn.Parameter(key_tensor)
assert len(linears) == 4 # Wq, Wk, Wv and Proj self.image_value_projection.weight = nn.Parameter(value_tensor)
self.image_cross_attention.to(self.device, self.dtype)
injection_points = list(self.layers(InjectionPoint))
assert len(injection_points) == 4
for linear, ip in zip(linears, injection_points):
ip.append(linear)
assert len(ip) == 1
return super().inject(parent)
def eject(self) -> None:
injection_points = list(self.layers(InjectionPoint))
assert len(injection_points) == 4
for ip in injection_points:
ip.pop()
assert len(ip) == 0
super().eject()
class IPAdapter(Generic[T], fl.Chain, Adapter[T]): class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
@ -377,7 +353,7 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
self._image_proj = [image_proj] self._image_proj = [image_proj]
self.sub_adapters = [ self.sub_adapters = [
CrossAttentionAdapter(target=cross_attn, scale=scale, image_sequence_length=self.image_proj.num_tokens) CrossAttentionAdapter(target=cross_attn, scale=scale)
for cross_attn in filter(lambda attn: type(attn) != fl.SelfAttention, target.layers(fl.Attention)) for cross_attn in filter(lambda attn: type(attn) != fl.SelfAttention, target.layers(fl.Attention))
] ]
@ -388,14 +364,15 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
self.image_proj.load_state_dict(image_proj_state_dict) self.image_proj.load_state_dict(image_proj_state_dict)
for i, cross_attn in enumerate(self.sub_adapters): for i, cross_attn in enumerate(self.sub_adapters):
cross_attn_state_dict: dict[str, Tensor] = {} cross_attention_weights: list[Tensor] = []
for k, v in weights.items(): for k, v in weights.items():
prefix = f"ip_adapter.{i:03d}." prefix = f"ip_adapter.{i:03d}."
if not k.startswith(prefix): if not k.startswith(prefix):
continue continue
cross_attn_state_dict[k.removeprefix(prefix)] = v cross_attention_weights.append(v)
cross_attn.load_state_dict(state_dict=cross_attn_state_dict) assert len(cross_attention_weights) == 2
cross_attn.load_weights(*cross_attention_weights)
@property @property
def clip_image_encoder(self) -> CLIPImageEncoderH: def clip_image_encoder(self) -> CLIPImageEncoderH:
@ -420,10 +397,22 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
adapter.eject() adapter.eject()
super().eject() super().eject()
@property
def scale(self) -> float:
return self.sub_adapters[0].scale
@scale.setter
def scale(self, value: float) -> None:
for cross_attn in self.sub_adapters:
cross_attn.scale = value
def set_scale(self, scale: float) -> None: def set_scale(self, scale: float) -> None:
for cross_attn in self.sub_adapters: for cross_attn in self.sub_adapters:
cross_attn.scale = scale cross_attn.scale = scale
def set_clip_image_embedding(self, image_embedding: Tensor) -> None:
self.set_context("ip_adapter", {"clip_image_embedding": image_embedding})
# These should be concatenated to the CLIP text embedding before setting the UNet context # These should be concatenated to the CLIP text embedding before setting the UNet context
def compute_clip_image_embedding(self, image_prompt: Tensor) -> Tensor: def compute_clip_image_embedding(self, image_prompt: Tensor) -> Tensor:
image_encoder = self.clip_image_encoder if not self.fine_grained else self.grid_image_encoder image_encoder = self.clip_image_encoder if not self.fine_grained else self.grid_image_encoder

View File

@ -1,146 +1,140 @@
from enum import Enum from warnings import warn
from pathlib import Path
from typing import Callable, Iterator
from torch import Tensor from torch import Tensor
import imaginairy.vendored.refiners.fluxion.layers as fl import imaginairy.vendored.refiners.fluxion.layers as fl
from imaginairy.vendored.refiners.fluxion.adapters.adapter import Adapter
from imaginairy.vendored.refiners.fluxion.adapters.lora import Lora, LoraAdapter from imaginairy.vendored.refiners.fluxion.adapters.lora import Lora, LoraAdapter
from imaginairy.vendored.refiners.fluxion.utils import load_from_safetensors, load_metadata_from_safetensors from imaginairy.vendored.refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
from imaginairy.vendored.refiners.foundationals.clip.text_encoder import FeedForward, TransformerLayer
from imaginairy.vendored.refiners.foundationals.latent_diffusion import (
CLIPTextEncoderL,
LatentDiffusionAutoencoder,
SD1UNet,
StableDiffusion_1,
)
from imaginairy.vendored.refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import Controlnet
MODELS = ["unet", "text_encoder", "lda"]
class LoraTarget(str, Enum): class SDLoraManager:
Self = "self"
Attention = "Attention"
SelfAttention = "SelfAttention"
CrossAttention = "CrossAttentionBlock2d"
FeedForward = "FeedForward"
TransformerLayer = "TransformerLayer"
def get_class(self) -> type[fl.Chain]:
match self:
case LoraTarget.Self:
return fl.Chain
case LoraTarget.Attention:
return fl.Attention
case LoraTarget.SelfAttention:
return fl.SelfAttention
case LoraTarget.CrossAttention:
return CrossAttentionBlock2d
case LoraTarget.FeedForward:
return FeedForward
case LoraTarget.TransformerLayer:
return TransformerLayer
def _predicate(k: type[fl.Module]) -> Callable[[fl.Module, fl.Chain], bool]:
def f(m: fl.Module, _: fl.Chain) -> bool:
if isinstance(m, Lora): # do not adapt other LoRAs
raise StopIteration
if isinstance(m, Controlnet): # do not adapt Controlnet linears
raise StopIteration
return isinstance(m, k)
return f
def _iter_linears(module: fl.Chain) -> Iterator[tuple[fl.Linear, fl.Chain]]:
for m, p in module.walk(_predicate(fl.Linear)):
assert isinstance(m, fl.Linear)
yield (m, p)
def lora_targets(
module: fl.Chain,
target: LoraTarget | list[LoraTarget],
) -> Iterator[tuple[fl.Linear, fl.Chain]]:
if isinstance(target, list):
for t in target:
yield from lora_targets(module, t)
return
if target == LoraTarget.Self:
yield from _iter_linears(module)
return
for layer, _ in module.walk(_predicate(target.get_class())):
assert isinstance(layer, fl.Chain)
yield from _iter_linears(layer)
class SD1LoraAdapter(fl.Chain, Adapter[StableDiffusion_1]):
metadata: dict[str, str] | None
tensors: dict[str, Tensor]
def __init__( def __init__(
self, self,
target: StableDiffusion_1, target: LatentDiffusionModel,
sub_targets: dict[str, list[LoraTarget]], ) -> None:
self.target = target
@property
def unet(self) -> fl.Chain:
unet = self.target.unet
assert isinstance(unet, fl.Chain)
return unet
@property
def clip_text_encoder(self) -> fl.Chain:
clip_text_encoder = self.target.clip_text_encoder
assert isinstance(clip_text_encoder, fl.Chain)
return clip_text_encoder
def load(
self,
tensors: dict[str, Tensor],
/,
scale: float = 1.0, scale: float = 1.0,
weights: dict[str, Tensor] | None = None, ) -> None:
): """Load the LoRA weights from a dictionary of tensors.
with self.setup_adapter(target):
super().__init__(target)
self.sub_adapters: list[LoraAdapter[SD1UNet | CLIPTextEncoderL | LatentDiffusionAutoencoder]] = [] Expects the keys to be in the commonly found formats on CivitAI's hub.
"""
for model_name in MODELS: assert len(self.lora_adapters) == 0, "Loras already loaded"
if not (model_targets := sub_targets.get(model_name, [])): loras = Lora.from_dict(
continue {key: value.to(device=self.target.device, dtype=self.target.dtype) for key, value in tensors.items()}
model = getattr(target, "clip_text_encoder" if model_name == "text_encoder" else model_name)
lora_weights = [weights[k] for k in sorted(weights) if k.startswith(model_name)] if weights else None
self.sub_adapters.append(
LoraAdapter[type(model)](
model,
sub_targets=lora_targets(model, model_targets),
scale=scale,
weights=lora_weights,
)
)
@classmethod
def from_safetensors(
cls,
target: StableDiffusion_1,
checkpoint_path: Path | str,
scale: float = 1.0,
):
metadata = load_metadata_from_safetensors(checkpoint_path)
assert metadata is not None, "Invalid safetensors checkpoint: missing metadata"
tensors = load_from_safetensors(checkpoint_path, device=target.device)
sub_targets: dict[str, list[LoraTarget]] = {}
for model_name in MODELS:
if not (v := metadata.get(f"{model_name}_targets", "")):
continue
sub_targets[model_name] = [LoraTarget(x) for x in v.split(",")]
return cls(
target,
sub_targets,
scale=scale,
weights=tensors,
) )
loras = {key: loras[key] for key in sorted(loras.keys(), key=SDLoraManager.sort_keys)}
def inject(self: "SD1LoraAdapter", parent: fl.Chain | None = None) -> "SD1LoraAdapter": # if no key contains "unet" or "text", assume all keys are for the unet
for adapter in self.sub_adapters: if not "unet" in loras and not "text" in loras:
adapter.inject() loras = {f"unet_{key}": loras[key] for key in loras.keys()}
return super().inject(parent)
def eject(self) -> None: self.load_unet(loras)
for adapter in self.sub_adapters: self.load_text_encoder(loras)
adapter.eject()
super().eject() self.scale = scale
def load_text_encoder(self, loras: dict[str, Lora], /) -> None:
text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key}
SDLoraManager.auto_attach(text_encoder_loras, self.clip_text_encoder)
def load_unet(self, loras: dict[str, Lora], /) -> None:
unet_loras = {key: loras[key] for key in loras.keys() if "unet" in key}
exclude: list[str] = []
exclude = [
self.unet_exclusions[exclusion]
for exclusion in self.unet_exclusions
if all([exclusion not in key for key in unet_loras.keys()])
]
SDLoraManager.auto_attach(unet_loras, self.unet, exclude=exclude)
def unload(self) -> None:
for lora_adapter in self.lora_adapters:
lora_adapter.eject()
@property
def loras(self) -> list[Lora]:
return list(self.unet.layers(Lora)) + list(self.clip_text_encoder.layers(Lora))
@property
def lora_adapters(self) -> list[LoraAdapter]:
return list(self.unet.layers(LoraAdapter)) + list(self.clip_text_encoder.layers(LoraAdapter))
@property
def unet_exclusions(self) -> dict[str, str]:
return {
"time": "TimestepEncoder",
"res": "ResidualBlock",
"downsample": "DownsampleBlock",
"upsample": "UpsampleBlock",
}
@property
def scale(self) -> float:
assert len(self.loras) > 0, "No loras found"
assert all([lora.scale == self.loras[0].scale for lora in self.loras])
return self.loras[0].scale
@scale.setter
def scale(self, value: float) -> None:
for lora in self.loras:
lora.scale = value
@staticmethod
def pad(input: str, /, padding_length: int = 2) -> str:
new_split: list[str] = []
for s in input.split("_"):
if s.isdigit():
new_split.append(s.zfill(padding_length))
else:
new_split.append(s)
return "_".join(new_split)
@staticmethod
def sort_keys(key: str, /) -> tuple[str, int]:
# out0 happens sometimes as an alias for out ; this dict might not be exhaustive
key_char_order = {"q": 1, "k": 2, "v": 3, "out": 4, "out0": 4}
for i, s in enumerate(key.split("_")):
if s in key_char_order:
prefix = SDLoraManager.pad("_".join(key.split("_")[:i]))
return (prefix, key_char_order[s])
return (SDLoraManager.pad(key), 5)
@staticmethod
def auto_attach(
loras: dict[str, Lora],
target: fl.Chain,
/,
exclude: list[str] | None = None,
) -> None:
failed_loras: dict[str, Lora] = {}
for key, lora in loras.items():
if attach := lora.auto_attach(target, exclude=exclude):
adapter, parent = attach
adapter.inject(parent)
else:
failed_loras[key] = lora
if failed_loras:
warn(f"failed to attach {len(failed_loras)}/{len(loras)} loras to {target.__class__.__name__}")
# TODO: add a stronger sanity check to make sure loras are attached correctly

View File

@ -11,7 +11,6 @@ from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.sche
T = TypeVar("T", bound="fl.Module") T = TypeVar("T", bound="fl.Module")
TLatentDiffusionModel = TypeVar("TLatentDiffusionModel", bound="LatentDiffusionModel") TLatentDiffusionModel = TypeVar("TLatentDiffusionModel", bound="LatentDiffusionModel")
@ -91,6 +90,8 @@ class LatentDiffusionModel(fl.Module, ABC):
self.set_unet_context(timestep=timestep, clip_text_embedding=clip_text_embedding, **kwargs) self.set_unet_context(timestep=timestep, clip_text_embedding=clip_text_embedding, **kwargs)
latents = torch.cat(tensors=(x, x)) # for classifier-free guidance latents = torch.cat(tensors=(x, x)) # for classifier-free guidance
# scale latents for schedulers that need it
latents = self.scheduler.scale_model_input(latents, step=step)
unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2) unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2)
# classifier-free guidance # classifier-free guidance

View File

@ -1,11 +1,7 @@
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.ddpm import DDPM from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.ddpm import DDPM
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.euler import EulerScheduler
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
__all__ = [ __all__ = ["Scheduler", "DPMSolver", "DDPM", "DDIM", "EulerScheduler"]
"Scheduler",
"DPMSolver",
"DDPM",
"DDIM",
]

View File

@ -1,4 +1,4 @@
from torch import Tensor, arange, device as Device, dtype as Dtype, float32, sqrt, tensor from torch import Generator, Tensor, arange, device as Device, dtype as Dtype, float32, sqrt, tensor
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
@ -34,7 +34,7 @@ class DDIM(Scheduler):
timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio + 1 timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio + 1
return timesteps.flip(0) return timesteps.flip(0)
def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor: def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
timestep, previous_timestep = ( timestep, previous_timestep = (
self.timesteps[step], self.timesteps[step],
( (
@ -52,6 +52,12 @@ class DDIM(Scheduler):
), ),
) )
predicted_x = (x - sqrt(1 - current_scale_factor**2) * noise) / current_scale_factor predicted_x = (x - sqrt(1 - current_scale_factor**2) * noise) / current_scale_factor
denoised_x = previous_scale_factor * predicted_x + sqrt(1 - previous_scale_factor**2) * noise noise_factor = sqrt(1 - previous_scale_factor**2)
# Do not add noise at the last step to avoid visual artifacts.
if step == self.num_inference_steps - 1:
noise_factor = 0
denoised_x = previous_scale_factor * predicted_x + noise_factor * noise
return denoised_x return denoised_x

View File

@ -1,4 +1,4 @@
from torch import Tensor, arange, device as Device from torch import Generator, Tensor, arange, device as Device
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
@ -30,5 +30,5 @@ class DDPM(Scheduler):
timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio
return timesteps.flip(0) return timesteps.flip(0)
def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor: def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
raise NotImplementedError raise NotImplementedError

View File

@ -1,15 +1,19 @@
from collections import deque from collections import deque
import numpy as np import numpy as np
from torch import Tensor, device as Device, dtype as Dtype, exp, float32, tensor from torch import Generator, Tensor, device as Device, dtype as Dtype, exp, float32, tensor
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
class DPMSolver(Scheduler): class DPMSolver(Scheduler):
"""Implements DPM-Solver++ from https://arxiv.org/abs/2211.01095 """
Implements DPM-Solver++ from https://arxiv.org/abs/2211.01095
We only support noise prediction for now. Regarding last_step_first_order: DPM-Solver++ is known to introduce artifacts
when used with SDXL and few steps. This parameter is a way to mitigate that
effect by using a first-order (Euler) update instead of a second-order update
for the last step of the diffusion.
""" """
def __init__( def __init__(
@ -18,6 +22,7 @@ class DPMSolver(Scheduler):
num_train_timesteps: int = 1_000, num_train_timesteps: int = 1_000,
initial_diffusion_rate: float = 8.5e-4, initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2, final_diffusion_rate: float = 1.2e-2,
last_step_first_order: bool = False,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC, noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
device: Device | str = "cpu", device: Device | str = "cpu",
dtype: Dtype = float32, dtype: Dtype = float32,
@ -32,7 +37,8 @@ class DPMSolver(Scheduler):
dtype=dtype, dtype=dtype,
) )
self.estimated_data = deque([tensor([])] * 2, maxlen=2) self.estimated_data = deque([tensor([])] * 2, maxlen=2)
self.initial_steps = 0 self.last_step_first_order = last_step_first_order
self._first_step_has_been_run = False
def _generate_timesteps(self) -> Tensor: def _generate_timesteps(self) -> Tensor:
# We need to use numpy here because: # We need to use numpy here because:
@ -45,57 +51,48 @@ class DPMSolver(Scheduler):
).flip(0) ).flip(0)
def dpm_solver_first_order_update(self, x: Tensor, noise: Tensor, step: int) -> Tensor: def dpm_solver_first_order_update(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
timestep, previous_timestep = ( current_timestep = self.timesteps[step]
self.timesteps[step], previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else tensor([0])
self.timesteps[step + 1 if step < len(self.timesteps) - 1 else 0],
) previous_ratio = self.signal_to_noise_ratios[previous_timestep]
previous_ratio, current_ratio = ( current_ratio = self.signal_to_noise_ratios[current_timestep]
self.signal_to_noise_ratios[previous_timestep],
self.signal_to_noise_ratios[timestep],
)
previous_scale_factor = self.cumulative_scale_factors[previous_timestep] previous_scale_factor = self.cumulative_scale_factors[previous_timestep]
previous_noise_std, current_noise_std = (
self.noise_std[previous_timestep], previous_noise_std = self.noise_std[previous_timestep]
self.noise_std[timestep], current_noise_std = self.noise_std[current_timestep]
)
factor = exp(-(previous_ratio - current_ratio)) - 1.0 factor = exp(-(previous_ratio - current_ratio)) - 1.0
denoised_x = (previous_noise_std / current_noise_std) * x - (factor * previous_scale_factor) * noise denoised_x = (previous_noise_std / current_noise_std) * x - (factor * previous_scale_factor) * noise
return denoised_x return denoised_x
def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int) -> Tensor: def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int) -> Tensor:
previous_timestep, current_timestep, next_timestep = ( previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else tensor([0])
self.timesteps[step + 1] if step < len(self.timesteps) - 1 else tensor([0]), current_timestep = self.timesteps[step]
self.timesteps[step], next_timestep = self.timesteps[step - 1]
self.timesteps[step - 1],
) current_data_estimation = self.estimated_data[-1]
current_data_estimation, next_data_estimation = self.estimated_data[-1], self.estimated_data[-2] next_data_estimation = self.estimated_data[-2]
previous_ratio, current_ratio, next_ratio = (
self.signal_to_noise_ratios[previous_timestep], previous_ratio = self.signal_to_noise_ratios[previous_timestep]
self.signal_to_noise_ratios[current_timestep], current_ratio = self.signal_to_noise_ratios[current_timestep]
self.signal_to_noise_ratios[next_timestep], next_ratio = self.signal_to_noise_ratios[next_timestep]
)
previous_scale_factor = self.cumulative_scale_factors[previous_timestep] previous_scale_factor = self.cumulative_scale_factors[previous_timestep]
previous_std, current_std = ( previous_noise_std = self.noise_std[previous_timestep]
self.noise_std[previous_timestep], current_noise_std = self.noise_std[current_timestep]
self.noise_std[current_timestep],
)
estimation_delta = (current_data_estimation - next_data_estimation) / ( estimation_delta = (current_data_estimation - next_data_estimation) / (
(current_ratio - next_ratio) / (previous_ratio - current_ratio) (current_ratio - next_ratio) / (previous_ratio - current_ratio)
) )
factor = exp(-(previous_ratio - current_ratio)) - 1.0 factor = exp(-(previous_ratio - current_ratio)) - 1.0
denoised_x = ( denoised_x = (
(previous_std / current_std) * x (previous_noise_std / current_noise_std) * x
- (factor * previous_scale_factor) * current_data_estimation - (factor * previous_scale_factor) * current_data_estimation
- 0.5 * (factor * previous_scale_factor) * estimation_delta - 0.5 * (factor * previous_scale_factor) * estimation_delta
) )
return denoised_x return denoised_x
def __call__( def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
self,
x: Tensor,
noise: Tensor,
step: int,
) -> Tensor:
""" """
Represents one step of the backward diffusion process that iteratively denoises the input data `x`. Represents one step of the backward diffusion process that iteratively denoises the input data `x`.
@ -107,11 +104,9 @@ class DPMSolver(Scheduler):
scale_factor, noise_ratio = self.cumulative_scale_factors[current_timestep], self.noise_std[current_timestep] scale_factor, noise_ratio = self.cumulative_scale_factors[current_timestep], self.noise_std[current_timestep]
estimated_denoised_data = (x - noise_ratio * noise) / scale_factor estimated_denoised_data = (x - noise_ratio * noise) / scale_factor
self.estimated_data.append(estimated_denoised_data) self.estimated_data.append(estimated_denoised_data)
denoised_x = (
self.dpm_solver_first_order_update(x=x, noise=estimated_denoised_data, step=step) if step == 0 or (self.last_step_first_order and step == self.num_inference_steps - 1) or not self._first_step_has_been_run:
if (self.initial_steps == 0) self._first_step_has_been_run = True
else self.multistep_dpm_solver_second_order_update(x=x, step=step) return self.dpm_solver_first_order_update(x=x, noise=estimated_denoised_data, step=step)
)
if self.initial_steps < 2: return self.multistep_dpm_solver_second_order_update(x=x, step=step)
self.initial_steps += 1
return denoised_x

View File

@ -0,0 +1,84 @@
import numpy as np
import torch
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
class EulerScheduler(Scheduler):
def __init__(
self,
num_inference_steps: int,
num_train_timesteps: int = 1_000,
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
device: Device | str = "cpu",
dtype: Dtype = float32,
):
if noise_schedule != NoiseSchedule.QUADRATIC:
raise NotImplementedError
super().__init__(
num_inference_steps=num_inference_steps,
num_train_timesteps=num_train_timesteps,
initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate,
noise_schedule=noise_schedule,
device=device,
dtype=dtype,
)
self.sigmas = self._generate_sigmas()
@property
def init_noise_sigma(self) -> Tensor:
return self.sigmas.max()
def _generate_timesteps(self) -> Tensor:
# We need to use numpy here because:
# numpy.linspace(0,999,31)[15] is 499.49999999999994
# torch.linspace(0,999,31)[15] is 499.5
# ...and we want the same result as the original codebase.
timesteps = torch.tensor(
np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps), dtype=self.dtype, device=self.device
).flip(0)
return timesteps
def _generate_sigmas(self) -> Tensor:
sigmas = self.noise_std / self.cumulative_scale_factors
sigmas = torch.tensor(np.interp(self.timesteps.cpu().numpy(), np.arange(0, len(sigmas)), sigmas.cpu().numpy()))
sigmas = torch.cat([sigmas, tensor([0.0])])
return sigmas.to(device=self.device, dtype=self.dtype)
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
sigma = self.sigmas[step]
return x / ((sigma**2 + 1) ** 0.5)
def __call__(
self,
x: Tensor,
noise: Tensor,
step: int,
generator: Generator | None = None,
s_churn: float = 0.0,
s_tmin: float = 0.0,
s_tmax: float = float("inf"),
s_noise: float = 1.0,
) -> Tensor:
sigma = self.sigmas[step]
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0
alt_noise = torch.randn(noise.shape, generator=generator, device=noise.device, dtype=noise.dtype)
eps = alt_noise * s_noise
sigma_hat = sigma * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat**2 - sigma**2) ** 0.5
predicted_x = x - sigma_hat * noise
# 1st order Euler
derivative = (x - predicted_x) / sigma_hat
dt = self.sigmas[step + 1] - sigma_hat
denoised_x = x + derivative * dt
return denoised_x

View File

@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
from typing import TypeVar from typing import TypeVar
from torch import Tensor, device as Device, dtype as DType, float32, linspace, log, sqrt from torch import Generator, Tensor, device as Device, dtype as DType, float32, linspace, log, sqrt
T = TypeVar("T", bound="Scheduler") T = TypeVar("T", bound="Scheduler")
@ -50,7 +50,7 @@ class Scheduler(ABC):
self.timesteps = self._generate_timesteps() self.timesteps = self._generate_timesteps()
@abstractmethod @abstractmethod
def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor: def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
""" """
Applies a step of the diffusion process to the input tensor `x` using the provided `noise` and `timestep`. Applies a step of the diffusion process to the input tensor `x` using the provided `noise` and `timestep`.
@ -71,6 +71,12 @@ class Scheduler(ABC):
def steps(self) -> list[int]: def steps(self) -> list[int]:
return list(range(self.num_inference_steps)) return list(range(self.num_inference_steps))
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
"""
For compatibility with schedulers that need to scale the input according to the current timestep.
"""
return x
def sample_power_distribution(self, power: float = 2, /) -> Tensor: def sample_power_distribution(self, power: float = 2, /) -> Tensor:
return ( return (
linspace( linspace(

View File

@ -89,10 +89,18 @@ class StableDiffusion_1(LatentDiffusionModel):
classifier_free_guidance=True, classifier_free_guidance=True,
) )
negative_embedding, _ = clip_text_embedding.chunk(2)
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0) timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
negative_embedding, _ = clip_text_embedding.chunk(2)
self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs) self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs)
degraded_noise = self.unet(degraded_latents) if "ip_adapter" in self.unet.provider.contexts:
# this implementation is a bit hacky, it should be refactored in the future
ip_adapter_context = self.unet.use_context("ip_adapter")
image_embedding_copy = ip_adapter_context["clip_image_embedding"].clone()
ip_adapter_context["clip_image_embedding"], _ = ip_adapter_context["clip_image_embedding"].chunk(2)
degraded_noise = self.unet(degraded_latents)
ip_adapter_context["clip_image_embedding"] = image_embedding_copy
else:
degraded_noise = self.unet(degraded_latents)
return sag.scale * (noise - degraded_noise) return sag.scale * (noise - degraded_noise)
@ -160,14 +168,23 @@ class StableDiffusion_1_Inpainting(StableDiffusion_1):
step=step, step=step,
classifier_free_guidance=True, classifier_free_guidance=True,
) )
negative_embedding, _ = clip_text_embedding.chunk(2)
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs)
x = torch.cat( x = torch.cat(
tensors=(degraded_latents, self.mask_latents, self.target_image_latents), tensors=(degraded_latents, self.mask_latents, self.target_image_latents),
dim=1, dim=1,
) )
degraded_noise = self.unet(x)
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
negative_embedding, _ = clip_text_embedding.chunk(2)
self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs)
if "ip_adapter" in self.unet.provider.contexts:
# this implementation is a bit hacky, it should be refactored in the future
ip_adapter_context = self.unet.use_context("ip_adapter")
image_embedding_copy = ip_adapter_context["clip_image_embedding"].clone()
ip_adapter_context["clip_image_embedding"], _ = ip_adapter_context["clip_image_embedding"].chunk(2)
degraded_noise = self.unet(x)
ip_adapter_context["clip_image_embedding"] = image_embedding_copy
else:
degraded_noise = self.unet(x)
return sag.scale * (noise - degraded_noise) return sag.scale * (noise - degraded_noise)

View File

@ -138,17 +138,25 @@ class StableDiffusion_XL(LatentDiffusionModel):
classifier_free_guidance=True, classifier_free_guidance=True,
) )
negative_embedding, _ = clip_text_embedding.chunk(2) negative_text_embedding, _ = clip_text_embedding.chunk(2)
negative_pooled_embedding, _ = pooled_text_embedding.chunk(2) negative_pooled_embedding, _ = pooled_text_embedding.chunk(2)
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0) timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
time_ids, _ = time_ids.chunk(2) time_ids, _ = time_ids.chunk(2)
self.set_unet_context( self.set_unet_context(
timestep=timestep, timestep=timestep,
clip_text_embedding=negative_embedding, clip_text_embedding=negative_text_embedding,
pooled_text_embedding=negative_pooled_embedding, pooled_text_embedding=negative_pooled_embedding,
time_ids=time_ids, time_ids=time_ids,
**kwargs,
) )
degraded_noise = self.unet(degraded_latents) if "ip_adapter" in self.unet.provider.contexts:
# this implementation is a bit hacky, it should be refactored in the future
ip_adapter_context = self.unet.use_context("ip_adapter")
image_embedding_copy = ip_adapter_context["clip_image_embedding"].clone()
ip_adapter_context["clip_image_embedding"], _ = ip_adapter_context["clip_image_embedding"].chunk(2)
degraded_noise = self.unet(degraded_latents)
ip_adapter_context["clip_image_embedding"] = image_embedding_copy
else:
degraded_noise = self.unet(degraded_latents)
return sag.scale * (noise - degraded_noise) return sag.scale * (noise - degraded_noise)

View File

@ -3,11 +3,12 @@ from typing import Sequence
import numpy as np import numpy as np
import torch import torch
from jaxtyping import Float
from PIL import Image from PIL import Image
from torch import Tensor, device as Device, dtype as DType from torch import Tensor, device as Device, dtype as DType
import imaginairy.vendored.refiners.fluxion.layers as fl import imaginairy.vendored.refiners.fluxion.layers as fl
from imaginairy.vendored.refiners.fluxion.utils import image_to_tensor, interpolate, no_grad, normalize, pad from imaginairy.vendored.refiners.fluxion.utils import interpolate, no_grad, normalize, pad
from imaginairy.vendored.refiners.foundationals.segment_anything.image_encoder import SAMViT, SAMViTH from imaginairy.vendored.refiners.foundationals.segment_anything.image_encoder import SAMViT, SAMViTH
from imaginairy.vendored.refiners.foundationals.segment_anything.mask_decoder import MaskDecoder from imaginairy.vendored.refiners.foundationals.segment_anything.mask_decoder import MaskDecoder
from imaginairy.vendored.refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder from imaginairy.vendored.refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder
@ -55,7 +56,7 @@ class SegmentAnything(fl.Module):
foreground_points: Sequence[tuple[float, float]] | None = None, foreground_points: Sequence[tuple[float, float]] | None = None,
background_points: Sequence[tuple[float, float]] | None = None, background_points: Sequence[tuple[float, float]] | None = None,
box_points: Sequence[Sequence[tuple[float, float]]] | None = None, box_points: Sequence[Sequence[tuple[float, float]]] | None = None,
masks: Sequence[Image.Image] | None = None, low_res_mask: Float[Tensor, "1 1 256 256"] | None = None,
binarize: bool = True, binarize: bool = True,
) -> tuple[Tensor, Tensor, Tensor]: ) -> tuple[Tensor, Tensor, Tensor]:
if isinstance(input, ImageEmbedding): if isinstance(input, ImageEmbedding):
@ -74,15 +75,13 @@ class SegmentAnything(fl.Module):
) )
self.point_encoder.set_type_mask(type_mask=type_mask) self.point_encoder.set_type_mask(type_mask=type_mask)
if masks is not None: if low_res_mask is not None:
mask_tensor = torch.stack( mask_embedding = self.mask_encoder(low_res_mask)
tensors=[image_to_tensor(image=mask, device=self.device, dtype=self.dtype) for mask in masks]
)
mask_embedding = self.mask_encoder(mask_tensor)
else: else:
mask_embedding = self.mask_encoder.get_no_mask_dense_embedding( mask_embedding = self.mask_encoder.get_no_mask_dense_embedding(
image_embedding_size=self.image_encoder.image_embedding_size image_embedding_size=self.image_encoder.image_embedding_size
) )
point_embedding = self.point_encoder( point_embedding = self.point_encoder(
self.normalize(coordinates, target_size=target_size, original_size=original_size) self.normalize(coordinates, target_size=target_size, original_size=original_size)
) )

View File

@ -1 +1 @@
vendored from git@github.com:finegrain-ai/refiners.git @ 20c229903f53d05dc1c44659ec97603660ef964c vendored from git@github.com:finegrain-ai/refiners.git @ ce3035923ba71bcb5044708d2f1c37fd1d6722e9