mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-10-31 03:20:40 +00:00
feature: updates refiners vendored library (#458)
* feature: updates refiners vendored library has a small bugfix that will soon be replaced by a better fix from upstream refiners Co-authored-by: Bryce <github20210803@accounts.brycedrennan.com>
This commit is contained in:
parent
fbb16f6c62
commit
1bf53e47cf
2
Makefile
2
Makefile
@ -210,7 +210,7 @@ vendorize_normal_map:
|
||||
|
||||
|
||||
vendorize_refiners:
|
||||
export REPO=git@github.com:finegrain-ai/refiners.git PKG=refiners COMMIT=20c229903f53d05dc1c44659ec97603660ef964c && \
|
||||
export REPO=git@github.com:finegrain-ai/refiners.git PKG=refiners COMMIT=ce3035923ba71bcb5044708d2f1c37fd1d6722e9 && \
|
||||
make download_repo REPO=$$REPO PKG=$$PKG COMMIT=$$COMMIT && \
|
||||
mkdir -p ./imaginairy/vendored/$$PKG && \
|
||||
rm -rf ./imaginairy/vendored/$$PKG/* && \
|
||||
|
@ -1,4 +1,5 @@
|
||||
from typing import Any, Generic, Iterable, TypeVar
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from torch import Tensor, device as Device, dtype as DType
|
||||
from torch.nn import Parameter as TorchParameter
|
||||
@ -6,125 +7,259 @@ from torch.nn.init import normal_, zeros_
|
||||
|
||||
import imaginairy.vendored.refiners.fluxion.layers as fl
|
||||
from imaginairy.vendored.refiners.fluxion.adapters.adapter import Adapter
|
||||
|
||||
T = TypeVar("T", bound=fl.Chain)
|
||||
TLoraAdapter = TypeVar("TLoraAdapter", bound="LoraAdapter[Any]") # Self (see PEP 673)
|
||||
from imaginairy.vendored.refiners.fluxion.layers.chain import Chain
|
||||
|
||||
|
||||
class Lora(fl.Chain):
|
||||
class Lora(fl.Chain, ABC):
|
||||
def __init__(
|
||||
self,
|
||||
rank: int = 16,
|
||||
scale: float = 1.0,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
self.rank = rank
|
||||
self._scale = scale
|
||||
|
||||
super().__init__(*self.lora_layers(device=device, dtype=dtype), fl.Multiply(scale))
|
||||
|
||||
normal_(tensor=self.down.weight, std=1 / self.rank)
|
||||
zeros_(tensor=self.up.weight)
|
||||
|
||||
@abstractmethod
|
||||
def lora_layers(
|
||||
self, device: Device | str | None = None, dtype: DType | None = None
|
||||
) -> tuple[fl.WeightedModule, fl.WeightedModule]:
|
||||
...
|
||||
|
||||
@property
|
||||
def down(self) -> fl.WeightedModule:
|
||||
down_layer = self[0]
|
||||
assert isinstance(down_layer, fl.WeightedModule)
|
||||
return down_layer
|
||||
|
||||
@property
|
||||
def up(self) -> fl.WeightedModule:
|
||||
up_layer = self[1]
|
||||
assert isinstance(up_layer, fl.WeightedModule)
|
||||
return up_layer
|
||||
|
||||
@property
|
||||
def scale(self) -> float:
|
||||
return self._scale
|
||||
|
||||
@scale.setter
|
||||
def scale(self, value: float) -> None:
|
||||
self._scale = value
|
||||
self.ensure_find(fl.Multiply).scale = value
|
||||
|
||||
@classmethod
|
||||
def from_weights(
|
||||
cls,
|
||||
down: Tensor,
|
||||
up: Tensor,
|
||||
) -> "Lora":
|
||||
match (up.ndim, down.ndim):
|
||||
case (2, 2):
|
||||
return LinearLora.from_weights(up=up, down=down)
|
||||
case (4, 4):
|
||||
return Conv2dLora.from_weights(up=up, down=down)
|
||||
case _:
|
||||
raise ValueError(f"Unsupported weight shapes: up={up.shape}, down={down.shape}")
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, state_dict: dict[str, Tensor], /) -> dict[str, "Lora"]:
|
||||
"""
|
||||
Create a dictionary of LoRA layers from a state dict.
|
||||
|
||||
Expects the state dict to be a succession of down and up weights.
|
||||
"""
|
||||
state_dict = {k: v for k, v in state_dict.items() if ".weight" in k}
|
||||
loras: dict[str, Lora] = {}
|
||||
for down_key, down_tensor, up_tensor in zip(
|
||||
list(state_dict.keys())[::2], list(state_dict.values())[::2], list(state_dict.values())[1::2]
|
||||
):
|
||||
key = ".".join(down_key.split(".")[:-2])
|
||||
loras[key] = cls.from_weights(down=down_tensor, up=up_tensor)
|
||||
return loras
|
||||
|
||||
@abstractmethod
|
||||
def auto_attach(self, target: fl.Chain, exclude: list[str] | None = None) -> Any:
|
||||
...
|
||||
|
||||
def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None:
|
||||
assert down_weight.shape == self.down.weight.shape
|
||||
assert up_weight.shape == self.up.weight.shape
|
||||
self.down.weight = TorchParameter(down_weight.to(device=self.device, dtype=self.dtype))
|
||||
self.up.weight = TorchParameter(up_weight.to(device=self.device, dtype=self.dtype))
|
||||
|
||||
|
||||
class LinearLora(Lora):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
rank: int = 16,
|
||||
scale: float = 1.0,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.rank = rank
|
||||
self.scale: float = 1.0
|
||||
|
||||
super().__init__(
|
||||
fl.Linear(in_features=in_features, out_features=rank, bias=False, device=device, dtype=dtype),
|
||||
fl.Linear(in_features=rank, out_features=out_features, bias=False, device=device, dtype=dtype),
|
||||
fl.Lambda(func=self.scale_outputs),
|
||||
super().__init__(rank=rank, scale=scale, device=device, dtype=dtype)
|
||||
|
||||
@classmethod
|
||||
def from_weights(
|
||||
cls,
|
||||
down: Tensor,
|
||||
up: Tensor,
|
||||
) -> "LinearLora":
|
||||
assert up.ndim == 2 and down.ndim == 2
|
||||
assert down.shape[0] == up.shape[1], f"Rank mismatch: down rank={down.shape[0]} and up rank={up.shape[1]}"
|
||||
lora = cls(
|
||||
in_features=down.shape[1], out_features=up.shape[0], rank=down.shape[0], device=up.device, dtype=up.dtype
|
||||
)
|
||||
lora.load_weights(down_weight=down, up_weight=up)
|
||||
return lora
|
||||
|
||||
def auto_attach(self, target: Chain, exclude: list[str] | None = None) -> "tuple[LoraAdapter, fl.Chain] | None":
|
||||
for layer, parent in target.walk(fl.Linear):
|
||||
if isinstance(parent, Lora) or isinstance(parent, LoraAdapter):
|
||||
continue
|
||||
|
||||
if exclude is not None and any(
|
||||
[any([p.__class__.__name__ == e for p in parent.get_parents() + [parent]]) for e in exclude]
|
||||
):
|
||||
continue
|
||||
|
||||
if layer.in_features == self.in_features and layer.out_features == self.out_features:
|
||||
return LoraAdapter(target=layer, lora=self), parent
|
||||
|
||||
def lora_layers(
|
||||
self, device: Device | str | None = None, dtype: DType | None = None
|
||||
) -> tuple[fl.Linear, fl.Linear]:
|
||||
return (
|
||||
fl.Linear(
|
||||
in_features=self.in_features,
|
||||
out_features=self.rank,
|
||||
bias=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
fl.Linear(
|
||||
in_features=self.rank,
|
||||
out_features=self.out_features,
|
||||
bias=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
)
|
||||
|
||||
normal_(tensor=self.Linear_1.weight, std=1 / self.rank)
|
||||
zeros_(tensor=self.Linear_2.weight)
|
||||
|
||||
def scale_outputs(self, x: Tensor) -> Tensor:
|
||||
return x * self.scale
|
||||
|
||||
def set_scale(self, scale: float) -> None:
|
||||
self.scale = scale
|
||||
|
||||
def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None:
|
||||
self.Linear_1.weight = TorchParameter(down_weight.to(device=self.device, dtype=self.dtype))
|
||||
self.Linear_2.weight = TorchParameter(up_weight.to(device=self.device, dtype=self.dtype))
|
||||
|
||||
@property
|
||||
def up_weight(self) -> Tensor:
|
||||
return self.Linear_2.weight.data
|
||||
|
||||
@property
|
||||
def down_weight(self) -> Tensor:
|
||||
return self.Linear_1.weight.data
|
||||
|
||||
|
||||
class SingleLoraAdapter(fl.Sum, Adapter[fl.Linear]):
|
||||
class Conv2dLora(Lora):
|
||||
def __init__(
|
||||
self,
|
||||
target: fl.Linear,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
rank: int = 16,
|
||||
scale: float = 1.0,
|
||||
kernel_size: tuple[int, int] = (1, 3),
|
||||
stride: tuple[int, int] = (1, 1),
|
||||
padding: tuple[int, int] = (0, 1),
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
self.in_features = target.in_features
|
||||
self.out_features = target.out_features
|
||||
self.rank = rank
|
||||
self.scale = scale
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
|
||||
super().__init__(rank=rank, scale=scale, device=device, dtype=dtype)
|
||||
|
||||
@classmethod
|
||||
def from_weights(
|
||||
cls,
|
||||
down: Tensor,
|
||||
up: Tensor,
|
||||
) -> "Conv2dLora":
|
||||
assert up.ndim == 4 and down.ndim == 4
|
||||
assert down.shape[0] == up.shape[1], f"Rank mismatch: down rank={down.shape[0]} and up rank={up.shape[1]}"
|
||||
down_kernel_size, up_kernel_size = down.shape[2], up.shape[2]
|
||||
down_padding = 1 if down_kernel_size == 3 else 0
|
||||
up_padding = 1 if up_kernel_size == 3 else 0
|
||||
lora = cls(
|
||||
in_channels=down.shape[1],
|
||||
out_channels=up.shape[0],
|
||||
rank=down.shape[0],
|
||||
kernel_size=(down_kernel_size, up_kernel_size),
|
||||
padding=(down_padding, up_padding),
|
||||
device=up.device,
|
||||
dtype=up.dtype,
|
||||
)
|
||||
lora.load_weights(down_weight=down, up_weight=up)
|
||||
return lora
|
||||
|
||||
def auto_attach(self, target: Chain, exclude: list[str] | None = None) -> "tuple[LoraAdapter, fl.Chain] | None":
|
||||
for layer, parent in target.walk(fl.Conv2d):
|
||||
if isinstance(parent, Lora) or isinstance(parent, LoraAdapter):
|
||||
continue
|
||||
|
||||
if exclude is not None and any(
|
||||
[any([p.__class__.__name__ == e for p in parent.get_parents() + [parent]]) for e in exclude]
|
||||
):
|
||||
continue
|
||||
|
||||
if layer.in_channels == self.in_channels and layer.out_channels == self.out_channels:
|
||||
if layer.stride != (self.stride[0], self.stride[0]):
|
||||
self.down.stride = layer.stride
|
||||
|
||||
return LoraAdapter(
|
||||
target=layer,
|
||||
lora=self,
|
||||
), parent
|
||||
|
||||
def lora_layers(
|
||||
self, device: Device | str | None = None, dtype: DType | None = None
|
||||
) -> tuple[fl.Conv2d, fl.Conv2d]:
|
||||
return (
|
||||
fl.Conv2d(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.rank,
|
||||
kernel_size=self.kernel_size[0],
|
||||
stride=self.stride[0],
|
||||
padding=self.padding[0],
|
||||
use_bias=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
fl.Conv2d(
|
||||
in_channels=self.rank,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=self.kernel_size[1],
|
||||
stride=self.stride[1],
|
||||
padding=self.padding[1],
|
||||
use_bias=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]):
|
||||
def __init__(self, target: fl.WeightedModule, lora: Lora) -> None:
|
||||
with self.setup_adapter(target):
|
||||
super().__init__(
|
||||
target,
|
||||
Lora(
|
||||
in_features=target.in_features,
|
||||
out_features=target.out_features,
|
||||
rank=rank,
|
||||
device=target.device,
|
||||
dtype=target.dtype,
|
||||
),
|
||||
)
|
||||
self.Lora.set_scale(scale=scale)
|
||||
|
||||
|
||||
class LoraAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||
def __init__(
|
||||
self,
|
||||
target: T,
|
||||
sub_targets: Iterable[tuple[fl.Linear, fl.Chain]],
|
||||
rank: int | None = None,
|
||||
scale: float = 1.0,
|
||||
weights: list[Tensor] | None = None,
|
||||
) -> None:
|
||||
with self.setup_adapter(target):
|
||||
super().__init__(target)
|
||||
|
||||
if weights is not None:
|
||||
assert len(weights) % 2 == 0
|
||||
weights_rank = weights[0].shape[1]
|
||||
if rank is None:
|
||||
rank = weights_rank
|
||||
else:
|
||||
assert rank == weights_rank
|
||||
|
||||
assert rank is not None, "either pass a rank or weights"
|
||||
|
||||
self.sub_targets = sub_targets
|
||||
self.sub_adapters: list[tuple[SingleLoraAdapter, fl.Chain]] = []
|
||||
|
||||
for linear, parent in self.sub_targets:
|
||||
self.sub_adapters.append((SingleLoraAdapter(target=linear, rank=rank, scale=scale), parent))
|
||||
|
||||
if weights is not None:
|
||||
assert len(self.sub_adapters) == (len(weights) // 2)
|
||||
for i, (adapter, _) in enumerate(self.sub_adapters):
|
||||
lora = adapter.Lora
|
||||
assert (
|
||||
lora.rank == weights[i * 2].shape[1]
|
||||
), f"Rank of Lora layer {lora.rank} must match shape of weights {weights[i*2].shape[1]}"
|
||||
adapter.Lora.load_weights(up_weight=weights[i * 2], down_weight=weights[i * 2 + 1])
|
||||
|
||||
def inject(self: TLoraAdapter, parent: fl.Chain | None = None) -> TLoraAdapter:
|
||||
for adapter, adapter_parent in self.sub_adapters:
|
||||
adapter.inject(adapter_parent)
|
||||
return super().inject(parent)
|
||||
|
||||
def eject(self) -> None:
|
||||
for adapter, _ in self.sub_adapters:
|
||||
adapter.eject()
|
||||
super().eject()
|
||||
super().__init__(target, lora)
|
||||
|
||||
@property
|
||||
def weights(self) -> list[Tensor]:
|
||||
return [w for adapter, _ in self.sub_adapters for w in [adapter.Lora.up_weight, adapter.Lora.down_weight]]
|
||||
def lora(self) -> Lora:
|
||||
return self.ensure_find(Lora)
|
||||
|
||||
@property
|
||||
def scale(self) -> float:
|
||||
return self.lora.scale
|
||||
|
||||
@scale.setter
|
||||
def scale(self, value: float) -> None:
|
||||
self.lora.scale = value
|
||||
|
@ -164,6 +164,9 @@ class WeightedModule(Module):
|
||||
def dtype(self) -> DType:
|
||||
return self.weight.dtype
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{super().__str__().removesuffix(')')}, device={self.device}, dtype={str(self.dtype).removeprefix('torch.')})"
|
||||
|
||||
|
||||
class TreeNode(TypedDict):
|
||||
value: str
|
||||
|
@ -146,6 +146,7 @@ def tensor_to_image(tensor: Tensor) -> Image.Image:
|
||||
assert tensor.ndim == 4 and tensor.shape[0] == 1, f"Unsupported tensor shape: {tensor.shape}"
|
||||
num_channels = tensor.shape[1]
|
||||
tensor = tensor.clamp(0, 1).squeeze(0)
|
||||
tensor = tensor.to(torch.float32) # to avoid numpy error with bfloat16
|
||||
|
||||
match num_channels:
|
||||
case 1:
|
||||
@ -187,20 +188,26 @@ def save_to_safetensors(path: Path | str, tensors: dict[str, Tensor], metadata:
|
||||
|
||||
|
||||
def summarize_tensor(tensor: torch.Tensor, /) -> str:
|
||||
return (
|
||||
"Tensor("
|
||||
+ ", ".join(
|
||||
info_list = [
|
||||
f"shape=({', '.join(map(str, tensor.shape))})",
|
||||
f"dtype={str(object=tensor.dtype).removeprefix('torch.')}",
|
||||
f"device={tensor.device}",
|
||||
]
|
||||
if not tensor.is_complex():
|
||||
info_list.extend(
|
||||
[
|
||||
f"shape=({', '.join(map(str, tensor.shape))})",
|
||||
f"dtype={str(object=tensor.dtype).removeprefix('torch.')}",
|
||||
f"device={tensor.device}",
|
||||
f"min={tensor.min():.2f}", # type: ignore
|
||||
f"max={tensor.max():.2f}", # type: ignore
|
||||
f"mean={tensor.mean():.2f}",
|
||||
f"std={tensor.std():.2f}",
|
||||
f"norm={norm(x=tensor):.2f}",
|
||||
f"grad={tensor.requires_grad}",
|
||||
]
|
||||
)
|
||||
+ ")"
|
||||
|
||||
info_list.extend(
|
||||
[
|
||||
f"mean={tensor.float().mean():.2f}",
|
||||
f"std={tensor.float().std():.2f}",
|
||||
f"norm={norm(x=tensor.float()):.2f}",
|
||||
f"grad={tensor.requires_grad}",
|
||||
]
|
||||
)
|
||||
|
||||
return "Tensor(" + ", ".join(info_list) + ")"
|
||||
|
@ -1,15 +1,12 @@
|
||||
import math
|
||||
from enum import IntEnum
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, Generic, TypeVar
|
||||
|
||||
from jaxtyping import Float
|
||||
from PIL import Image
|
||||
from torch import Tensor, cat, device as Device, dtype as DType, softmax, zeros_like
|
||||
from torch import Tensor, cat, device as Device, dtype as DType, nn, softmax, zeros_like
|
||||
|
||||
import imaginairy.vendored.refiners.fluxion.layers as fl
|
||||
from imaginairy.vendored.refiners.fluxion.adapters.adapter import Adapter
|
||||
from imaginairy.vendored.refiners.fluxion.adapters.lora import Lora
|
||||
from imaginairy.vendored.refiners.fluxion.context import Contexts
|
||||
from imaginairy.vendored.refiners.fluxion.layers.attentions import ScaledDotProductAttention
|
||||
from imaginairy.vendored.refiners.fluxion.utils import image_to_tensor, normalize
|
||||
@ -236,120 +233,99 @@ class PerceiverResampler(fl.Chain):
|
||||
return {"perceiver_resampler": {"x": None}}
|
||||
|
||||
|
||||
class _CrossAttnIndex(IntEnum):
|
||||
TXT_CROSS_ATTN = 0 # text cross-attention
|
||||
IMG_CROSS_ATTN = 1 # image cross-attention
|
||||
class ImageCrossAttention(fl.Chain):
|
||||
def __init__(self, text_cross_attention: fl.Attention, scale: float = 1.0) -> None:
|
||||
self._scale = scale
|
||||
super().__init__(
|
||||
fl.Distribute(
|
||||
fl.Identity(),
|
||||
fl.Chain(
|
||||
fl.UseContext(context="ip_adapter", key="clip_image_embedding"),
|
||||
fl.Linear(
|
||||
in_features=text_cross_attention.key_embedding_dim,
|
||||
out_features=text_cross_attention.inner_dim,
|
||||
bias=text_cross_attention.use_bias,
|
||||
device=text_cross_attention.device,
|
||||
dtype=text_cross_attention.dtype,
|
||||
),
|
||||
),
|
||||
fl.Chain(
|
||||
fl.UseContext(context="ip_adapter", key="clip_image_embedding"),
|
||||
fl.Linear(
|
||||
in_features=text_cross_attention.value_embedding_dim,
|
||||
out_features=text_cross_attention.inner_dim,
|
||||
bias=text_cross_attention.use_bias,
|
||||
device=text_cross_attention.device,
|
||||
dtype=text_cross_attention.dtype,
|
||||
),
|
||||
),
|
||||
),
|
||||
ScaledDotProductAttention(
|
||||
num_heads=text_cross_attention.num_heads, is_causal=text_cross_attention.is_causal
|
||||
),
|
||||
fl.Multiply(self.scale),
|
||||
)
|
||||
|
||||
@property
|
||||
def scale(self) -> float:
|
||||
return self._scale
|
||||
|
||||
class InjectionPoint(fl.Chain):
|
||||
pass
|
||||
@scale.setter
|
||||
def scale(self, value: float) -> None:
|
||||
self._scale = value
|
||||
self.ensure_find(fl.Multiply).scale = value
|
||||
|
||||
|
||||
class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
|
||||
def __init__(
|
||||
self,
|
||||
target: fl.Attention,
|
||||
text_sequence_length: int = 77,
|
||||
image_sequence_length: int = 4,
|
||||
scale: float = 1.0,
|
||||
) -> None:
|
||||
self.text_sequence_length = text_sequence_length
|
||||
self.image_sequence_length = image_sequence_length
|
||||
self.scale = scale
|
||||
|
||||
self._scale = scale
|
||||
with self.setup_adapter(target):
|
||||
clone = target.structural_copy()
|
||||
scaled_dot_product = clone.ensure_find(ScaledDotProductAttention)
|
||||
image_cross_attention = ImageCrossAttention(
|
||||
text_cross_attention=clone,
|
||||
scale=self.scale,
|
||||
)
|
||||
clone.replace(
|
||||
old_module=scaled_dot_product,
|
||||
new_module=fl.Sum(
|
||||
scaled_dot_product,
|
||||
image_cross_attention,
|
||||
),
|
||||
)
|
||||
super().__init__(
|
||||
fl.Distribute(
|
||||
# Note: the same query is used for image cross-attention as for text cross-attention
|
||||
InjectionPoint(), # Wq
|
||||
fl.Parallel(
|
||||
fl.Chain(
|
||||
fl.Slicing(dim=1, end=text_sequence_length),
|
||||
InjectionPoint(), # Wk
|
||||
),
|
||||
fl.Chain(
|
||||
fl.Slicing(dim=1, start=text_sequence_length),
|
||||
fl.Linear(
|
||||
in_features=self.target.key_embedding_dim,
|
||||
out_features=self.target.inner_dim,
|
||||
bias=self.target.use_bias,
|
||||
device=target.device,
|
||||
dtype=target.dtype,
|
||||
), # Wk'
|
||||
),
|
||||
),
|
||||
fl.Parallel(
|
||||
fl.Chain(
|
||||
fl.Slicing(dim=1, end=text_sequence_length),
|
||||
InjectionPoint(), # Wv
|
||||
),
|
||||
fl.Chain(
|
||||
fl.Slicing(dim=1, start=text_sequence_length),
|
||||
fl.Linear(
|
||||
in_features=self.target.key_embedding_dim,
|
||||
out_features=self.target.inner_dim,
|
||||
bias=self.target.use_bias,
|
||||
device=target.device,
|
||||
dtype=target.dtype,
|
||||
), # Wv'
|
||||
),
|
||||
),
|
||||
),
|
||||
fl.Sum(
|
||||
fl.Chain(
|
||||
fl.Lambda(func=partial(self.select_qkv, index=_CrossAttnIndex.TXT_CROSS_ATTN)),
|
||||
ScaledDotProductAttention(num_heads=target.num_heads, is_causal=target.is_causal),
|
||||
),
|
||||
fl.Chain(
|
||||
fl.Lambda(func=partial(self.select_qkv, index=_CrossAttnIndex.IMG_CROSS_ATTN)),
|
||||
ScaledDotProductAttention(num_heads=target.num_heads, is_causal=target.is_causal),
|
||||
fl.Lambda(func=self.scale_outputs),
|
||||
),
|
||||
),
|
||||
InjectionPoint(), # proj
|
||||
clone,
|
||||
)
|
||||
|
||||
def select_qkv(
|
||||
self, query: Tensor, keys: tuple[Tensor, Tensor], values: tuple[Tensor, Tensor], index: _CrossAttnIndex
|
||||
) -> tuple[Tensor, Tensor, Tensor]:
|
||||
return (query, keys[index.value], values[index.value])
|
||||
@property
|
||||
def image_cross_attention(self) -> ImageCrossAttention:
|
||||
return self.ensure_find(ImageCrossAttention)
|
||||
|
||||
def scale_outputs(self, x: Tensor) -> Tensor:
|
||||
return x * self.scale
|
||||
@property
|
||||
def image_key_projection(self) -> fl.Linear:
|
||||
return self.image_cross_attention.Distribute[1].Linear
|
||||
|
||||
def _predicate(self, k: type[fl.Module]) -> Callable[[fl.Module, fl.Chain], bool]:
|
||||
def f(m: fl.Module, _: fl.Chain) -> bool:
|
||||
if isinstance(m, Lora): # do not adapt LoRAs
|
||||
raise StopIteration
|
||||
return isinstance(m, k)
|
||||
@property
|
||||
def image_value_projection(self) -> fl.Linear:
|
||||
return self.image_cross_attention.Distribute[2].Linear
|
||||
|
||||
return f
|
||||
@property
|
||||
def scale(self) -> float:
|
||||
return self._scale
|
||||
|
||||
def _target_linears(self) -> list[fl.Linear]:
|
||||
return [m for m, _ in self.target.walk(self._predicate(fl.Linear)) if isinstance(m, fl.Linear)]
|
||||
@scale.setter
|
||||
def scale(self, value: float) -> None:
|
||||
self._scale = value
|
||||
self.image_cross_attention.scale = value
|
||||
|
||||
def inject(self: "CrossAttentionAdapter", parent: fl.Chain | None = None) -> "CrossAttentionAdapter":
|
||||
linears = self._target_linears()
|
||||
assert len(linears) == 4 # Wq, Wk, Wv and Proj
|
||||
|
||||
injection_points = list(self.layers(InjectionPoint))
|
||||
assert len(injection_points) == 4
|
||||
|
||||
for linear, ip in zip(linears, injection_points):
|
||||
ip.append(linear)
|
||||
assert len(ip) == 1
|
||||
|
||||
return super().inject(parent)
|
||||
|
||||
def eject(self) -> None:
|
||||
injection_points = list(self.layers(InjectionPoint))
|
||||
assert len(injection_points) == 4
|
||||
|
||||
for ip in injection_points:
|
||||
ip.pop()
|
||||
assert len(ip) == 0
|
||||
|
||||
super().eject()
|
||||
def load_weights(self, key_tensor: Tensor, value_tensor: Tensor) -> None:
|
||||
self.image_key_projection.weight = nn.Parameter(key_tensor)
|
||||
self.image_value_projection.weight = nn.Parameter(value_tensor)
|
||||
self.image_cross_attention.to(self.device, self.dtype)
|
||||
|
||||
|
||||
class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||
@ -377,7 +353,7 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||
self._image_proj = [image_proj]
|
||||
|
||||
self.sub_adapters = [
|
||||
CrossAttentionAdapter(target=cross_attn, scale=scale, image_sequence_length=self.image_proj.num_tokens)
|
||||
CrossAttentionAdapter(target=cross_attn, scale=scale)
|
||||
for cross_attn in filter(lambda attn: type(attn) != fl.SelfAttention, target.layers(fl.Attention))
|
||||
]
|
||||
|
||||
@ -388,14 +364,15 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||
self.image_proj.load_state_dict(image_proj_state_dict)
|
||||
|
||||
for i, cross_attn in enumerate(self.sub_adapters):
|
||||
cross_attn_state_dict: dict[str, Tensor] = {}
|
||||
cross_attention_weights: list[Tensor] = []
|
||||
for k, v in weights.items():
|
||||
prefix = f"ip_adapter.{i:03d}."
|
||||
if not k.startswith(prefix):
|
||||
continue
|
||||
cross_attn_state_dict[k.removeprefix(prefix)] = v
|
||||
cross_attention_weights.append(v)
|
||||
|
||||
cross_attn.load_state_dict(state_dict=cross_attn_state_dict)
|
||||
assert len(cross_attention_weights) == 2
|
||||
cross_attn.load_weights(*cross_attention_weights)
|
||||
|
||||
@property
|
||||
def clip_image_encoder(self) -> CLIPImageEncoderH:
|
||||
@ -420,10 +397,22 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||
adapter.eject()
|
||||
super().eject()
|
||||
|
||||
@property
|
||||
def scale(self) -> float:
|
||||
return self.sub_adapters[0].scale
|
||||
|
||||
@scale.setter
|
||||
def scale(self, value: float) -> None:
|
||||
for cross_attn in self.sub_adapters:
|
||||
cross_attn.scale = value
|
||||
|
||||
def set_scale(self, scale: float) -> None:
|
||||
for cross_attn in self.sub_adapters:
|
||||
cross_attn.scale = scale
|
||||
|
||||
def set_clip_image_embedding(self, image_embedding: Tensor) -> None:
|
||||
self.set_context("ip_adapter", {"clip_image_embedding": image_embedding})
|
||||
|
||||
# These should be concatenated to the CLIP text embedding before setting the UNet context
|
||||
def compute_clip_image_embedding(self, image_prompt: Tensor) -> Tensor:
|
||||
image_encoder = self.clip_image_encoder if not self.fine_grained else self.grid_image_encoder
|
||||
|
@ -1,146 +1,140 @@
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Callable, Iterator
|
||||
from warnings import warn
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
import imaginairy.vendored.refiners.fluxion.layers as fl
|
||||
from imaginairy.vendored.refiners.fluxion.adapters.adapter import Adapter
|
||||
from imaginairy.vendored.refiners.fluxion.adapters.lora import Lora, LoraAdapter
|
||||
from imaginairy.vendored.refiners.fluxion.utils import load_from_safetensors, load_metadata_from_safetensors
|
||||
from imaginairy.vendored.refiners.foundationals.clip.text_encoder import FeedForward, TransformerLayer
|
||||
from imaginairy.vendored.refiners.foundationals.latent_diffusion import (
|
||||
CLIPTextEncoderL,
|
||||
LatentDiffusionAutoencoder,
|
||||
SD1UNet,
|
||||
StableDiffusion_1,
|
||||
)
|
||||
from imaginairy.vendored.refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
|
||||
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import Controlnet
|
||||
|
||||
MODELS = ["unet", "text_encoder", "lda"]
|
||||
from imaginairy.vendored.refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
|
||||
|
||||
|
||||
class LoraTarget(str, Enum):
|
||||
Self = "self"
|
||||
Attention = "Attention"
|
||||
SelfAttention = "SelfAttention"
|
||||
CrossAttention = "CrossAttentionBlock2d"
|
||||
FeedForward = "FeedForward"
|
||||
TransformerLayer = "TransformerLayer"
|
||||
|
||||
def get_class(self) -> type[fl.Chain]:
|
||||
match self:
|
||||
case LoraTarget.Self:
|
||||
return fl.Chain
|
||||
case LoraTarget.Attention:
|
||||
return fl.Attention
|
||||
case LoraTarget.SelfAttention:
|
||||
return fl.SelfAttention
|
||||
case LoraTarget.CrossAttention:
|
||||
return CrossAttentionBlock2d
|
||||
case LoraTarget.FeedForward:
|
||||
return FeedForward
|
||||
case LoraTarget.TransformerLayer:
|
||||
return TransformerLayer
|
||||
|
||||
|
||||
def _predicate(k: type[fl.Module]) -> Callable[[fl.Module, fl.Chain], bool]:
|
||||
def f(m: fl.Module, _: fl.Chain) -> bool:
|
||||
if isinstance(m, Lora): # do not adapt other LoRAs
|
||||
raise StopIteration
|
||||
if isinstance(m, Controlnet): # do not adapt Controlnet linears
|
||||
raise StopIteration
|
||||
return isinstance(m, k)
|
||||
|
||||
return f
|
||||
|
||||
|
||||
def _iter_linears(module: fl.Chain) -> Iterator[tuple[fl.Linear, fl.Chain]]:
|
||||
for m, p in module.walk(_predicate(fl.Linear)):
|
||||
assert isinstance(m, fl.Linear)
|
||||
yield (m, p)
|
||||
|
||||
|
||||
def lora_targets(
|
||||
module: fl.Chain,
|
||||
target: LoraTarget | list[LoraTarget],
|
||||
) -> Iterator[tuple[fl.Linear, fl.Chain]]:
|
||||
if isinstance(target, list):
|
||||
for t in target:
|
||||
yield from lora_targets(module, t)
|
||||
return
|
||||
|
||||
if target == LoraTarget.Self:
|
||||
yield from _iter_linears(module)
|
||||
return
|
||||
|
||||
for layer, _ in module.walk(_predicate(target.get_class())):
|
||||
assert isinstance(layer, fl.Chain)
|
||||
yield from _iter_linears(layer)
|
||||
|
||||
|
||||
class SD1LoraAdapter(fl.Chain, Adapter[StableDiffusion_1]):
|
||||
metadata: dict[str, str] | None
|
||||
tensors: dict[str, Tensor]
|
||||
|
||||
class SDLoraManager:
|
||||
def __init__(
|
||||
self,
|
||||
target: StableDiffusion_1,
|
||||
sub_targets: dict[str, list[LoraTarget]],
|
||||
target: LatentDiffusionModel,
|
||||
) -> None:
|
||||
self.target = target
|
||||
|
||||
@property
|
||||
def unet(self) -> fl.Chain:
|
||||
unet = self.target.unet
|
||||
assert isinstance(unet, fl.Chain)
|
||||
return unet
|
||||
|
||||
@property
|
||||
def clip_text_encoder(self) -> fl.Chain:
|
||||
clip_text_encoder = self.target.clip_text_encoder
|
||||
assert isinstance(clip_text_encoder, fl.Chain)
|
||||
return clip_text_encoder
|
||||
|
||||
def load(
|
||||
self,
|
||||
tensors: dict[str, Tensor],
|
||||
/,
|
||||
scale: float = 1.0,
|
||||
weights: dict[str, Tensor] | None = None,
|
||||
):
|
||||
with self.setup_adapter(target):
|
||||
super().__init__(target)
|
||||
) -> None:
|
||||
"""Load the LoRA weights from a dictionary of tensors.
|
||||
|
||||
self.sub_adapters: list[LoraAdapter[SD1UNet | CLIPTextEncoderL | LatentDiffusionAutoencoder]] = []
|
||||
|
||||
for model_name in MODELS:
|
||||
if not (model_targets := sub_targets.get(model_name, [])):
|
||||
continue
|
||||
model = getattr(target, "clip_text_encoder" if model_name == "text_encoder" else model_name)
|
||||
|
||||
lora_weights = [weights[k] for k in sorted(weights) if k.startswith(model_name)] if weights else None
|
||||
self.sub_adapters.append(
|
||||
LoraAdapter[type(model)](
|
||||
model,
|
||||
sub_targets=lora_targets(model, model_targets),
|
||||
scale=scale,
|
||||
weights=lora_weights,
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_safetensors(
|
||||
cls,
|
||||
target: StableDiffusion_1,
|
||||
checkpoint_path: Path | str,
|
||||
scale: float = 1.0,
|
||||
):
|
||||
metadata = load_metadata_from_safetensors(checkpoint_path)
|
||||
assert metadata is not None, "Invalid safetensors checkpoint: missing metadata"
|
||||
tensors = load_from_safetensors(checkpoint_path, device=target.device)
|
||||
|
||||
sub_targets: dict[str, list[LoraTarget]] = {}
|
||||
for model_name in MODELS:
|
||||
if not (v := metadata.get(f"{model_name}_targets", "")):
|
||||
continue
|
||||
sub_targets[model_name] = [LoraTarget(x) for x in v.split(",")]
|
||||
|
||||
return cls(
|
||||
target,
|
||||
sub_targets,
|
||||
scale=scale,
|
||||
weights=tensors,
|
||||
Expects the keys to be in the commonly found formats on CivitAI's hub.
|
||||
"""
|
||||
assert len(self.lora_adapters) == 0, "Loras already loaded"
|
||||
loras = Lora.from_dict(
|
||||
{key: value.to(device=self.target.device, dtype=self.target.dtype) for key, value in tensors.items()}
|
||||
)
|
||||
loras = {key: loras[key] for key in sorted(loras.keys(), key=SDLoraManager.sort_keys)}
|
||||
|
||||
def inject(self: "SD1LoraAdapter", parent: fl.Chain | None = None) -> "SD1LoraAdapter":
|
||||
for adapter in self.sub_adapters:
|
||||
adapter.inject()
|
||||
return super().inject(parent)
|
||||
# if no key contains "unet" or "text", assume all keys are for the unet
|
||||
if not "unet" in loras and not "text" in loras:
|
||||
loras = {f"unet_{key}": loras[key] for key in loras.keys()}
|
||||
|
||||
def eject(self) -> None:
|
||||
for adapter in self.sub_adapters:
|
||||
adapter.eject()
|
||||
super().eject()
|
||||
self.load_unet(loras)
|
||||
self.load_text_encoder(loras)
|
||||
|
||||
self.scale = scale
|
||||
|
||||
def load_text_encoder(self, loras: dict[str, Lora], /) -> None:
|
||||
text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key}
|
||||
SDLoraManager.auto_attach(text_encoder_loras, self.clip_text_encoder)
|
||||
|
||||
def load_unet(self, loras: dict[str, Lora], /) -> None:
|
||||
unet_loras = {key: loras[key] for key in loras.keys() if "unet" in key}
|
||||
exclude: list[str] = []
|
||||
exclude = [
|
||||
self.unet_exclusions[exclusion]
|
||||
for exclusion in self.unet_exclusions
|
||||
if all([exclusion not in key for key in unet_loras.keys()])
|
||||
]
|
||||
SDLoraManager.auto_attach(unet_loras, self.unet, exclude=exclude)
|
||||
|
||||
def unload(self) -> None:
|
||||
for lora_adapter in self.lora_adapters:
|
||||
lora_adapter.eject()
|
||||
|
||||
@property
|
||||
def loras(self) -> list[Lora]:
|
||||
return list(self.unet.layers(Lora)) + list(self.clip_text_encoder.layers(Lora))
|
||||
|
||||
@property
|
||||
def lora_adapters(self) -> list[LoraAdapter]:
|
||||
return list(self.unet.layers(LoraAdapter)) + list(self.clip_text_encoder.layers(LoraAdapter))
|
||||
|
||||
@property
|
||||
def unet_exclusions(self) -> dict[str, str]:
|
||||
return {
|
||||
"time": "TimestepEncoder",
|
||||
"res": "ResidualBlock",
|
||||
"downsample": "DownsampleBlock",
|
||||
"upsample": "UpsampleBlock",
|
||||
}
|
||||
|
||||
@property
|
||||
def scale(self) -> float:
|
||||
assert len(self.loras) > 0, "No loras found"
|
||||
assert all([lora.scale == self.loras[0].scale for lora in self.loras])
|
||||
return self.loras[0].scale
|
||||
|
||||
@scale.setter
|
||||
def scale(self, value: float) -> None:
|
||||
for lora in self.loras:
|
||||
lora.scale = value
|
||||
|
||||
@staticmethod
|
||||
def pad(input: str, /, padding_length: int = 2) -> str:
|
||||
new_split: list[str] = []
|
||||
for s in input.split("_"):
|
||||
if s.isdigit():
|
||||
new_split.append(s.zfill(padding_length))
|
||||
else:
|
||||
new_split.append(s)
|
||||
return "_".join(new_split)
|
||||
|
||||
@staticmethod
|
||||
def sort_keys(key: str, /) -> tuple[str, int]:
|
||||
# out0 happens sometimes as an alias for out ; this dict might not be exhaustive
|
||||
key_char_order = {"q": 1, "k": 2, "v": 3, "out": 4, "out0": 4}
|
||||
|
||||
for i, s in enumerate(key.split("_")):
|
||||
if s in key_char_order:
|
||||
prefix = SDLoraManager.pad("_".join(key.split("_")[:i]))
|
||||
return (prefix, key_char_order[s])
|
||||
|
||||
return (SDLoraManager.pad(key), 5)
|
||||
|
||||
@staticmethod
|
||||
def auto_attach(
|
||||
loras: dict[str, Lora],
|
||||
target: fl.Chain,
|
||||
/,
|
||||
exclude: list[str] | None = None,
|
||||
) -> None:
|
||||
failed_loras: dict[str, Lora] = {}
|
||||
for key, lora in loras.items():
|
||||
if attach := lora.auto_attach(target, exclude=exclude):
|
||||
adapter, parent = attach
|
||||
adapter.inject(parent)
|
||||
else:
|
||||
failed_loras[key] = lora
|
||||
|
||||
if failed_loras:
|
||||
warn(f"failed to attach {len(failed_loras)}/{len(loras)} loras to {target.__class__.__name__}")
|
||||
|
||||
# TODO: add a stronger sanity check to make sure loras are attached correctly
|
||||
|
@ -11,7 +11,6 @@ from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.sche
|
||||
|
||||
T = TypeVar("T", bound="fl.Module")
|
||||
|
||||
|
||||
TLatentDiffusionModel = TypeVar("TLatentDiffusionModel", bound="LatentDiffusionModel")
|
||||
|
||||
|
||||
@ -91,6 +90,8 @@ class LatentDiffusionModel(fl.Module, ABC):
|
||||
self.set_unet_context(timestep=timestep, clip_text_embedding=clip_text_embedding, **kwargs)
|
||||
|
||||
latents = torch.cat(tensors=(x, x)) # for classifier-free guidance
|
||||
# scale latents for schedulers that need it
|
||||
latents = self.scheduler.scale_model_input(latents, step=step)
|
||||
unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2)
|
||||
|
||||
# classifier-free guidance
|
||||
|
@ -1,11 +1,7 @@
|
||||
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM
|
||||
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.ddpm import DDPM
|
||||
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver
|
||||
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.euler import EulerScheduler
|
||||
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
||||
|
||||
__all__ = [
|
||||
"Scheduler",
|
||||
"DPMSolver",
|
||||
"DDPM",
|
||||
"DDIM",
|
||||
]
|
||||
__all__ = ["Scheduler", "DPMSolver", "DDPM", "DDIM", "EulerScheduler"]
|
||||
|
@ -1,4 +1,4 @@
|
||||
from torch import Tensor, arange, device as Device, dtype as Dtype, float32, sqrt, tensor
|
||||
from torch import Generator, Tensor, arange, device as Device, dtype as Dtype, float32, sqrt, tensor
|
||||
|
||||
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
|
||||
|
||||
@ -34,7 +34,7 @@ class DDIM(Scheduler):
|
||||
timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio + 1
|
||||
return timesteps.flip(0)
|
||||
|
||||
def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
|
||||
def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
||||
timestep, previous_timestep = (
|
||||
self.timesteps[step],
|
||||
(
|
||||
@ -52,6 +52,12 @@ class DDIM(Scheduler):
|
||||
),
|
||||
)
|
||||
predicted_x = (x - sqrt(1 - current_scale_factor**2) * noise) / current_scale_factor
|
||||
denoised_x = previous_scale_factor * predicted_x + sqrt(1 - previous_scale_factor**2) * noise
|
||||
noise_factor = sqrt(1 - previous_scale_factor**2)
|
||||
|
||||
# Do not add noise at the last step to avoid visual artifacts.
|
||||
if step == self.num_inference_steps - 1:
|
||||
noise_factor = 0
|
||||
|
||||
denoised_x = previous_scale_factor * predicted_x + noise_factor * noise
|
||||
|
||||
return denoised_x
|
||||
|
@ -1,4 +1,4 @@
|
||||
from torch import Tensor, arange, device as Device
|
||||
from torch import Generator, Tensor, arange, device as Device
|
||||
|
||||
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
||||
|
||||
@ -30,5 +30,5 @@ class DDPM(Scheduler):
|
||||
timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio
|
||||
return timesteps.flip(0)
|
||||
|
||||
def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
|
||||
def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
||||
raise NotImplementedError
|
||||
|
@ -1,15 +1,19 @@
|
||||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
from torch import Tensor, device as Device, dtype as Dtype, exp, float32, tensor
|
||||
from torch import Generator, Tensor, device as Device, dtype as Dtype, exp, float32, tensor
|
||||
|
||||
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
|
||||
|
||||
|
||||
class DPMSolver(Scheduler):
|
||||
"""Implements DPM-Solver++ from https://arxiv.org/abs/2211.01095
|
||||
"""
|
||||
Implements DPM-Solver++ from https://arxiv.org/abs/2211.01095
|
||||
|
||||
We only support noise prediction for now.
|
||||
Regarding last_step_first_order: DPM-Solver++ is known to introduce artifacts
|
||||
when used with SDXL and few steps. This parameter is a way to mitigate that
|
||||
effect by using a first-order (Euler) update instead of a second-order update
|
||||
for the last step of the diffusion.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -18,6 +22,7 @@ class DPMSolver(Scheduler):
|
||||
num_train_timesteps: int = 1_000,
|
||||
initial_diffusion_rate: float = 8.5e-4,
|
||||
final_diffusion_rate: float = 1.2e-2,
|
||||
last_step_first_order: bool = False,
|
||||
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
|
||||
device: Device | str = "cpu",
|
||||
dtype: Dtype = float32,
|
||||
@ -32,7 +37,8 @@ class DPMSolver(Scheduler):
|
||||
dtype=dtype,
|
||||
)
|
||||
self.estimated_data = deque([tensor([])] * 2, maxlen=2)
|
||||
self.initial_steps = 0
|
||||
self.last_step_first_order = last_step_first_order
|
||||
self._first_step_has_been_run = False
|
||||
|
||||
def _generate_timesteps(self) -> Tensor:
|
||||
# We need to use numpy here because:
|
||||
@ -45,57 +51,48 @@ class DPMSolver(Scheduler):
|
||||
).flip(0)
|
||||
|
||||
def dpm_solver_first_order_update(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
|
||||
timestep, previous_timestep = (
|
||||
self.timesteps[step],
|
||||
self.timesteps[step + 1 if step < len(self.timesteps) - 1 else 0],
|
||||
)
|
||||
previous_ratio, current_ratio = (
|
||||
self.signal_to_noise_ratios[previous_timestep],
|
||||
self.signal_to_noise_ratios[timestep],
|
||||
)
|
||||
current_timestep = self.timesteps[step]
|
||||
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else tensor([0])
|
||||
|
||||
previous_ratio = self.signal_to_noise_ratios[previous_timestep]
|
||||
current_ratio = self.signal_to_noise_ratios[current_timestep]
|
||||
|
||||
previous_scale_factor = self.cumulative_scale_factors[previous_timestep]
|
||||
previous_noise_std, current_noise_std = (
|
||||
self.noise_std[previous_timestep],
|
||||
self.noise_std[timestep],
|
||||
)
|
||||
|
||||
previous_noise_std = self.noise_std[previous_timestep]
|
||||
current_noise_std = self.noise_std[current_timestep]
|
||||
|
||||
factor = exp(-(previous_ratio - current_ratio)) - 1.0
|
||||
denoised_x = (previous_noise_std / current_noise_std) * x - (factor * previous_scale_factor) * noise
|
||||
return denoised_x
|
||||
|
||||
def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int) -> Tensor:
|
||||
previous_timestep, current_timestep, next_timestep = (
|
||||
self.timesteps[step + 1] if step < len(self.timesteps) - 1 else tensor([0]),
|
||||
self.timesteps[step],
|
||||
self.timesteps[step - 1],
|
||||
)
|
||||
current_data_estimation, next_data_estimation = self.estimated_data[-1], self.estimated_data[-2]
|
||||
previous_ratio, current_ratio, next_ratio = (
|
||||
self.signal_to_noise_ratios[previous_timestep],
|
||||
self.signal_to_noise_ratios[current_timestep],
|
||||
self.signal_to_noise_ratios[next_timestep],
|
||||
)
|
||||
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else tensor([0])
|
||||
current_timestep = self.timesteps[step]
|
||||
next_timestep = self.timesteps[step - 1]
|
||||
|
||||
current_data_estimation = self.estimated_data[-1]
|
||||
next_data_estimation = self.estimated_data[-2]
|
||||
|
||||
previous_ratio = self.signal_to_noise_ratios[previous_timestep]
|
||||
current_ratio = self.signal_to_noise_ratios[current_timestep]
|
||||
next_ratio = self.signal_to_noise_ratios[next_timestep]
|
||||
|
||||
previous_scale_factor = self.cumulative_scale_factors[previous_timestep]
|
||||
previous_std, current_std = (
|
||||
self.noise_std[previous_timestep],
|
||||
self.noise_std[current_timestep],
|
||||
)
|
||||
previous_noise_std = self.noise_std[previous_timestep]
|
||||
current_noise_std = self.noise_std[current_timestep]
|
||||
estimation_delta = (current_data_estimation - next_data_estimation) / (
|
||||
(current_ratio - next_ratio) / (previous_ratio - current_ratio)
|
||||
)
|
||||
factor = exp(-(previous_ratio - current_ratio)) - 1.0
|
||||
denoised_x = (
|
||||
(previous_std / current_std) * x
|
||||
(previous_noise_std / current_noise_std) * x
|
||||
- (factor * previous_scale_factor) * current_data_estimation
|
||||
- 0.5 * (factor * previous_scale_factor) * estimation_delta
|
||||
)
|
||||
return denoised_x
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: Tensor,
|
||||
noise: Tensor,
|
||||
step: int,
|
||||
) -> Tensor:
|
||||
def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
||||
"""
|
||||
Represents one step of the backward diffusion process that iteratively denoises the input data `x`.
|
||||
|
||||
@ -107,11 +104,9 @@ class DPMSolver(Scheduler):
|
||||
scale_factor, noise_ratio = self.cumulative_scale_factors[current_timestep], self.noise_std[current_timestep]
|
||||
estimated_denoised_data = (x - noise_ratio * noise) / scale_factor
|
||||
self.estimated_data.append(estimated_denoised_data)
|
||||
denoised_x = (
|
||||
self.dpm_solver_first_order_update(x=x, noise=estimated_denoised_data, step=step)
|
||||
if (self.initial_steps == 0)
|
||||
else self.multistep_dpm_solver_second_order_update(x=x, step=step)
|
||||
)
|
||||
if self.initial_steps < 2:
|
||||
self.initial_steps += 1
|
||||
return denoised_x
|
||||
|
||||
if step == 0 or (self.last_step_first_order and step == self.num_inference_steps - 1) or not self._first_step_has_been_run:
|
||||
self._first_step_has_been_run = True
|
||||
return self.dpm_solver_first_order_update(x=x, noise=estimated_denoised_data, step=step)
|
||||
|
||||
return self.multistep_dpm_solver_second_order_update(x=x, step=step)
|
||||
|
@ -0,0 +1,84 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor
|
||||
|
||||
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
|
||||
|
||||
|
||||
class EulerScheduler(Scheduler):
|
||||
def __init__(
|
||||
self,
|
||||
num_inference_steps: int,
|
||||
num_train_timesteps: int = 1_000,
|
||||
initial_diffusion_rate: float = 8.5e-4,
|
||||
final_diffusion_rate: float = 1.2e-2,
|
||||
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
|
||||
device: Device | str = "cpu",
|
||||
dtype: Dtype = float32,
|
||||
):
|
||||
if noise_schedule != NoiseSchedule.QUADRATIC:
|
||||
raise NotImplementedError
|
||||
super().__init__(
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_train_timesteps=num_train_timesteps,
|
||||
initial_diffusion_rate=initial_diffusion_rate,
|
||||
final_diffusion_rate=final_diffusion_rate,
|
||||
noise_schedule=noise_schedule,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.sigmas = self._generate_sigmas()
|
||||
|
||||
@property
|
||||
def init_noise_sigma(self) -> Tensor:
|
||||
return self.sigmas.max()
|
||||
|
||||
def _generate_timesteps(self) -> Tensor:
|
||||
# We need to use numpy here because:
|
||||
# numpy.linspace(0,999,31)[15] is 499.49999999999994
|
||||
# torch.linspace(0,999,31)[15] is 499.5
|
||||
# ...and we want the same result as the original codebase.
|
||||
timesteps = torch.tensor(
|
||||
np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps), dtype=self.dtype, device=self.device
|
||||
).flip(0)
|
||||
return timesteps
|
||||
|
||||
def _generate_sigmas(self) -> Tensor:
|
||||
sigmas = self.noise_std / self.cumulative_scale_factors
|
||||
sigmas = torch.tensor(np.interp(self.timesteps.cpu().numpy(), np.arange(0, len(sigmas)), sigmas.cpu().numpy()))
|
||||
sigmas = torch.cat([sigmas, tensor([0.0])])
|
||||
return sigmas.to(device=self.device, dtype=self.dtype)
|
||||
|
||||
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
|
||||
sigma = self.sigmas[step]
|
||||
return x / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: Tensor,
|
||||
noise: Tensor,
|
||||
step: int,
|
||||
generator: Generator | None = None,
|
||||
s_churn: float = 0.0,
|
||||
s_tmin: float = 0.0,
|
||||
s_tmax: float = float("inf"),
|
||||
s_noise: float = 1.0,
|
||||
) -> Tensor:
|
||||
sigma = self.sigmas[step]
|
||||
|
||||
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0
|
||||
|
||||
alt_noise = torch.randn(noise.shape, generator=generator, device=noise.device, dtype=noise.dtype)
|
||||
eps = alt_noise * s_noise
|
||||
sigma_hat = sigma * (gamma + 1)
|
||||
if gamma > 0:
|
||||
x = x + eps * (sigma_hat**2 - sigma**2) ** 0.5
|
||||
|
||||
predicted_x = x - sigma_hat * noise
|
||||
|
||||
# 1st order Euler
|
||||
derivative = (x - predicted_x) / sigma_hat
|
||||
dt = self.sigmas[step + 1] - sigma_hat
|
||||
denoised_x = x + derivative * dt
|
||||
|
||||
return denoised_x
|
@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import TypeVar
|
||||
|
||||
from torch import Tensor, device as Device, dtype as DType, float32, linspace, log, sqrt
|
||||
from torch import Generator, Tensor, device as Device, dtype as DType, float32, linspace, log, sqrt
|
||||
|
||||
T = TypeVar("T", bound="Scheduler")
|
||||
|
||||
@ -50,7 +50,7 @@ class Scheduler(ABC):
|
||||
self.timesteps = self._generate_timesteps()
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
|
||||
def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
||||
"""
|
||||
Applies a step of the diffusion process to the input tensor `x` using the provided `noise` and `timestep`.
|
||||
|
||||
@ -71,6 +71,12 @@ class Scheduler(ABC):
|
||||
def steps(self) -> list[int]:
|
||||
return list(range(self.num_inference_steps))
|
||||
|
||||
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
|
||||
"""
|
||||
For compatibility with schedulers that need to scale the input according to the current timestep.
|
||||
"""
|
||||
return x
|
||||
|
||||
def sample_power_distribution(self, power: float = 2, /) -> Tensor:
|
||||
return (
|
||||
linspace(
|
||||
|
@ -89,10 +89,18 @@ class StableDiffusion_1(LatentDiffusionModel):
|
||||
classifier_free_guidance=True,
|
||||
)
|
||||
|
||||
negative_embedding, _ = clip_text_embedding.chunk(2)
|
||||
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
|
||||
negative_embedding, _ = clip_text_embedding.chunk(2)
|
||||
self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs)
|
||||
degraded_noise = self.unet(degraded_latents)
|
||||
if "ip_adapter" in self.unet.provider.contexts:
|
||||
# this implementation is a bit hacky, it should be refactored in the future
|
||||
ip_adapter_context = self.unet.use_context("ip_adapter")
|
||||
image_embedding_copy = ip_adapter_context["clip_image_embedding"].clone()
|
||||
ip_adapter_context["clip_image_embedding"], _ = ip_adapter_context["clip_image_embedding"].chunk(2)
|
||||
degraded_noise = self.unet(degraded_latents)
|
||||
ip_adapter_context["clip_image_embedding"] = image_embedding_copy
|
||||
else:
|
||||
degraded_noise = self.unet(degraded_latents)
|
||||
|
||||
return sag.scale * (noise - degraded_noise)
|
||||
|
||||
@ -160,14 +168,23 @@ class StableDiffusion_1_Inpainting(StableDiffusion_1):
|
||||
step=step,
|
||||
classifier_free_guidance=True,
|
||||
)
|
||||
|
||||
negative_embedding, _ = clip_text_embedding.chunk(2)
|
||||
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
|
||||
self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs)
|
||||
x = torch.cat(
|
||||
tensors=(degraded_latents, self.mask_latents, self.target_image_latents),
|
||||
dim=1,
|
||||
)
|
||||
degraded_noise = self.unet(x)
|
||||
|
||||
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
|
||||
negative_embedding, _ = clip_text_embedding.chunk(2)
|
||||
self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs)
|
||||
|
||||
if "ip_adapter" in self.unet.provider.contexts:
|
||||
# this implementation is a bit hacky, it should be refactored in the future
|
||||
ip_adapter_context = self.unet.use_context("ip_adapter")
|
||||
image_embedding_copy = ip_adapter_context["clip_image_embedding"].clone()
|
||||
ip_adapter_context["clip_image_embedding"], _ = ip_adapter_context["clip_image_embedding"].chunk(2)
|
||||
degraded_noise = self.unet(x)
|
||||
ip_adapter_context["clip_image_embedding"] = image_embedding_copy
|
||||
else:
|
||||
degraded_noise = self.unet(x)
|
||||
|
||||
return sag.scale * (noise - degraded_noise)
|
||||
|
@ -138,17 +138,25 @@ class StableDiffusion_XL(LatentDiffusionModel):
|
||||
classifier_free_guidance=True,
|
||||
)
|
||||
|
||||
negative_embedding, _ = clip_text_embedding.chunk(2)
|
||||
negative_text_embedding, _ = clip_text_embedding.chunk(2)
|
||||
negative_pooled_embedding, _ = pooled_text_embedding.chunk(2)
|
||||
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
|
||||
time_ids, _ = time_ids.chunk(2)
|
||||
|
||||
self.set_unet_context(
|
||||
timestep=timestep,
|
||||
clip_text_embedding=negative_embedding,
|
||||
clip_text_embedding=negative_text_embedding,
|
||||
pooled_text_embedding=negative_pooled_embedding,
|
||||
time_ids=time_ids,
|
||||
**kwargs,
|
||||
)
|
||||
degraded_noise = self.unet(degraded_latents)
|
||||
if "ip_adapter" in self.unet.provider.contexts:
|
||||
# this implementation is a bit hacky, it should be refactored in the future
|
||||
ip_adapter_context = self.unet.use_context("ip_adapter")
|
||||
image_embedding_copy = ip_adapter_context["clip_image_embedding"].clone()
|
||||
ip_adapter_context["clip_image_embedding"], _ = ip_adapter_context["clip_image_embedding"].chunk(2)
|
||||
degraded_noise = self.unet(degraded_latents)
|
||||
ip_adapter_context["clip_image_embedding"] = image_embedding_copy
|
||||
else:
|
||||
degraded_noise = self.unet(degraded_latents)
|
||||
|
||||
return sag.scale * (noise - degraded_noise)
|
||||
|
@ -3,11 +3,12 @@ from typing import Sequence
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from jaxtyping import Float
|
||||
from PIL import Image
|
||||
from torch import Tensor, device as Device, dtype as DType
|
||||
|
||||
import imaginairy.vendored.refiners.fluxion.layers as fl
|
||||
from imaginairy.vendored.refiners.fluxion.utils import image_to_tensor, interpolate, no_grad, normalize, pad
|
||||
from imaginairy.vendored.refiners.fluxion.utils import interpolate, no_grad, normalize, pad
|
||||
from imaginairy.vendored.refiners.foundationals.segment_anything.image_encoder import SAMViT, SAMViTH
|
||||
from imaginairy.vendored.refiners.foundationals.segment_anything.mask_decoder import MaskDecoder
|
||||
from imaginairy.vendored.refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder
|
||||
@ -55,7 +56,7 @@ class SegmentAnything(fl.Module):
|
||||
foreground_points: Sequence[tuple[float, float]] | None = None,
|
||||
background_points: Sequence[tuple[float, float]] | None = None,
|
||||
box_points: Sequence[Sequence[tuple[float, float]]] | None = None,
|
||||
masks: Sequence[Image.Image] | None = None,
|
||||
low_res_mask: Float[Tensor, "1 1 256 256"] | None = None,
|
||||
binarize: bool = True,
|
||||
) -> tuple[Tensor, Tensor, Tensor]:
|
||||
if isinstance(input, ImageEmbedding):
|
||||
@ -74,15 +75,13 @@ class SegmentAnything(fl.Module):
|
||||
)
|
||||
self.point_encoder.set_type_mask(type_mask=type_mask)
|
||||
|
||||
if masks is not None:
|
||||
mask_tensor = torch.stack(
|
||||
tensors=[image_to_tensor(image=mask, device=self.device, dtype=self.dtype) for mask in masks]
|
||||
)
|
||||
mask_embedding = self.mask_encoder(mask_tensor)
|
||||
if low_res_mask is not None:
|
||||
mask_embedding = self.mask_encoder(low_res_mask)
|
||||
else:
|
||||
mask_embedding = self.mask_encoder.get_no_mask_dense_embedding(
|
||||
image_embedding_size=self.image_encoder.image_embedding_size
|
||||
)
|
||||
|
||||
point_embedding = self.point_encoder(
|
||||
self.normalize(coordinates, target_size=target_size, original_size=original_size)
|
||||
)
|
||||
|
@ -1 +1 @@
|
||||
vendored from git@github.com:finegrain-ai/refiners.git @ 20c229903f53d05dc1c44659ec97603660ef964c
|
||||
vendored from git@github.com:finegrain-ai/refiners.git @ ce3035923ba71bcb5044708d2f1c37fd1d6722e9
|
||||
|
Loading…
Reference in New Issue
Block a user