feature: updates refiners vendored library (#458)

* feature: updates refiners vendored library

has a small bugfix that will soon be replaced by a better fix from upstream refiners

Co-authored-by: Bryce <github20210803@accounts.brycedrennan.com>
pull/461/head
jaydrennan 4 months ago committed by GitHub
parent fbb16f6c62
commit 1bf53e47cf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -210,7 +210,7 @@ vendorize_normal_map:
vendorize_refiners:
export REPO=git@github.com:finegrain-ai/refiners.git PKG=refiners COMMIT=20c229903f53d05dc1c44659ec97603660ef964c && \
export REPO=git@github.com:finegrain-ai/refiners.git PKG=refiners COMMIT=ce3035923ba71bcb5044708d2f1c37fd1d6722e9 && \
make download_repo REPO=$$REPO PKG=$$PKG COMMIT=$$COMMIT && \
mkdir -p ./imaginairy/vendored/$$PKG && \
rm -rf ./imaginairy/vendored/$$PKG/* && \

@ -1,4 +1,5 @@
from typing import Any, Generic, Iterable, TypeVar
from abc import ABC, abstractmethod
from typing import Any
from torch import Tensor, device as Device, dtype as DType
from torch.nn import Parameter as TorchParameter
@ -6,125 +7,259 @@ from torch.nn.init import normal_, zeros_
import imaginairy.vendored.refiners.fluxion.layers as fl
from imaginairy.vendored.refiners.fluxion.adapters.adapter import Adapter
from imaginairy.vendored.refiners.fluxion.layers.chain import Chain
T = TypeVar("T", bound=fl.Chain)
TLoraAdapter = TypeVar("TLoraAdapter", bound="LoraAdapter[Any]") # Self (see PEP 673)
class Lora(fl.Chain):
class Lora(fl.Chain, ABC):
def __init__(
self,
in_features: int,
out_features: int,
rank: int = 16,
scale: float = 1.0,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.in_features = in_features
self.out_features = out_features
self.rank = rank
self.scale: float = 1.0
self._scale = scale
super().__init__(
fl.Linear(in_features=in_features, out_features=rank, bias=False, device=device, dtype=dtype),
fl.Linear(in_features=rank, out_features=out_features, bias=False, device=device, dtype=dtype),
fl.Lambda(func=self.scale_outputs),
)
super().__init__(*self.lora_layers(device=device, dtype=dtype), fl.Multiply(scale))
normal_(tensor=self.Linear_1.weight, std=1 / self.rank)
zeros_(tensor=self.Linear_2.weight)
normal_(tensor=self.down.weight, std=1 / self.rank)
zeros_(tensor=self.up.weight)
def scale_outputs(self, x: Tensor) -> Tensor:
return x * self.scale
@abstractmethod
def lora_layers(
self, device: Device | str | None = None, dtype: DType | None = None
) -> tuple[fl.WeightedModule, fl.WeightedModule]:
...
def set_scale(self, scale: float) -> None:
self.scale = scale
def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None:
self.Linear_1.weight = TorchParameter(down_weight.to(device=self.device, dtype=self.dtype))
self.Linear_2.weight = TorchParameter(up_weight.to(device=self.device, dtype=self.dtype))
@property
def down(self) -> fl.WeightedModule:
down_layer = self[0]
assert isinstance(down_layer, fl.WeightedModule)
return down_layer
@property
def up_weight(self) -> Tensor:
return self.Linear_2.weight.data
def up(self) -> fl.WeightedModule:
up_layer = self[1]
assert isinstance(up_layer, fl.WeightedModule)
return up_layer
@property
def down_weight(self) -> Tensor:
return self.Linear_1.weight.data
def scale(self) -> float:
return self._scale
@scale.setter
def scale(self, value: float) -> None:
self._scale = value
self.ensure_find(fl.Multiply).scale = value
@classmethod
def from_weights(
cls,
down: Tensor,
up: Tensor,
) -> "Lora":
match (up.ndim, down.ndim):
case (2, 2):
return LinearLora.from_weights(up=up, down=down)
case (4, 4):
return Conv2dLora.from_weights(up=up, down=down)
case _:
raise ValueError(f"Unsupported weight shapes: up={up.shape}, down={down.shape}")
@classmethod
def from_dict(cls, state_dict: dict[str, Tensor], /) -> dict[str, "Lora"]:
"""
Create a dictionary of LoRA layers from a state dict.
Expects the state dict to be a succession of down and up weights.
"""
state_dict = {k: v for k, v in state_dict.items() if ".weight" in k}
loras: dict[str, Lora] = {}
for down_key, down_tensor, up_tensor in zip(
list(state_dict.keys())[::2], list(state_dict.values())[::2], list(state_dict.values())[1::2]
):
key = ".".join(down_key.split(".")[:-2])
loras[key] = cls.from_weights(down=down_tensor, up=up_tensor)
return loras
class SingleLoraAdapter(fl.Sum, Adapter[fl.Linear]):
@abstractmethod
def auto_attach(self, target: fl.Chain, exclude: list[str] | None = None) -> Any:
...
def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None:
assert down_weight.shape == self.down.weight.shape
assert up_weight.shape == self.up.weight.shape
self.down.weight = TorchParameter(down_weight.to(device=self.device, dtype=self.dtype))
self.up.weight = TorchParameter(up_weight.to(device=self.device, dtype=self.dtype))
class LinearLora(Lora):
def __init__(
self,
target: fl.Linear,
in_features: int,
out_features: int,
rank: int = 16,
scale: float = 1.0,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.in_features = target.in_features
self.out_features = target.out_features
self.rank = rank
self.scale = scale
with self.setup_adapter(target):
super().__init__(
target,
Lora(
in_features=target.in_features,
out_features=target.out_features,
rank=rank,
device=target.device,
dtype=target.dtype,
),
)
self.Lora.set_scale(scale=scale)
class LoraAdapter(Generic[T], fl.Chain, Adapter[T]):
self.in_features = in_features
self.out_features = out_features
super().__init__(rank=rank, scale=scale, device=device, dtype=dtype)
@classmethod
def from_weights(
cls,
down: Tensor,
up: Tensor,
) -> "LinearLora":
assert up.ndim == 2 and down.ndim == 2
assert down.shape[0] == up.shape[1], f"Rank mismatch: down rank={down.shape[0]} and up rank={up.shape[1]}"
lora = cls(
in_features=down.shape[1], out_features=up.shape[0], rank=down.shape[0], device=up.device, dtype=up.dtype
)
lora.load_weights(down_weight=down, up_weight=up)
return lora
def auto_attach(self, target: Chain, exclude: list[str] | None = None) -> "tuple[LoraAdapter, fl.Chain] | None":
for layer, parent in target.walk(fl.Linear):
if isinstance(parent, Lora) or isinstance(parent, LoraAdapter):
continue
if exclude is not None and any(
[any([p.__class__.__name__ == e for p in parent.get_parents() + [parent]]) for e in exclude]
):
continue
if layer.in_features == self.in_features and layer.out_features == self.out_features:
return LoraAdapter(target=layer, lora=self), parent
def lora_layers(
self, device: Device | str | None = None, dtype: DType | None = None
) -> tuple[fl.Linear, fl.Linear]:
return (
fl.Linear(
in_features=self.in_features,
out_features=self.rank,
bias=False,
device=device,
dtype=dtype,
),
fl.Linear(
in_features=self.rank,
out_features=self.out_features,
bias=False,
device=device,
dtype=dtype,
),
)
class Conv2dLora(Lora):
def __init__(
self,
target: T,
sub_targets: Iterable[tuple[fl.Linear, fl.Chain]],
rank: int | None = None,
in_channels: int,
out_channels: int,
rank: int = 16,
scale: float = 1.0,
weights: list[Tensor] | None = None,
kernel_size: tuple[int, int] = (1, 3),
stride: tuple[int, int] = (1, 1),
padding: tuple[int, int] = (0, 1),
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
super().__init__(rank=rank, scale=scale, device=device, dtype=dtype)
@classmethod
def from_weights(
cls,
down: Tensor,
up: Tensor,
) -> "Conv2dLora":
assert up.ndim == 4 and down.ndim == 4
assert down.shape[0] == up.shape[1], f"Rank mismatch: down rank={down.shape[0]} and up rank={up.shape[1]}"
down_kernel_size, up_kernel_size = down.shape[2], up.shape[2]
down_padding = 1 if down_kernel_size == 3 else 0
up_padding = 1 if up_kernel_size == 3 else 0
lora = cls(
in_channels=down.shape[1],
out_channels=up.shape[0],
rank=down.shape[0],
kernel_size=(down_kernel_size, up_kernel_size),
padding=(down_padding, up_padding),
device=up.device,
dtype=up.dtype,
)
lora.load_weights(down_weight=down, up_weight=up)
return lora
def auto_attach(self, target: Chain, exclude: list[str] | None = None) -> "tuple[LoraAdapter, fl.Chain] | None":
for layer, parent in target.walk(fl.Conv2d):
if isinstance(parent, Lora) or isinstance(parent, LoraAdapter):
continue
if exclude is not None and any(
[any([p.__class__.__name__ == e for p in parent.get_parents() + [parent]]) for e in exclude]
):
continue
if layer.in_channels == self.in_channels and layer.out_channels == self.out_channels:
if layer.stride != (self.stride[0], self.stride[0]):
self.down.stride = layer.stride
return LoraAdapter(
target=layer,
lora=self,
), parent
def lora_layers(
self, device: Device | str | None = None, dtype: DType | None = None
) -> tuple[fl.Conv2d, fl.Conv2d]:
return (
fl.Conv2d(
in_channels=self.in_channels,
out_channels=self.rank,
kernel_size=self.kernel_size[0],
stride=self.stride[0],
padding=self.padding[0],
use_bias=False,
device=device,
dtype=dtype,
),
fl.Conv2d(
in_channels=self.rank,
out_channels=self.out_channels,
kernel_size=self.kernel_size[1],
stride=self.stride[1],
padding=self.padding[1],
use_bias=False,
device=device,
dtype=dtype,
),
)
class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]):
def __init__(self, target: fl.WeightedModule, lora: Lora) -> None:
with self.setup_adapter(target):
super().__init__(target)
if weights is not None:
assert len(weights) % 2 == 0
weights_rank = weights[0].shape[1]
if rank is None:
rank = weights_rank
else:
assert rank == weights_rank
assert rank is not None, "either pass a rank or weights"
self.sub_targets = sub_targets
self.sub_adapters: list[tuple[SingleLoraAdapter, fl.Chain]] = []
for linear, parent in self.sub_targets:
self.sub_adapters.append((SingleLoraAdapter(target=linear, rank=rank, scale=scale), parent))
if weights is not None:
assert len(self.sub_adapters) == (len(weights) // 2)
for i, (adapter, _) in enumerate(self.sub_adapters):
lora = adapter.Lora
assert (
lora.rank == weights[i * 2].shape[1]
), f"Rank of Lora layer {lora.rank} must match shape of weights {weights[i*2].shape[1]}"
adapter.Lora.load_weights(up_weight=weights[i * 2], down_weight=weights[i * 2 + 1])
def inject(self: TLoraAdapter, parent: fl.Chain | None = None) -> TLoraAdapter:
for adapter, adapter_parent in self.sub_adapters:
adapter.inject(adapter_parent)
return super().inject(parent)
def eject(self) -> None:
for adapter, _ in self.sub_adapters:
adapter.eject()
super().eject()
super().__init__(target, lora)
@property
def weights(self) -> list[Tensor]:
return [w for adapter, _ in self.sub_adapters for w in [adapter.Lora.up_weight, adapter.Lora.down_weight]]
def lora(self) -> Lora:
return self.ensure_find(Lora)
@property
def scale(self) -> float:
return self.lora.scale
@scale.setter
def scale(self, value: float) -> None:
self.lora.scale = value

@ -164,6 +164,9 @@ class WeightedModule(Module):
def dtype(self) -> DType:
return self.weight.dtype
def __str__(self) -> str:
return f"{super().__str__().removesuffix(')')}, device={self.device}, dtype={str(self.dtype).removeprefix('torch.')})"
class TreeNode(TypedDict):
value: str

@ -146,6 +146,7 @@ def tensor_to_image(tensor: Tensor) -> Image.Image:
assert tensor.ndim == 4 and tensor.shape[0] == 1, f"Unsupported tensor shape: {tensor.shape}"
num_channels = tensor.shape[1]
tensor = tensor.clamp(0, 1).squeeze(0)
tensor = tensor.to(torch.float32) # to avoid numpy error with bfloat16
match num_channels:
case 1:
@ -187,20 +188,26 @@ def save_to_safetensors(path: Path | str, tensors: dict[str, Tensor], metadata:
def summarize_tensor(tensor: torch.Tensor, /) -> str:
return (
"Tensor("
+ ", ".join(
info_list = [
f"shape=({', '.join(map(str, tensor.shape))})",
f"dtype={str(object=tensor.dtype).removeprefix('torch.')}",
f"device={tensor.device}",
]
if not tensor.is_complex():
info_list.extend(
[
f"shape=({', '.join(map(str, tensor.shape))})",
f"dtype={str(object=tensor.dtype).removeprefix('torch.')}",
f"device={tensor.device}",
f"min={tensor.min():.2f}", # type: ignore
f"max={tensor.max():.2f}", # type: ignore
f"mean={tensor.mean():.2f}",
f"std={tensor.std():.2f}",
f"norm={norm(x=tensor):.2f}",
f"grad={tensor.requires_grad}",
]
)
+ ")"
info_list.extend(
[
f"mean={tensor.float().mean():.2f}",
f"std={tensor.float().std():.2f}",
f"norm={norm(x=tensor.float()):.2f}",
f"grad={tensor.requires_grad}",
]
)
return "Tensor(" + ", ".join(info_list) + ")"

@ -1,15 +1,12 @@
import math
from enum import IntEnum
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar
from typing import TYPE_CHECKING, Any, Generic, TypeVar
from jaxtyping import Float
from PIL import Image
from torch import Tensor, cat, device as Device, dtype as DType, softmax, zeros_like
from torch import Tensor, cat, device as Device, dtype as DType, nn, softmax, zeros_like
import imaginairy.vendored.refiners.fluxion.layers as fl
from imaginairy.vendored.refiners.fluxion.adapters.adapter import Adapter
from imaginairy.vendored.refiners.fluxion.adapters.lora import Lora
from imaginairy.vendored.refiners.fluxion.context import Contexts
from imaginairy.vendored.refiners.fluxion.layers.attentions import ScaledDotProductAttention
from imaginairy.vendored.refiners.fluxion.utils import image_to_tensor, normalize
@ -236,120 +233,99 @@ class PerceiverResampler(fl.Chain):
return {"perceiver_resampler": {"x": None}}
class _CrossAttnIndex(IntEnum):
TXT_CROSS_ATTN = 0 # text cross-attention
IMG_CROSS_ATTN = 1 # image cross-attention
class ImageCrossAttention(fl.Chain):
def __init__(self, text_cross_attention: fl.Attention, scale: float = 1.0) -> None:
self._scale = scale
super().__init__(
fl.Distribute(
fl.Identity(),
fl.Chain(
fl.UseContext(context="ip_adapter", key="clip_image_embedding"),
fl.Linear(
in_features=text_cross_attention.key_embedding_dim,
out_features=text_cross_attention.inner_dim,
bias=text_cross_attention.use_bias,
device=text_cross_attention.device,
dtype=text_cross_attention.dtype,
),
),
fl.Chain(
fl.UseContext(context="ip_adapter", key="clip_image_embedding"),
fl.Linear(
in_features=text_cross_attention.value_embedding_dim,
out_features=text_cross_attention.inner_dim,
bias=text_cross_attention.use_bias,
device=text_cross_attention.device,
dtype=text_cross_attention.dtype,
),
),
),
ScaledDotProductAttention(
num_heads=text_cross_attention.num_heads, is_causal=text_cross_attention.is_causal
),
fl.Multiply(self.scale),
)
@property
def scale(self) -> float:
return self._scale
class InjectionPoint(fl.Chain):
pass
@scale.setter
def scale(self, value: float) -> None:
self._scale = value
self.ensure_find(fl.Multiply).scale = value
class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
def __init__(
self,
target: fl.Attention,
text_sequence_length: int = 77,
image_sequence_length: int = 4,
scale: float = 1.0,
) -> None:
self.text_sequence_length = text_sequence_length
self.image_sequence_length = image_sequence_length
self.scale = scale
self._scale = scale
with self.setup_adapter(target):
super().__init__(
fl.Distribute(
# Note: the same query is used for image cross-attention as for text cross-attention
InjectionPoint(), # Wq
fl.Parallel(
fl.Chain(
fl.Slicing(dim=1, end=text_sequence_length),
InjectionPoint(), # Wk
),
fl.Chain(
fl.Slicing(dim=1, start=text_sequence_length),
fl.Linear(
in_features=self.target.key_embedding_dim,
out_features=self.target.inner_dim,
bias=self.target.use_bias,
device=target.device,
dtype=target.dtype,
), # Wk'
),
),
fl.Parallel(
fl.Chain(
fl.Slicing(dim=1, end=text_sequence_length),
InjectionPoint(), # Wv
),
fl.Chain(
fl.Slicing(dim=1, start=text_sequence_length),
fl.Linear(
in_features=self.target.key_embedding_dim,
out_features=self.target.inner_dim,
bias=self.target.use_bias,
device=target.device,
dtype=target.dtype,
), # Wv'
),
),
),
fl.Sum(
fl.Chain(
fl.Lambda(func=partial(self.select_qkv, index=_CrossAttnIndex.TXT_CROSS_ATTN)),
ScaledDotProductAttention(num_heads=target.num_heads, is_causal=target.is_causal),
),
fl.Chain(
fl.Lambda(func=partial(self.select_qkv, index=_CrossAttnIndex.IMG_CROSS_ATTN)),
ScaledDotProductAttention(num_heads=target.num_heads, is_causal=target.is_causal),
fl.Lambda(func=self.scale_outputs),
),
clone = target.structural_copy()
scaled_dot_product = clone.ensure_find(ScaledDotProductAttention)
image_cross_attention = ImageCrossAttention(
text_cross_attention=clone,
scale=self.scale,
)
clone.replace(
old_module=scaled_dot_product,
new_module=fl.Sum(
scaled_dot_product,
image_cross_attention,
),
InjectionPoint(), # proj
)
super().__init__(
clone,
)
def select_qkv(
self, query: Tensor, keys: tuple[Tensor, Tensor], values: tuple[Tensor, Tensor], index: _CrossAttnIndex
) -> tuple[Tensor, Tensor, Tensor]:
return (query, keys[index.value], values[index.value])
def scale_outputs(self, x: Tensor) -> Tensor:
return x * self.scale
def _predicate(self, k: type[fl.Module]) -> Callable[[fl.Module, fl.Chain], bool]:
def f(m: fl.Module, _: fl.Chain) -> bool:
if isinstance(m, Lora): # do not adapt LoRAs
raise StopIteration
return isinstance(m, k)
return f
def _target_linears(self) -> list[fl.Linear]:
return [m for m, _ in self.target.walk(self._predicate(fl.Linear)) if isinstance(m, fl.Linear)]
def inject(self: "CrossAttentionAdapter", parent: fl.Chain | None = None) -> "CrossAttentionAdapter":
linears = self._target_linears()
assert len(linears) == 4 # Wq, Wk, Wv and Proj
injection_points = list(self.layers(InjectionPoint))
assert len(injection_points) == 4
@property
def image_cross_attention(self) -> ImageCrossAttention:
return self.ensure_find(ImageCrossAttention)
for linear, ip in zip(linears, injection_points):
ip.append(linear)
assert len(ip) == 1
@property
def image_key_projection(self) -> fl.Linear:
return self.image_cross_attention.Distribute[1].Linear
return super().inject(parent)
@property
def image_value_projection(self) -> fl.Linear:
return self.image_cross_attention.Distribute[2].Linear
def eject(self) -> None:
injection_points = list(self.layers(InjectionPoint))
assert len(injection_points) == 4
@property
def scale(self) -> float:
return self._scale
for ip in injection_points:
ip.pop()
assert len(ip) == 0
@scale.setter
def scale(self, value: float) -> None:
self._scale = value
self.image_cross_attention.scale = value
super().eject()
def load_weights(self, key_tensor: Tensor, value_tensor: Tensor) -> None:
self.image_key_projection.weight = nn.Parameter(key_tensor)
self.image_value_projection.weight = nn.Parameter(value_tensor)
self.image_cross_attention.to(self.device, self.dtype)
class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
@ -377,7 +353,7 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
self._image_proj = [image_proj]
self.sub_adapters = [
CrossAttentionAdapter(target=cross_attn, scale=scale, image_sequence_length=self.image_proj.num_tokens)
CrossAttentionAdapter(target=cross_attn, scale=scale)
for cross_attn in filter(lambda attn: type(attn) != fl.SelfAttention, target.layers(fl.Attention))
]
@ -388,14 +364,15 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
self.image_proj.load_state_dict(image_proj_state_dict)
for i, cross_attn in enumerate(self.sub_adapters):
cross_attn_state_dict: dict[str, Tensor] = {}
cross_attention_weights: list[Tensor] = []
for k, v in weights.items():
prefix = f"ip_adapter.{i:03d}."
if not k.startswith(prefix):
continue
cross_attn_state_dict[k.removeprefix(prefix)] = v
cross_attention_weights.append(v)
cross_attn.load_state_dict(state_dict=cross_attn_state_dict)
assert len(cross_attention_weights) == 2
cross_attn.load_weights(*cross_attention_weights)
@property
def clip_image_encoder(self) -> CLIPImageEncoderH:
@ -420,10 +397,22 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
adapter.eject()
super().eject()
@property
def scale(self) -> float:
return self.sub_adapters[0].scale
@scale.setter
def scale(self, value: float) -> None:
for cross_attn in self.sub_adapters:
cross_attn.scale = value
def set_scale(self, scale: float) -> None:
for cross_attn in self.sub_adapters:
cross_attn.scale = scale
def set_clip_image_embedding(self, image_embedding: Tensor) -> None:
self.set_context("ip_adapter", {"clip_image_embedding": image_embedding})
# These should be concatenated to the CLIP text embedding before setting the UNet context
def compute_clip_image_embedding(self, image_prompt: Tensor) -> Tensor:
image_encoder = self.clip_image_encoder if not self.fine_grained else self.grid_image_encoder

@ -1,146 +1,140 @@
from enum import Enum
from pathlib import Path
from typing import Callable, Iterator
from warnings import warn
from torch import Tensor
import imaginairy.vendored.refiners.fluxion.layers as fl
from imaginairy.vendored.refiners.fluxion.adapters.adapter import Adapter
from imaginairy.vendored.refiners.fluxion.adapters.lora import Lora, LoraAdapter
from imaginairy.vendored.refiners.fluxion.utils import load_from_safetensors, load_metadata_from_safetensors
from imaginairy.vendored.refiners.foundationals.clip.text_encoder import FeedForward, TransformerLayer
from imaginairy.vendored.refiners.foundationals.latent_diffusion import (
CLIPTextEncoderL,
LatentDiffusionAutoencoder,
SD1UNet,
StableDiffusion_1,
)
from imaginairy.vendored.refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import Controlnet
MODELS = ["unet", "text_encoder", "lda"]
class LoraTarget(str, Enum):
Self = "self"
Attention = "Attention"
SelfAttention = "SelfAttention"
CrossAttention = "CrossAttentionBlock2d"
FeedForward = "FeedForward"
TransformerLayer = "TransformerLayer"
def get_class(self) -> type[fl.Chain]:
match self:
case LoraTarget.Self:
return fl.Chain
case LoraTarget.Attention:
return fl.Attention
case LoraTarget.SelfAttention:
return fl.SelfAttention
case LoraTarget.CrossAttention:
return CrossAttentionBlock2d
case LoraTarget.FeedForward:
return FeedForward
case LoraTarget.TransformerLayer:
return TransformerLayer
def _predicate(k: type[fl.Module]) -> Callable[[fl.Module, fl.Chain], bool]:
def f(m: fl.Module, _: fl.Chain) -> bool:
if isinstance(m, Lora): # do not adapt other LoRAs
raise StopIteration
if isinstance(m, Controlnet): # do not adapt Controlnet linears
raise StopIteration
return isinstance(m, k)
return f
def _iter_linears(module: fl.Chain) -> Iterator[tuple[fl.Linear, fl.Chain]]:
for m, p in module.walk(_predicate(fl.Linear)):
assert isinstance(m, fl.Linear)
yield (m, p)
def lora_targets(
module: fl.Chain,
target: LoraTarget | list[LoraTarget],
) -> Iterator[tuple[fl.Linear, fl.Chain]]:
if isinstance(target, list):
for t in target:
yield from lora_targets(module, t)
return
if target == LoraTarget.Self:
yield from _iter_linears(module)
return
for layer, _ in module.walk(_predicate(target.get_class())):
assert isinstance(layer, fl.Chain)
yield from _iter_linears(layer)
class SD1LoraAdapter(fl.Chain, Adapter[StableDiffusion_1]):
metadata: dict[str, str] | None
tensors: dict[str, Tensor]
from imaginairy.vendored.refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
class SDLoraManager:
def __init__(
self,
target: StableDiffusion_1,
sub_targets: dict[str, list[LoraTarget]],
scale: float = 1.0,
weights: dict[str, Tensor] | None = None,
):
with self.setup_adapter(target):
super().__init__(target)
self.sub_adapters: list[LoraAdapter[SD1UNet | CLIPTextEncoderL | LatentDiffusionAutoencoder]] = []
for model_name in MODELS:
if not (model_targets := sub_targets.get(model_name, [])):
continue
model = getattr(target, "clip_text_encoder" if model_name == "text_encoder" else model_name)
lora_weights = [weights[k] for k in sorted(weights) if k.startswith(model_name)] if weights else None
self.sub_adapters.append(
LoraAdapter[type(model)](
model,
sub_targets=lora_targets(model, model_targets),
scale=scale,
weights=lora_weights,
)
)
@classmethod
def from_safetensors(
cls,
target: StableDiffusion_1,
checkpoint_path: Path | str,
target: LatentDiffusionModel,
) -> None:
self.target = target
@property
def unet(self) -> fl.Chain:
unet = self.target.unet
assert isinstance(unet, fl.Chain)
return unet
@property
def clip_text_encoder(self) -> fl.Chain:
clip_text_encoder = self.target.clip_text_encoder
assert isinstance(clip_text_encoder, fl.Chain)
return clip_text_encoder
def load(
self,
tensors: dict[str, Tensor],
/,
scale: float = 1.0,
):
metadata = load_metadata_from_safetensors(checkpoint_path)
assert metadata is not None, "Invalid safetensors checkpoint: missing metadata"
tensors = load_from_safetensors(checkpoint_path, device=target.device)
sub_targets: dict[str, list[LoraTarget]] = {}
for model_name in MODELS:
if not (v := metadata.get(f"{model_name}_targets", "")):
continue
sub_targets[model_name] = [LoraTarget(x) for x in v.split(",")]
return cls(
target,
sub_targets,
scale=scale,
weights=tensors,
) -> None:
"""Load the LoRA weights from a dictionary of tensors.
Expects the keys to be in the commonly found formats on CivitAI's hub.
"""
assert len(self.lora_adapters) == 0, "Loras already loaded"
loras = Lora.from_dict(
{key: value.to(device=self.target.device, dtype=self.target.dtype) for key, value in tensors.items()}
)
def inject(self: "SD1LoraAdapter", parent: fl.Chain | None = None) -> "SD1LoraAdapter":
for adapter in self.sub_adapters:
adapter.inject()
return super().inject(parent)
def eject(self) -> None:
for adapter in self.sub_adapters:
adapter.eject()
super().eject()
loras = {key: loras[key] for key in sorted(loras.keys(), key=SDLoraManager.sort_keys)}
# if no key contains "unet" or "text", assume all keys are for the unet
if not "unet" in loras and not "text" in loras:
loras = {f"unet_{key}": loras[key] for key in loras.keys()}
self.load_unet(loras)
self.load_text_encoder(loras)
self.scale = scale
def load_text_encoder(self, loras: dict[str, Lora], /) -> None:
text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key}
SDLoraManager.auto_attach(text_encoder_loras, self.clip_text_encoder)
def load_unet(self, loras: dict[str, Lora], /) -> None:
unet_loras = {key: loras[key] for key in loras.keys() if "unet" in key}
exclude: list[str] = []
exclude = [
self.unet_exclusions[exclusion]
for exclusion in self.unet_exclusions
if all([exclusion not in key for key in unet_loras.keys()])
]
SDLoraManager.auto_attach(unet_loras, self.unet, exclude=exclude)
def unload(self) -> None:
for lora_adapter in self.lora_adapters:
lora_adapter.eject()
@property
def loras(self) -> list[Lora]:
return list(self.unet.layers(Lora)) + list(self.clip_text_encoder.layers(Lora))
@property
def lora_adapters(self) -> list[LoraAdapter]:
return list(self.unet.layers(LoraAdapter)) + list(self.clip_text_encoder.layers(LoraAdapter))
@property
def unet_exclusions(self) -> dict[str, str]:
return {
"time": "TimestepEncoder",
"res": "ResidualBlock",
"downsample": "DownsampleBlock",
"upsample": "UpsampleBlock",
}
@property
def scale(self) -> float:
assert len(self.loras) > 0, "No loras found"
assert all([lora.scale == self.loras[0].scale for lora in self.loras])
return self.loras[0].scale
@scale.setter
def scale(self, value: float) -> None:
for lora in self.loras:
lora.scale = value
@staticmethod
def pad(input: str, /, padding_length: int = 2) -> str:
new_split: list[str] = []
for s in input.split("_"):
if s.isdigit():
new_split.append(s.zfill(padding_length))
else:
new_split.append(s)
return "_".join(new_split)
@staticmethod
def sort_keys(key: str, /) -> tuple[str, int]:
# out0 happens sometimes as an alias for out ; this dict might not be exhaustive
key_char_order = {"q": 1, "k": 2, "v": 3, "out": 4, "out0": 4}
for i, s in enumerate(key.split("_")):
if s in key_char_order:
prefix = SDLoraManager.pad("_".join(key.split("_")[:i]))
return (prefix, key_char_order[s])
return (SDLoraManager.pad(key), 5)
@staticmethod
def auto_attach(
loras: dict[str, Lora],
target: fl.Chain,
/,
exclude: list[str] | None = None,
) -> None:
failed_loras: dict[str, Lora] = {}
for key, lora in loras.items():
if attach := lora.auto_attach(target, exclude=exclude):
adapter, parent = attach
adapter.inject(parent)
else:
failed_loras[key] = lora
if failed_loras:
warn(f"failed to attach {len(failed_loras)}/{len(loras)} loras to {target.__class__.__name__}")
# TODO: add a stronger sanity check to make sure loras are attached correctly

@ -11,7 +11,6 @@ from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.sche
T = TypeVar("T", bound="fl.Module")
TLatentDiffusionModel = TypeVar("TLatentDiffusionModel", bound="LatentDiffusionModel")
@ -91,6 +90,8 @@ class LatentDiffusionModel(fl.Module, ABC):
self.set_unet_context(timestep=timestep, clip_text_embedding=clip_text_embedding, **kwargs)
latents = torch.cat(tensors=(x, x)) # for classifier-free guidance
# scale latents for schedulers that need it
latents = self.scheduler.scale_model_input(latents, step=step)
unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2)
# classifier-free guidance

@ -1,11 +1,7 @@
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.ddpm import DDPM
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.euler import EulerScheduler
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
__all__ = [
"Scheduler",
"DPMSolver",
"DDPM",
"DDIM",
]
__all__ = ["Scheduler", "DPMSolver", "DDPM", "DDIM", "EulerScheduler"]

@ -1,4 +1,4 @@
from torch import Tensor, arange, device as Device, dtype as Dtype, float32, sqrt, tensor
from torch import Generator, Tensor, arange, device as Device, dtype as Dtype, float32, sqrt, tensor
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
@ -34,7 +34,7 @@ class DDIM(Scheduler):
timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio + 1
return timesteps.flip(0)
def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
timestep, previous_timestep = (
self.timesteps[step],
(
@ -52,6 +52,12 @@ class DDIM(Scheduler):
),
)
predicted_x = (x - sqrt(1 - current_scale_factor**2) * noise) / current_scale_factor
denoised_x = previous_scale_factor * predicted_x + sqrt(1 - previous_scale_factor**2) * noise
noise_factor = sqrt(1 - previous_scale_factor**2)
# Do not add noise at the last step to avoid visual artifacts.
if step == self.num_inference_steps - 1:
noise_factor = 0
denoised_x = previous_scale_factor * predicted_x + noise_factor * noise
return denoised_x

@ -1,4 +1,4 @@
from torch import Tensor, arange, device as Device
from torch import Generator, Tensor, arange, device as Device
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
@ -30,5 +30,5 @@ class DDPM(Scheduler):
timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio
return timesteps.flip(0)
def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
raise NotImplementedError

@ -1,15 +1,19 @@
from collections import deque
import numpy as np
from torch import Tensor, device as Device, dtype as Dtype, exp, float32, tensor
from torch import Generator, Tensor, device as Device, dtype as Dtype, exp, float32, tensor
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
class DPMSolver(Scheduler):
"""Implements DPM-Solver++ from https://arxiv.org/abs/2211.01095
"""
Implements DPM-Solver++ from https://arxiv.org/abs/2211.01095
We only support noise prediction for now.
Regarding last_step_first_order: DPM-Solver++ is known to introduce artifacts
when used with SDXL and few steps. This parameter is a way to mitigate that
effect by using a first-order (Euler) update instead of a second-order update
for the last step of the diffusion.
"""
def __init__(
@ -18,6 +22,7 @@ class DPMSolver(Scheduler):
num_train_timesteps: int = 1_000,
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
last_step_first_order: bool = False,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
device: Device | str = "cpu",
dtype: Dtype = float32,
@ -32,7 +37,8 @@ class DPMSolver(Scheduler):
dtype=dtype,
)
self.estimated_data = deque([tensor([])] * 2, maxlen=2)
self.initial_steps = 0
self.last_step_first_order = last_step_first_order
self._first_step_has_been_run = False
def _generate_timesteps(self) -> Tensor:
# We need to use numpy here because:
@ -45,57 +51,48 @@ class DPMSolver(Scheduler):
).flip(0)
def dpm_solver_first_order_update(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
timestep, previous_timestep = (
self.timesteps[step],
self.timesteps[step + 1 if step < len(self.timesteps) - 1 else 0],
)
previous_ratio, current_ratio = (
self.signal_to_noise_ratios[previous_timestep],
self.signal_to_noise_ratios[timestep],
)
current_timestep = self.timesteps[step]
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else tensor([0])
previous_ratio = self.signal_to_noise_ratios[previous_timestep]
current_ratio = self.signal_to_noise_ratios[current_timestep]
previous_scale_factor = self.cumulative_scale_factors[previous_timestep]
previous_noise_std, current_noise_std = (
self.noise_std[previous_timestep],
self.noise_std[timestep],
)
previous_noise_std = self.noise_std[previous_timestep]
current_noise_std = self.noise_std[current_timestep]
factor = exp(-(previous_ratio - current_ratio)) - 1.0
denoised_x = (previous_noise_std / current_noise_std) * x - (factor * previous_scale_factor) * noise
return denoised_x
def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int) -> Tensor:
previous_timestep, current_timestep, next_timestep = (
self.timesteps[step + 1] if step < len(self.timesteps) - 1 else tensor([0]),
self.timesteps[step],
self.timesteps[step - 1],
)
current_data_estimation, next_data_estimation = self.estimated_data[-1], self.estimated_data[-2]
previous_ratio, current_ratio, next_ratio = (
self.signal_to_noise_ratios[previous_timestep],
self.signal_to_noise_ratios[current_timestep],
self.signal_to_noise_ratios[next_timestep],
)
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else tensor([0])
current_timestep = self.timesteps[step]
next_timestep = self.timesteps[step - 1]
current_data_estimation = self.estimated_data[-1]
next_data_estimation = self.estimated_data[-2]
previous_ratio = self.signal_to_noise_ratios[previous_timestep]
current_ratio = self.signal_to_noise_ratios[current_timestep]
next_ratio = self.signal_to_noise_ratios[next_timestep]
previous_scale_factor = self.cumulative_scale_factors[previous_timestep]
previous_std, current_std = (
self.noise_std[previous_timestep],
self.noise_std[current_timestep],
)
previous_noise_std = self.noise_std[previous_timestep]
current_noise_std = self.noise_std[current_timestep]
estimation_delta = (current_data_estimation - next_data_estimation) / (
(current_ratio - next_ratio) / (previous_ratio - current_ratio)
)
factor = exp(-(previous_ratio - current_ratio)) - 1.0
denoised_x = (
(previous_std / current_std) * x
(previous_noise_std / current_noise_std) * x
- (factor * previous_scale_factor) * current_data_estimation
- 0.5 * (factor * previous_scale_factor) * estimation_delta
)
return denoised_x
def __call__(
self,
x: Tensor,
noise: Tensor,
step: int,
) -> Tensor:
def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
"""
Represents one step of the backward diffusion process that iteratively denoises the input data `x`.
@ -107,11 +104,9 @@ class DPMSolver(Scheduler):
scale_factor, noise_ratio = self.cumulative_scale_factors[current_timestep], self.noise_std[current_timestep]
estimated_denoised_data = (x - noise_ratio * noise) / scale_factor
self.estimated_data.append(estimated_denoised_data)
denoised_x = (
self.dpm_solver_first_order_update(x=x, noise=estimated_denoised_data, step=step)
if (self.initial_steps == 0)
else self.multistep_dpm_solver_second_order_update(x=x, step=step)
)
if self.initial_steps < 2:
self.initial_steps += 1
return denoised_x
if step == 0 or (self.last_step_first_order and step == self.num_inference_steps - 1) or not self._first_step_has_been_run:
self._first_step_has_been_run = True
return self.dpm_solver_first_order_update(x=x, noise=estimated_denoised_data, step=step)
return self.multistep_dpm_solver_second_order_update(x=x, step=step)

@ -0,0 +1,84 @@
import numpy as np
import torch
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
class EulerScheduler(Scheduler):
def __init__(
self,
num_inference_steps: int,
num_train_timesteps: int = 1_000,
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
device: Device | str = "cpu",
dtype: Dtype = float32,
):
if noise_schedule != NoiseSchedule.QUADRATIC:
raise NotImplementedError
super().__init__(
num_inference_steps=num_inference_steps,
num_train_timesteps=num_train_timesteps,
initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate,
noise_schedule=noise_schedule,
device=device,
dtype=dtype,
)
self.sigmas = self._generate_sigmas()
@property
def init_noise_sigma(self) -> Tensor:
return self.sigmas.max()
def _generate_timesteps(self) -> Tensor:
# We need to use numpy here because:
# numpy.linspace(0,999,31)[15] is 499.49999999999994
# torch.linspace(0,999,31)[15] is 499.5
# ...and we want the same result as the original codebase.
timesteps = torch.tensor(
np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps), dtype=self.dtype, device=self.device
).flip(0)
return timesteps
def _generate_sigmas(self) -> Tensor:
sigmas = self.noise_std / self.cumulative_scale_factors
sigmas = torch.tensor(np.interp(self.timesteps.cpu().numpy(), np.arange(0, len(sigmas)), sigmas.cpu().numpy()))
sigmas = torch.cat([sigmas, tensor([0.0])])
return sigmas.to(device=self.device, dtype=self.dtype)
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
sigma = self.sigmas[step]
return x / ((sigma**2 + 1) ** 0.5)
def __call__(
self,
x: Tensor,
noise: Tensor,
step: int,
generator: Generator | None = None,
s_churn: float = 0.0,
s_tmin: float = 0.0,
s_tmax: float = float("inf"),
s_noise: float = 1.0,
) -> Tensor:
sigma = self.sigmas[step]
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0
alt_noise = torch.randn(noise.shape, generator=generator, device=noise.device, dtype=noise.dtype)
eps = alt_noise * s_noise
sigma_hat = sigma * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat**2 - sigma**2) ** 0.5
predicted_x = x - sigma_hat * noise
# 1st order Euler
derivative = (x - predicted_x) / sigma_hat
dt = self.sigmas[step + 1] - sigma_hat
denoised_x = x + derivative * dt
return denoised_x

@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
from enum import Enum
from typing import TypeVar
from torch import Tensor, device as Device, dtype as DType, float32, linspace, log, sqrt
from torch import Generator, Tensor, device as Device, dtype as DType, float32, linspace, log, sqrt
T = TypeVar("T", bound="Scheduler")
@ -50,7 +50,7 @@ class Scheduler(ABC):
self.timesteps = self._generate_timesteps()
@abstractmethod
def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
"""
Applies a step of the diffusion process to the input tensor `x` using the provided `noise` and `timestep`.
@ -71,6 +71,12 @@ class Scheduler(ABC):
def steps(self) -> list[int]:
return list(range(self.num_inference_steps))
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
"""
For compatibility with schedulers that need to scale the input according to the current timestep.
"""
return x
def sample_power_distribution(self, power: float = 2, /) -> Tensor:
return (
linspace(

@ -89,10 +89,18 @@ class StableDiffusion_1(LatentDiffusionModel):
classifier_free_guidance=True,
)
negative_embedding, _ = clip_text_embedding.chunk(2)
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
negative_embedding, _ = clip_text_embedding.chunk(2)
self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs)
degraded_noise = self.unet(degraded_latents)
if "ip_adapter" in self.unet.provider.contexts:
# this implementation is a bit hacky, it should be refactored in the future
ip_adapter_context = self.unet.use_context("ip_adapter")
image_embedding_copy = ip_adapter_context["clip_image_embedding"].clone()
ip_adapter_context["clip_image_embedding"], _ = ip_adapter_context["clip_image_embedding"].chunk(2)
degraded_noise = self.unet(degraded_latents)
ip_adapter_context["clip_image_embedding"] = image_embedding_copy
else:
degraded_noise = self.unet(degraded_latents)
return sag.scale * (noise - degraded_noise)
@ -160,14 +168,23 @@ class StableDiffusion_1_Inpainting(StableDiffusion_1):
step=step,
classifier_free_guidance=True,
)
negative_embedding, _ = clip_text_embedding.chunk(2)
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs)
x = torch.cat(
tensors=(degraded_latents, self.mask_latents, self.target_image_latents),
dim=1,
)
degraded_noise = self.unet(x)
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
negative_embedding, _ = clip_text_embedding.chunk(2)
self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs)
if "ip_adapter" in self.unet.provider.contexts:
# this implementation is a bit hacky, it should be refactored in the future
ip_adapter_context = self.unet.use_context("ip_adapter")
image_embedding_copy = ip_adapter_context["clip_image_embedding"].clone()
ip_adapter_context["clip_image_embedding"], _ = ip_adapter_context["clip_image_embedding"].chunk(2)
degraded_noise = self.unet(x)
ip_adapter_context["clip_image_embedding"] = image_embedding_copy
else:
degraded_noise = self.unet(x)
return sag.scale * (noise - degraded_noise)

@ -138,17 +138,25 @@ class StableDiffusion_XL(LatentDiffusionModel):
classifier_free_guidance=True,
)
negative_embedding, _ = clip_text_embedding.chunk(2)
negative_text_embedding, _ = clip_text_embedding.chunk(2)
negative_pooled_embedding, _ = pooled_text_embedding.chunk(2)
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
time_ids, _ = time_ids.chunk(2)
self.set_unet_context(
timestep=timestep,
clip_text_embedding=negative_embedding,
clip_text_embedding=negative_text_embedding,
pooled_text_embedding=negative_pooled_embedding,
time_ids=time_ids,
**kwargs,
)
degraded_noise = self.unet(degraded_latents)
if "ip_adapter" in self.unet.provider.contexts:
# this implementation is a bit hacky, it should be refactored in the future
ip_adapter_context = self.unet.use_context("ip_adapter")
image_embedding_copy = ip_adapter_context["clip_image_embedding"].clone()
ip_adapter_context["clip_image_embedding"], _ = ip_adapter_context["clip_image_embedding"].chunk(2)
degraded_noise = self.unet(degraded_latents)
ip_adapter_context["clip_image_embedding"] = image_embedding_copy
else:
degraded_noise = self.unet(degraded_latents)
return sag.scale * (noise - degraded_noise)

@ -3,11 +3,12 @@ from typing import Sequence
import numpy as np
import torch
from jaxtyping import Float
from PIL import Image
from torch import Tensor, device as Device, dtype as DType
import imaginairy.vendored.refiners.fluxion.layers as fl
from imaginairy.vendored.refiners.fluxion.utils import image_to_tensor, interpolate, no_grad, normalize, pad
from imaginairy.vendored.refiners.fluxion.utils import interpolate, no_grad, normalize, pad
from imaginairy.vendored.refiners.foundationals.segment_anything.image_encoder import SAMViT, SAMViTH
from imaginairy.vendored.refiners.foundationals.segment_anything.mask_decoder import MaskDecoder
from imaginairy.vendored.refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder
@ -55,7 +56,7 @@ class SegmentAnything(fl.Module):
foreground_points: Sequence[tuple[float, float]] | None = None,
background_points: Sequence[tuple[float, float]] | None = None,
box_points: Sequence[Sequence[tuple[float, float]]] | None = None,
masks: Sequence[Image.Image] | None = None,
low_res_mask: Float[Tensor, "1 1 256 256"] | None = None,
binarize: bool = True,
) -> tuple[Tensor, Tensor, Tensor]:
if isinstance(input, ImageEmbedding):
@ -74,15 +75,13 @@ class SegmentAnything(fl.Module):
)
self.point_encoder.set_type_mask(type_mask=type_mask)
if masks is not None:
mask_tensor = torch.stack(
tensors=[image_to_tensor(image=mask, device=self.device, dtype=self.dtype) for mask in masks]
)
mask_embedding = self.mask_encoder(mask_tensor)
if low_res_mask is not None:
mask_embedding = self.mask_encoder(low_res_mask)
else:
mask_embedding = self.mask_encoder.get_no_mask_dense_embedding(
image_embedding_size=self.image_encoder.image_embedding_size
)
point_embedding = self.point_encoder(
self.normalize(coordinates, target_size=target_size, original_size=original_size)
)

@ -1 +1 @@
vendored from git@github.com:finegrain-ai/refiners.git @ 20c229903f53d05dc1c44659ec97603660ef964c
vendored from git@github.com:finegrain-ai/refiners.git @ ce3035923ba71bcb5044708d2f1c37fd1d6722e9

Loading…
Cancel
Save