feature: updates refiners vendored library (#458)

* feature: updates refiners vendored library

has a small bugfix that will soon be replaced by a better fix from upstream refiners

Co-authored-by: Bryce <github20210803@accounts.brycedrennan.com>
This commit is contained in:
jaydrennan 2024-01-19 09:45:23 -07:00 committed by GitHub
parent fbb16f6c62
commit 1bf53e47cf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 666 additions and 426 deletions

View File

@ -210,7 +210,7 @@ vendorize_normal_map:
vendorize_refiners:
export REPO=git@github.com:finegrain-ai/refiners.git PKG=refiners COMMIT=20c229903f53d05dc1c44659ec97603660ef964c && \
export REPO=git@github.com:finegrain-ai/refiners.git PKG=refiners COMMIT=ce3035923ba71bcb5044708d2f1c37fd1d6722e9 && \
make download_repo REPO=$$REPO PKG=$$PKG COMMIT=$$COMMIT && \
mkdir -p ./imaginairy/vendored/$$PKG && \
rm -rf ./imaginairy/vendored/$$PKG/* && \

View File

@ -1,4 +1,5 @@
from typing import Any, Generic, Iterable, TypeVar
from abc import ABC, abstractmethod
from typing import Any
from torch import Tensor, device as Device, dtype as DType
from torch.nn import Parameter as TorchParameter
@ -6,125 +7,259 @@ from torch.nn.init import normal_, zeros_
import imaginairy.vendored.refiners.fluxion.layers as fl
from imaginairy.vendored.refiners.fluxion.adapters.adapter import Adapter
T = TypeVar("T", bound=fl.Chain)
TLoraAdapter = TypeVar("TLoraAdapter", bound="LoraAdapter[Any]") # Self (see PEP 673)
from imaginairy.vendored.refiners.fluxion.layers.chain import Chain
class Lora(fl.Chain):
class Lora(fl.Chain, ABC):
def __init__(
self,
rank: int = 16,
scale: float = 1.0,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.rank = rank
self._scale = scale
super().__init__(*self.lora_layers(device=device, dtype=dtype), fl.Multiply(scale))
normal_(tensor=self.down.weight, std=1 / self.rank)
zeros_(tensor=self.up.weight)
@abstractmethod
def lora_layers(
self, device: Device | str | None = None, dtype: DType | None = None
) -> tuple[fl.WeightedModule, fl.WeightedModule]:
...
@property
def down(self) -> fl.WeightedModule:
down_layer = self[0]
assert isinstance(down_layer, fl.WeightedModule)
return down_layer
@property
def up(self) -> fl.WeightedModule:
up_layer = self[1]
assert isinstance(up_layer, fl.WeightedModule)
return up_layer
@property
def scale(self) -> float:
return self._scale
@scale.setter
def scale(self, value: float) -> None:
self._scale = value
self.ensure_find(fl.Multiply).scale = value
@classmethod
def from_weights(
cls,
down: Tensor,
up: Tensor,
) -> "Lora":
match (up.ndim, down.ndim):
case (2, 2):
return LinearLora.from_weights(up=up, down=down)
case (4, 4):
return Conv2dLora.from_weights(up=up, down=down)
case _:
raise ValueError(f"Unsupported weight shapes: up={up.shape}, down={down.shape}")
@classmethod
def from_dict(cls, state_dict: dict[str, Tensor], /) -> dict[str, "Lora"]:
"""
Create a dictionary of LoRA layers from a state dict.
Expects the state dict to be a succession of down and up weights.
"""
state_dict = {k: v for k, v in state_dict.items() if ".weight" in k}
loras: dict[str, Lora] = {}
for down_key, down_tensor, up_tensor in zip(
list(state_dict.keys())[::2], list(state_dict.values())[::2], list(state_dict.values())[1::2]
):
key = ".".join(down_key.split(".")[:-2])
loras[key] = cls.from_weights(down=down_tensor, up=up_tensor)
return loras
@abstractmethod
def auto_attach(self, target: fl.Chain, exclude: list[str] | None = None) -> Any:
...
def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None:
assert down_weight.shape == self.down.weight.shape
assert up_weight.shape == self.up.weight.shape
self.down.weight = TorchParameter(down_weight.to(device=self.device, dtype=self.dtype))
self.up.weight = TorchParameter(up_weight.to(device=self.device, dtype=self.dtype))
class LinearLora(Lora):
def __init__(
self,
in_features: int,
out_features: int,
rank: int = 16,
scale: float = 1.0,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.in_features = in_features
self.out_features = out_features
self.rank = rank
self.scale: float = 1.0
super().__init__(
fl.Linear(in_features=in_features, out_features=rank, bias=False, device=device, dtype=dtype),
fl.Linear(in_features=rank, out_features=out_features, bias=False, device=device, dtype=dtype),
fl.Lambda(func=self.scale_outputs),
super().__init__(rank=rank, scale=scale, device=device, dtype=dtype)
@classmethod
def from_weights(
cls,
down: Tensor,
up: Tensor,
) -> "LinearLora":
assert up.ndim == 2 and down.ndim == 2
assert down.shape[0] == up.shape[1], f"Rank mismatch: down rank={down.shape[0]} and up rank={up.shape[1]}"
lora = cls(
in_features=down.shape[1], out_features=up.shape[0], rank=down.shape[0], device=up.device, dtype=up.dtype
)
lora.load_weights(down_weight=down, up_weight=up)
return lora
normal_(tensor=self.Linear_1.weight, std=1 / self.rank)
zeros_(tensor=self.Linear_2.weight)
def auto_attach(self, target: Chain, exclude: list[str] | None = None) -> "tuple[LoraAdapter, fl.Chain] | None":
for layer, parent in target.walk(fl.Linear):
if isinstance(parent, Lora) or isinstance(parent, LoraAdapter):
continue
def scale_outputs(self, x: Tensor) -> Tensor:
return x * self.scale
if exclude is not None and any(
[any([p.__class__.__name__ == e for p in parent.get_parents() + [parent]]) for e in exclude]
):
continue
def set_scale(self, scale: float) -> None:
self.scale = scale
if layer.in_features == self.in_features and layer.out_features == self.out_features:
return LoraAdapter(target=layer, lora=self), parent
def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None:
self.Linear_1.weight = TorchParameter(down_weight.to(device=self.device, dtype=self.dtype))
self.Linear_2.weight = TorchParameter(up_weight.to(device=self.device, dtype=self.dtype))
@property
def up_weight(self) -> Tensor:
return self.Linear_2.weight.data
@property
def down_weight(self) -> Tensor:
return self.Linear_1.weight.data
class SingleLoraAdapter(fl.Sum, Adapter[fl.Linear]):
def __init__(
self,
target: fl.Linear,
rank: int = 16,
scale: float = 1.0,
) -> None:
self.in_features = target.in_features
self.out_features = target.out_features
self.rank = rank
self.scale = scale
with self.setup_adapter(target):
super().__init__(
target,
Lora(
in_features=target.in_features,
out_features=target.out_features,
rank=rank,
device=target.device,
dtype=target.dtype,
def lora_layers(
self, device: Device | str | None = None, dtype: DType | None = None
) -> tuple[fl.Linear, fl.Linear]:
return (
fl.Linear(
in_features=self.in_features,
out_features=self.rank,
bias=False,
device=device,
dtype=dtype,
),
fl.Linear(
in_features=self.rank,
out_features=self.out_features,
bias=False,
device=device,
dtype=dtype,
),
)
self.Lora.set_scale(scale=scale)
class LoraAdapter(Generic[T], fl.Chain, Adapter[T]):
class Conv2dLora(Lora):
def __init__(
self,
target: T,
sub_targets: Iterable[tuple[fl.Linear, fl.Chain]],
rank: int | None = None,
in_channels: int,
out_channels: int,
rank: int = 16,
scale: float = 1.0,
weights: list[Tensor] | None = None,
kernel_size: tuple[int, int] = (1, 3),
stride: tuple[int, int] = (1, 1),
padding: tuple[int, int] = (0, 1),
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
super().__init__(rank=rank, scale=scale, device=device, dtype=dtype)
@classmethod
def from_weights(
cls,
down: Tensor,
up: Tensor,
) -> "Conv2dLora":
assert up.ndim == 4 and down.ndim == 4
assert down.shape[0] == up.shape[1], f"Rank mismatch: down rank={down.shape[0]} and up rank={up.shape[1]}"
down_kernel_size, up_kernel_size = down.shape[2], up.shape[2]
down_padding = 1 if down_kernel_size == 3 else 0
up_padding = 1 if up_kernel_size == 3 else 0
lora = cls(
in_channels=down.shape[1],
out_channels=up.shape[0],
rank=down.shape[0],
kernel_size=(down_kernel_size, up_kernel_size),
padding=(down_padding, up_padding),
device=up.device,
dtype=up.dtype,
)
lora.load_weights(down_weight=down, up_weight=up)
return lora
def auto_attach(self, target: Chain, exclude: list[str] | None = None) -> "tuple[LoraAdapter, fl.Chain] | None":
for layer, parent in target.walk(fl.Conv2d):
if isinstance(parent, Lora) or isinstance(parent, LoraAdapter):
continue
if exclude is not None and any(
[any([p.__class__.__name__ == e for p in parent.get_parents() + [parent]]) for e in exclude]
):
continue
if layer.in_channels == self.in_channels and layer.out_channels == self.out_channels:
if layer.stride != (self.stride[0], self.stride[0]):
self.down.stride = layer.stride
return LoraAdapter(
target=layer,
lora=self,
), parent
def lora_layers(
self, device: Device | str | None = None, dtype: DType | None = None
) -> tuple[fl.Conv2d, fl.Conv2d]:
return (
fl.Conv2d(
in_channels=self.in_channels,
out_channels=self.rank,
kernel_size=self.kernel_size[0],
stride=self.stride[0],
padding=self.padding[0],
use_bias=False,
device=device,
dtype=dtype,
),
fl.Conv2d(
in_channels=self.rank,
out_channels=self.out_channels,
kernel_size=self.kernel_size[1],
stride=self.stride[1],
padding=self.padding[1],
use_bias=False,
device=device,
dtype=dtype,
),
)
class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]):
def __init__(self, target: fl.WeightedModule, lora: Lora) -> None:
with self.setup_adapter(target):
super().__init__(target)
if weights is not None:
assert len(weights) % 2 == 0
weights_rank = weights[0].shape[1]
if rank is None:
rank = weights_rank
else:
assert rank == weights_rank
assert rank is not None, "either pass a rank or weights"
self.sub_targets = sub_targets
self.sub_adapters: list[tuple[SingleLoraAdapter, fl.Chain]] = []
for linear, parent in self.sub_targets:
self.sub_adapters.append((SingleLoraAdapter(target=linear, rank=rank, scale=scale), parent))
if weights is not None:
assert len(self.sub_adapters) == (len(weights) // 2)
for i, (adapter, _) in enumerate(self.sub_adapters):
lora = adapter.Lora
assert (
lora.rank == weights[i * 2].shape[1]
), f"Rank of Lora layer {lora.rank} must match shape of weights {weights[i*2].shape[1]}"
adapter.Lora.load_weights(up_weight=weights[i * 2], down_weight=weights[i * 2 + 1])
def inject(self: TLoraAdapter, parent: fl.Chain | None = None) -> TLoraAdapter:
for adapter, adapter_parent in self.sub_adapters:
adapter.inject(adapter_parent)
return super().inject(parent)
def eject(self) -> None:
for adapter, _ in self.sub_adapters:
adapter.eject()
super().eject()
super().__init__(target, lora)
@property
def weights(self) -> list[Tensor]:
return [w for adapter, _ in self.sub_adapters for w in [adapter.Lora.up_weight, adapter.Lora.down_weight]]
def lora(self) -> Lora:
return self.ensure_find(Lora)
@property
def scale(self) -> float:
return self.lora.scale
@scale.setter
def scale(self, value: float) -> None:
self.lora.scale = value

View File

@ -164,6 +164,9 @@ class WeightedModule(Module):
def dtype(self) -> DType:
return self.weight.dtype
def __str__(self) -> str:
return f"{super().__str__().removesuffix(')')}, device={self.device}, dtype={str(self.dtype).removeprefix('torch.')})"
class TreeNode(TypedDict):
value: str

View File

@ -146,6 +146,7 @@ def tensor_to_image(tensor: Tensor) -> Image.Image:
assert tensor.ndim == 4 and tensor.shape[0] == 1, f"Unsupported tensor shape: {tensor.shape}"
num_channels = tensor.shape[1]
tensor = tensor.clamp(0, 1).squeeze(0)
tensor = tensor.to(torch.float32) # to avoid numpy error with bfloat16
match num_channels:
case 1:
@ -187,20 +188,26 @@ def save_to_safetensors(path: Path | str, tensors: dict[str, Tensor], metadata:
def summarize_tensor(tensor: torch.Tensor, /) -> str:
return (
"Tensor("
+ ", ".join(
[
info_list = [
f"shape=({', '.join(map(str, tensor.shape))})",
f"dtype={str(object=tensor.dtype).removeprefix('torch.')}",
f"device={tensor.device}",
]
if not tensor.is_complex():
info_list.extend(
[
f"min={tensor.min():.2f}", # type: ignore
f"max={tensor.max():.2f}", # type: ignore
f"mean={tensor.mean():.2f}",
f"std={tensor.std():.2f}",
f"norm={norm(x=tensor):.2f}",
]
)
info_list.extend(
[
f"mean={tensor.float().mean():.2f}",
f"std={tensor.float().std():.2f}",
f"norm={norm(x=tensor.float()):.2f}",
f"grad={tensor.requires_grad}",
]
)
+ ")"
)
return "Tensor(" + ", ".join(info_list) + ")"

View File

@ -1,15 +1,12 @@
import math
from enum import IntEnum
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar
from typing import TYPE_CHECKING, Any, Generic, TypeVar
from jaxtyping import Float
from PIL import Image
from torch import Tensor, cat, device as Device, dtype as DType, softmax, zeros_like
from torch import Tensor, cat, device as Device, dtype as DType, nn, softmax, zeros_like
import imaginairy.vendored.refiners.fluxion.layers as fl
from imaginairy.vendored.refiners.fluxion.adapters.adapter import Adapter
from imaginairy.vendored.refiners.fluxion.adapters.lora import Lora
from imaginairy.vendored.refiners.fluxion.context import Contexts
from imaginairy.vendored.refiners.fluxion.layers.attentions import ScaledDotProductAttention
from imaginairy.vendored.refiners.fluxion.utils import image_to_tensor, normalize
@ -236,120 +233,99 @@ class PerceiverResampler(fl.Chain):
return {"perceiver_resampler": {"x": None}}
class _CrossAttnIndex(IntEnum):
TXT_CROSS_ATTN = 0 # text cross-attention
IMG_CROSS_ATTN = 1 # image cross-attention
class ImageCrossAttention(fl.Chain):
def __init__(self, text_cross_attention: fl.Attention, scale: float = 1.0) -> None:
self._scale = scale
super().__init__(
fl.Distribute(
fl.Identity(),
fl.Chain(
fl.UseContext(context="ip_adapter", key="clip_image_embedding"),
fl.Linear(
in_features=text_cross_attention.key_embedding_dim,
out_features=text_cross_attention.inner_dim,
bias=text_cross_attention.use_bias,
device=text_cross_attention.device,
dtype=text_cross_attention.dtype,
),
),
fl.Chain(
fl.UseContext(context="ip_adapter", key="clip_image_embedding"),
fl.Linear(
in_features=text_cross_attention.value_embedding_dim,
out_features=text_cross_attention.inner_dim,
bias=text_cross_attention.use_bias,
device=text_cross_attention.device,
dtype=text_cross_attention.dtype,
),
),
),
ScaledDotProductAttention(
num_heads=text_cross_attention.num_heads, is_causal=text_cross_attention.is_causal
),
fl.Multiply(self.scale),
)
@property
def scale(self) -> float:
return self._scale
class InjectionPoint(fl.Chain):
pass
@scale.setter
def scale(self, value: float) -> None:
self._scale = value
self.ensure_find(fl.Multiply).scale = value
class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
def __init__(
self,
target: fl.Attention,
text_sequence_length: int = 77,
image_sequence_length: int = 4,
scale: float = 1.0,
) -> None:
self.text_sequence_length = text_sequence_length
self.image_sequence_length = image_sequence_length
self.scale = scale
self._scale = scale
with self.setup_adapter(target):
clone = target.structural_copy()
scaled_dot_product = clone.ensure_find(ScaledDotProductAttention)
image_cross_attention = ImageCrossAttention(
text_cross_attention=clone,
scale=self.scale,
)
clone.replace(
old_module=scaled_dot_product,
new_module=fl.Sum(
scaled_dot_product,
image_cross_attention,
),
)
super().__init__(
fl.Distribute(
# Note: the same query is used for image cross-attention as for text cross-attention
InjectionPoint(), # Wq
fl.Parallel(
fl.Chain(
fl.Slicing(dim=1, end=text_sequence_length),
InjectionPoint(), # Wk
),
fl.Chain(
fl.Slicing(dim=1, start=text_sequence_length),
fl.Linear(
in_features=self.target.key_embedding_dim,
out_features=self.target.inner_dim,
bias=self.target.use_bias,
device=target.device,
dtype=target.dtype,
), # Wk'
),
),
fl.Parallel(
fl.Chain(
fl.Slicing(dim=1, end=text_sequence_length),
InjectionPoint(), # Wv
),
fl.Chain(
fl.Slicing(dim=1, start=text_sequence_length),
fl.Linear(
in_features=self.target.key_embedding_dim,
out_features=self.target.inner_dim,
bias=self.target.use_bias,
device=target.device,
dtype=target.dtype,
), # Wv'
),
),
),
fl.Sum(
fl.Chain(
fl.Lambda(func=partial(self.select_qkv, index=_CrossAttnIndex.TXT_CROSS_ATTN)),
ScaledDotProductAttention(num_heads=target.num_heads, is_causal=target.is_causal),
),
fl.Chain(
fl.Lambda(func=partial(self.select_qkv, index=_CrossAttnIndex.IMG_CROSS_ATTN)),
ScaledDotProductAttention(num_heads=target.num_heads, is_causal=target.is_causal),
fl.Lambda(func=self.scale_outputs),
),
),
InjectionPoint(), # proj
clone,
)
def select_qkv(
self, query: Tensor, keys: tuple[Tensor, Tensor], values: tuple[Tensor, Tensor], index: _CrossAttnIndex
) -> tuple[Tensor, Tensor, Tensor]:
return (query, keys[index.value], values[index.value])
@property
def image_cross_attention(self) -> ImageCrossAttention:
return self.ensure_find(ImageCrossAttention)
def scale_outputs(self, x: Tensor) -> Tensor:
return x * self.scale
@property
def image_key_projection(self) -> fl.Linear:
return self.image_cross_attention.Distribute[1].Linear
def _predicate(self, k: type[fl.Module]) -> Callable[[fl.Module, fl.Chain], bool]:
def f(m: fl.Module, _: fl.Chain) -> bool:
if isinstance(m, Lora): # do not adapt LoRAs
raise StopIteration
return isinstance(m, k)
@property
def image_value_projection(self) -> fl.Linear:
return self.image_cross_attention.Distribute[2].Linear
return f
@property
def scale(self) -> float:
return self._scale
def _target_linears(self) -> list[fl.Linear]:
return [m for m, _ in self.target.walk(self._predicate(fl.Linear)) if isinstance(m, fl.Linear)]
@scale.setter
def scale(self, value: float) -> None:
self._scale = value
self.image_cross_attention.scale = value
def inject(self: "CrossAttentionAdapter", parent: fl.Chain | None = None) -> "CrossAttentionAdapter":
linears = self._target_linears()
assert len(linears) == 4 # Wq, Wk, Wv and Proj
injection_points = list(self.layers(InjectionPoint))
assert len(injection_points) == 4
for linear, ip in zip(linears, injection_points):
ip.append(linear)
assert len(ip) == 1
return super().inject(parent)
def eject(self) -> None:
injection_points = list(self.layers(InjectionPoint))
assert len(injection_points) == 4
for ip in injection_points:
ip.pop()
assert len(ip) == 0
super().eject()
def load_weights(self, key_tensor: Tensor, value_tensor: Tensor) -> None:
self.image_key_projection.weight = nn.Parameter(key_tensor)
self.image_value_projection.weight = nn.Parameter(value_tensor)
self.image_cross_attention.to(self.device, self.dtype)
class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
@ -377,7 +353,7 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
self._image_proj = [image_proj]
self.sub_adapters = [
CrossAttentionAdapter(target=cross_attn, scale=scale, image_sequence_length=self.image_proj.num_tokens)
CrossAttentionAdapter(target=cross_attn, scale=scale)
for cross_attn in filter(lambda attn: type(attn) != fl.SelfAttention, target.layers(fl.Attention))
]
@ -388,14 +364,15 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
self.image_proj.load_state_dict(image_proj_state_dict)
for i, cross_attn in enumerate(self.sub_adapters):
cross_attn_state_dict: dict[str, Tensor] = {}
cross_attention_weights: list[Tensor] = []
for k, v in weights.items():
prefix = f"ip_adapter.{i:03d}."
if not k.startswith(prefix):
continue
cross_attn_state_dict[k.removeprefix(prefix)] = v
cross_attention_weights.append(v)
cross_attn.load_state_dict(state_dict=cross_attn_state_dict)
assert len(cross_attention_weights) == 2
cross_attn.load_weights(*cross_attention_weights)
@property
def clip_image_encoder(self) -> CLIPImageEncoderH:
@ -420,10 +397,22 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
adapter.eject()
super().eject()
@property
def scale(self) -> float:
return self.sub_adapters[0].scale
@scale.setter
def scale(self, value: float) -> None:
for cross_attn in self.sub_adapters:
cross_attn.scale = value
def set_scale(self, scale: float) -> None:
for cross_attn in self.sub_adapters:
cross_attn.scale = scale
def set_clip_image_embedding(self, image_embedding: Tensor) -> None:
self.set_context("ip_adapter", {"clip_image_embedding": image_embedding})
# These should be concatenated to the CLIP text embedding before setting the UNet context
def compute_clip_image_embedding(self, image_prompt: Tensor) -> Tensor:
image_encoder = self.clip_image_encoder if not self.fine_grained else self.grid_image_encoder

View File

@ -1,146 +1,140 @@
from enum import Enum
from pathlib import Path
from typing import Callable, Iterator
from warnings import warn
from torch import Tensor
import imaginairy.vendored.refiners.fluxion.layers as fl
from imaginairy.vendored.refiners.fluxion.adapters.adapter import Adapter
from imaginairy.vendored.refiners.fluxion.adapters.lora import Lora, LoraAdapter
from imaginairy.vendored.refiners.fluxion.utils import load_from_safetensors, load_metadata_from_safetensors
from imaginairy.vendored.refiners.foundationals.clip.text_encoder import FeedForward, TransformerLayer
from imaginairy.vendored.refiners.foundationals.latent_diffusion import (
CLIPTextEncoderL,
LatentDiffusionAutoencoder,
SD1UNet,
StableDiffusion_1,
)
from imaginairy.vendored.refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import Controlnet
MODELS = ["unet", "text_encoder", "lda"]
from imaginairy.vendored.refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
class LoraTarget(str, Enum):
Self = "self"
Attention = "Attention"
SelfAttention = "SelfAttention"
CrossAttention = "CrossAttentionBlock2d"
FeedForward = "FeedForward"
TransformerLayer = "TransformerLayer"
def get_class(self) -> type[fl.Chain]:
match self:
case LoraTarget.Self:
return fl.Chain
case LoraTarget.Attention:
return fl.Attention
case LoraTarget.SelfAttention:
return fl.SelfAttention
case LoraTarget.CrossAttention:
return CrossAttentionBlock2d
case LoraTarget.FeedForward:
return FeedForward
case LoraTarget.TransformerLayer:
return TransformerLayer
def _predicate(k: type[fl.Module]) -> Callable[[fl.Module, fl.Chain], bool]:
def f(m: fl.Module, _: fl.Chain) -> bool:
if isinstance(m, Lora): # do not adapt other LoRAs
raise StopIteration
if isinstance(m, Controlnet): # do not adapt Controlnet linears
raise StopIteration
return isinstance(m, k)
return f
def _iter_linears(module: fl.Chain) -> Iterator[tuple[fl.Linear, fl.Chain]]:
for m, p in module.walk(_predicate(fl.Linear)):
assert isinstance(m, fl.Linear)
yield (m, p)
def lora_targets(
module: fl.Chain,
target: LoraTarget | list[LoraTarget],
) -> Iterator[tuple[fl.Linear, fl.Chain]]:
if isinstance(target, list):
for t in target:
yield from lora_targets(module, t)
return
if target == LoraTarget.Self:
yield from _iter_linears(module)
return
for layer, _ in module.walk(_predicate(target.get_class())):
assert isinstance(layer, fl.Chain)
yield from _iter_linears(layer)
class SD1LoraAdapter(fl.Chain, Adapter[StableDiffusion_1]):
metadata: dict[str, str] | None
tensors: dict[str, Tensor]
class SDLoraManager:
def __init__(
self,
target: StableDiffusion_1,
sub_targets: dict[str, list[LoraTarget]],
target: LatentDiffusionModel,
) -> None:
self.target = target
@property
def unet(self) -> fl.Chain:
unet = self.target.unet
assert isinstance(unet, fl.Chain)
return unet
@property
def clip_text_encoder(self) -> fl.Chain:
clip_text_encoder = self.target.clip_text_encoder
assert isinstance(clip_text_encoder, fl.Chain)
return clip_text_encoder
def load(
self,
tensors: dict[str, Tensor],
/,
scale: float = 1.0,
weights: dict[str, Tensor] | None = None,
):
with self.setup_adapter(target):
super().__init__(target)
) -> None:
"""Load the LoRA weights from a dictionary of tensors.
self.sub_adapters: list[LoraAdapter[SD1UNet | CLIPTextEncoderL | LatentDiffusionAutoencoder]] = []
for model_name in MODELS:
if not (model_targets := sub_targets.get(model_name, [])):
continue
model = getattr(target, "clip_text_encoder" if model_name == "text_encoder" else model_name)
lora_weights = [weights[k] for k in sorted(weights) if k.startswith(model_name)] if weights else None
self.sub_adapters.append(
LoraAdapter[type(model)](
model,
sub_targets=lora_targets(model, model_targets),
scale=scale,
weights=lora_weights,
)
Expects the keys to be in the commonly found formats on CivitAI's hub.
"""
assert len(self.lora_adapters) == 0, "Loras already loaded"
loras = Lora.from_dict(
{key: value.to(device=self.target.device, dtype=self.target.dtype) for key, value in tensors.items()}
)
loras = {key: loras[key] for key in sorted(loras.keys(), key=SDLoraManager.sort_keys)}
@classmethod
def from_safetensors(
cls,
target: StableDiffusion_1,
checkpoint_path: Path | str,
scale: float = 1.0,
):
metadata = load_metadata_from_safetensors(checkpoint_path)
assert metadata is not None, "Invalid safetensors checkpoint: missing metadata"
tensors = load_from_safetensors(checkpoint_path, device=target.device)
# if no key contains "unet" or "text", assume all keys are for the unet
if not "unet" in loras and not "text" in loras:
loras = {f"unet_{key}": loras[key] for key in loras.keys()}
sub_targets: dict[str, list[LoraTarget]] = {}
for model_name in MODELS:
if not (v := metadata.get(f"{model_name}_targets", "")):
continue
sub_targets[model_name] = [LoraTarget(x) for x in v.split(",")]
self.load_unet(loras)
self.load_text_encoder(loras)
return cls(
target,
sub_targets,
scale=scale,
weights=tensors,
)
self.scale = scale
def inject(self: "SD1LoraAdapter", parent: fl.Chain | None = None) -> "SD1LoraAdapter":
for adapter in self.sub_adapters:
adapter.inject()
return super().inject(parent)
def load_text_encoder(self, loras: dict[str, Lora], /) -> None:
text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key}
SDLoraManager.auto_attach(text_encoder_loras, self.clip_text_encoder)
def eject(self) -> None:
for adapter in self.sub_adapters:
adapter.eject()
super().eject()
def load_unet(self, loras: dict[str, Lora], /) -> None:
unet_loras = {key: loras[key] for key in loras.keys() if "unet" in key}
exclude: list[str] = []
exclude = [
self.unet_exclusions[exclusion]
for exclusion in self.unet_exclusions
if all([exclusion not in key for key in unet_loras.keys()])
]
SDLoraManager.auto_attach(unet_loras, self.unet, exclude=exclude)
def unload(self) -> None:
for lora_adapter in self.lora_adapters:
lora_adapter.eject()
@property
def loras(self) -> list[Lora]:
return list(self.unet.layers(Lora)) + list(self.clip_text_encoder.layers(Lora))
@property
def lora_adapters(self) -> list[LoraAdapter]:
return list(self.unet.layers(LoraAdapter)) + list(self.clip_text_encoder.layers(LoraAdapter))
@property
def unet_exclusions(self) -> dict[str, str]:
return {
"time": "TimestepEncoder",
"res": "ResidualBlock",
"downsample": "DownsampleBlock",
"upsample": "UpsampleBlock",
}
@property
def scale(self) -> float:
assert len(self.loras) > 0, "No loras found"
assert all([lora.scale == self.loras[0].scale for lora in self.loras])
return self.loras[0].scale
@scale.setter
def scale(self, value: float) -> None:
for lora in self.loras:
lora.scale = value
@staticmethod
def pad(input: str, /, padding_length: int = 2) -> str:
new_split: list[str] = []
for s in input.split("_"):
if s.isdigit():
new_split.append(s.zfill(padding_length))
else:
new_split.append(s)
return "_".join(new_split)
@staticmethod
def sort_keys(key: str, /) -> tuple[str, int]:
# out0 happens sometimes as an alias for out ; this dict might not be exhaustive
key_char_order = {"q": 1, "k": 2, "v": 3, "out": 4, "out0": 4}
for i, s in enumerate(key.split("_")):
if s in key_char_order:
prefix = SDLoraManager.pad("_".join(key.split("_")[:i]))
return (prefix, key_char_order[s])
return (SDLoraManager.pad(key), 5)
@staticmethod
def auto_attach(
loras: dict[str, Lora],
target: fl.Chain,
/,
exclude: list[str] | None = None,
) -> None:
failed_loras: dict[str, Lora] = {}
for key, lora in loras.items():
if attach := lora.auto_attach(target, exclude=exclude):
adapter, parent = attach
adapter.inject(parent)
else:
failed_loras[key] = lora
if failed_loras:
warn(f"failed to attach {len(failed_loras)}/{len(loras)} loras to {target.__class__.__name__}")
# TODO: add a stronger sanity check to make sure loras are attached correctly

View File

@ -11,7 +11,6 @@ from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.sche
T = TypeVar("T", bound="fl.Module")
TLatentDiffusionModel = TypeVar("TLatentDiffusionModel", bound="LatentDiffusionModel")
@ -91,6 +90,8 @@ class LatentDiffusionModel(fl.Module, ABC):
self.set_unet_context(timestep=timestep, clip_text_embedding=clip_text_embedding, **kwargs)
latents = torch.cat(tensors=(x, x)) # for classifier-free guidance
# scale latents for schedulers that need it
latents = self.scheduler.scale_model_input(latents, step=step)
unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2)
# classifier-free guidance

View File

@ -1,11 +1,7 @@
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.ddpm import DDPM
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.euler import EulerScheduler
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
__all__ = [
"Scheduler",
"DPMSolver",
"DDPM",
"DDIM",
]
__all__ = ["Scheduler", "DPMSolver", "DDPM", "DDIM", "EulerScheduler"]

View File

@ -1,4 +1,4 @@
from torch import Tensor, arange, device as Device, dtype as Dtype, float32, sqrt, tensor
from torch import Generator, Tensor, arange, device as Device, dtype as Dtype, float32, sqrt, tensor
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
@ -34,7 +34,7 @@ class DDIM(Scheduler):
timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio + 1
return timesteps.flip(0)
def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
timestep, previous_timestep = (
self.timesteps[step],
(
@ -52,6 +52,12 @@ class DDIM(Scheduler):
),
)
predicted_x = (x - sqrt(1 - current_scale_factor**2) * noise) / current_scale_factor
denoised_x = previous_scale_factor * predicted_x + sqrt(1 - previous_scale_factor**2) * noise
noise_factor = sqrt(1 - previous_scale_factor**2)
# Do not add noise at the last step to avoid visual artifacts.
if step == self.num_inference_steps - 1:
noise_factor = 0
denoised_x = previous_scale_factor * predicted_x + noise_factor * noise
return denoised_x

View File

@ -1,4 +1,4 @@
from torch import Tensor, arange, device as Device
from torch import Generator, Tensor, arange, device as Device
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
@ -30,5 +30,5 @@ class DDPM(Scheduler):
timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio
return timesteps.flip(0)
def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
raise NotImplementedError

View File

@ -1,15 +1,19 @@
from collections import deque
import numpy as np
from torch import Tensor, device as Device, dtype as Dtype, exp, float32, tensor
from torch import Generator, Tensor, device as Device, dtype as Dtype, exp, float32, tensor
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
class DPMSolver(Scheduler):
"""Implements DPM-Solver++ from https://arxiv.org/abs/2211.01095
"""
Implements DPM-Solver++ from https://arxiv.org/abs/2211.01095
We only support noise prediction for now.
Regarding last_step_first_order: DPM-Solver++ is known to introduce artifacts
when used with SDXL and few steps. This parameter is a way to mitigate that
effect by using a first-order (Euler) update instead of a second-order update
for the last step of the diffusion.
"""
def __init__(
@ -18,6 +22,7 @@ class DPMSolver(Scheduler):
num_train_timesteps: int = 1_000,
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
last_step_first_order: bool = False,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
device: Device | str = "cpu",
dtype: Dtype = float32,
@ -32,7 +37,8 @@ class DPMSolver(Scheduler):
dtype=dtype,
)
self.estimated_data = deque([tensor([])] * 2, maxlen=2)
self.initial_steps = 0
self.last_step_first_order = last_step_first_order
self._first_step_has_been_run = False
def _generate_timesteps(self) -> Tensor:
# We need to use numpy here because:
@ -45,57 +51,48 @@ class DPMSolver(Scheduler):
).flip(0)
def dpm_solver_first_order_update(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
timestep, previous_timestep = (
self.timesteps[step],
self.timesteps[step + 1 if step < len(self.timesteps) - 1 else 0],
)
previous_ratio, current_ratio = (
self.signal_to_noise_ratios[previous_timestep],
self.signal_to_noise_ratios[timestep],
)
current_timestep = self.timesteps[step]
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else tensor([0])
previous_ratio = self.signal_to_noise_ratios[previous_timestep]
current_ratio = self.signal_to_noise_ratios[current_timestep]
previous_scale_factor = self.cumulative_scale_factors[previous_timestep]
previous_noise_std, current_noise_std = (
self.noise_std[previous_timestep],
self.noise_std[timestep],
)
previous_noise_std = self.noise_std[previous_timestep]
current_noise_std = self.noise_std[current_timestep]
factor = exp(-(previous_ratio - current_ratio)) - 1.0
denoised_x = (previous_noise_std / current_noise_std) * x - (factor * previous_scale_factor) * noise
return denoised_x
def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int) -> Tensor:
previous_timestep, current_timestep, next_timestep = (
self.timesteps[step + 1] if step < len(self.timesteps) - 1 else tensor([0]),
self.timesteps[step],
self.timesteps[step - 1],
)
current_data_estimation, next_data_estimation = self.estimated_data[-1], self.estimated_data[-2]
previous_ratio, current_ratio, next_ratio = (
self.signal_to_noise_ratios[previous_timestep],
self.signal_to_noise_ratios[current_timestep],
self.signal_to_noise_ratios[next_timestep],
)
previous_timestep = self.timesteps[step + 1] if step < self.num_inference_steps - 1 else tensor([0])
current_timestep = self.timesteps[step]
next_timestep = self.timesteps[step - 1]
current_data_estimation = self.estimated_data[-1]
next_data_estimation = self.estimated_data[-2]
previous_ratio = self.signal_to_noise_ratios[previous_timestep]
current_ratio = self.signal_to_noise_ratios[current_timestep]
next_ratio = self.signal_to_noise_ratios[next_timestep]
previous_scale_factor = self.cumulative_scale_factors[previous_timestep]
previous_std, current_std = (
self.noise_std[previous_timestep],
self.noise_std[current_timestep],
)
previous_noise_std = self.noise_std[previous_timestep]
current_noise_std = self.noise_std[current_timestep]
estimation_delta = (current_data_estimation - next_data_estimation) / (
(current_ratio - next_ratio) / (previous_ratio - current_ratio)
)
factor = exp(-(previous_ratio - current_ratio)) - 1.0
denoised_x = (
(previous_std / current_std) * x
(previous_noise_std / current_noise_std) * x
- (factor * previous_scale_factor) * current_data_estimation
- 0.5 * (factor * previous_scale_factor) * estimation_delta
)
return denoised_x
def __call__(
self,
x: Tensor,
noise: Tensor,
step: int,
) -> Tensor:
def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
"""
Represents one step of the backward diffusion process that iteratively denoises the input data `x`.
@ -107,11 +104,9 @@ class DPMSolver(Scheduler):
scale_factor, noise_ratio = self.cumulative_scale_factors[current_timestep], self.noise_std[current_timestep]
estimated_denoised_data = (x - noise_ratio * noise) / scale_factor
self.estimated_data.append(estimated_denoised_data)
denoised_x = (
self.dpm_solver_first_order_update(x=x, noise=estimated_denoised_data, step=step)
if (self.initial_steps == 0)
else self.multistep_dpm_solver_second_order_update(x=x, step=step)
)
if self.initial_steps < 2:
self.initial_steps += 1
return denoised_x
if step == 0 or (self.last_step_first_order and step == self.num_inference_steps - 1) or not self._first_step_has_been_run:
self._first_step_has_been_run = True
return self.dpm_solver_first_order_update(x=x, noise=estimated_denoised_data, step=step)
return self.multistep_dpm_solver_second_order_update(x=x, step=step)

View File

@ -0,0 +1,84 @@
import numpy as np
import torch
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
class EulerScheduler(Scheduler):
def __init__(
self,
num_inference_steps: int,
num_train_timesteps: int = 1_000,
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
device: Device | str = "cpu",
dtype: Dtype = float32,
):
if noise_schedule != NoiseSchedule.QUADRATIC:
raise NotImplementedError
super().__init__(
num_inference_steps=num_inference_steps,
num_train_timesteps=num_train_timesteps,
initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate,
noise_schedule=noise_schedule,
device=device,
dtype=dtype,
)
self.sigmas = self._generate_sigmas()
@property
def init_noise_sigma(self) -> Tensor:
return self.sigmas.max()
def _generate_timesteps(self) -> Tensor:
# We need to use numpy here because:
# numpy.linspace(0,999,31)[15] is 499.49999999999994
# torch.linspace(0,999,31)[15] is 499.5
# ...and we want the same result as the original codebase.
timesteps = torch.tensor(
np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps), dtype=self.dtype, device=self.device
).flip(0)
return timesteps
def _generate_sigmas(self) -> Tensor:
sigmas = self.noise_std / self.cumulative_scale_factors
sigmas = torch.tensor(np.interp(self.timesteps.cpu().numpy(), np.arange(0, len(sigmas)), sigmas.cpu().numpy()))
sigmas = torch.cat([sigmas, tensor([0.0])])
return sigmas.to(device=self.device, dtype=self.dtype)
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
sigma = self.sigmas[step]
return x / ((sigma**2 + 1) ** 0.5)
def __call__(
self,
x: Tensor,
noise: Tensor,
step: int,
generator: Generator | None = None,
s_churn: float = 0.0,
s_tmin: float = 0.0,
s_tmax: float = float("inf"),
s_noise: float = 1.0,
) -> Tensor:
sigma = self.sigmas[step]
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0
alt_noise = torch.randn(noise.shape, generator=generator, device=noise.device, dtype=noise.dtype)
eps = alt_noise * s_noise
sigma_hat = sigma * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat**2 - sigma**2) ** 0.5
predicted_x = x - sigma_hat * noise
# 1st order Euler
derivative = (x - predicted_x) / sigma_hat
dt = self.sigmas[step + 1] - sigma_hat
denoised_x = x + derivative * dt
return denoised_x

View File

@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
from enum import Enum
from typing import TypeVar
from torch import Tensor, device as Device, dtype as DType, float32, linspace, log, sqrt
from torch import Generator, Tensor, device as Device, dtype as DType, float32, linspace, log, sqrt
T = TypeVar("T", bound="Scheduler")
@ -50,7 +50,7 @@ class Scheduler(ABC):
self.timesteps = self._generate_timesteps()
@abstractmethod
def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
"""
Applies a step of the diffusion process to the input tensor `x` using the provided `noise` and `timestep`.
@ -71,6 +71,12 @@ class Scheduler(ABC):
def steps(self) -> list[int]:
return list(range(self.num_inference_steps))
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
"""
For compatibility with schedulers that need to scale the input according to the current timestep.
"""
return x
def sample_power_distribution(self, power: float = 2, /) -> Tensor:
return (
linspace(

View File

@ -89,9 +89,17 @@ class StableDiffusion_1(LatentDiffusionModel):
classifier_free_guidance=True,
)
negative_embedding, _ = clip_text_embedding.chunk(2)
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
negative_embedding, _ = clip_text_embedding.chunk(2)
self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs)
if "ip_adapter" in self.unet.provider.contexts:
# this implementation is a bit hacky, it should be refactored in the future
ip_adapter_context = self.unet.use_context("ip_adapter")
image_embedding_copy = ip_adapter_context["clip_image_embedding"].clone()
ip_adapter_context["clip_image_embedding"], _ = ip_adapter_context["clip_image_embedding"].chunk(2)
degraded_noise = self.unet(degraded_latents)
ip_adapter_context["clip_image_embedding"] = image_embedding_copy
else:
degraded_noise = self.unet(degraded_latents)
return sag.scale * (noise - degraded_noise)
@ -160,14 +168,23 @@ class StableDiffusion_1_Inpainting(StableDiffusion_1):
step=step,
classifier_free_guidance=True,
)
negative_embedding, _ = clip_text_embedding.chunk(2)
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs)
x = torch.cat(
tensors=(degraded_latents, self.mask_latents, self.target_image_latents),
dim=1,
)
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
negative_embedding, _ = clip_text_embedding.chunk(2)
self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs)
if "ip_adapter" in self.unet.provider.contexts:
# this implementation is a bit hacky, it should be refactored in the future
ip_adapter_context = self.unet.use_context("ip_adapter")
image_embedding_copy = ip_adapter_context["clip_image_embedding"].clone()
ip_adapter_context["clip_image_embedding"], _ = ip_adapter_context["clip_image_embedding"].chunk(2)
degraded_noise = self.unet(x)
ip_adapter_context["clip_image_embedding"] = image_embedding_copy
else:
degraded_noise = self.unet(x)
return sag.scale * (noise - degraded_noise)

View File

@ -138,17 +138,25 @@ class StableDiffusion_XL(LatentDiffusionModel):
classifier_free_guidance=True,
)
negative_embedding, _ = clip_text_embedding.chunk(2)
negative_text_embedding, _ = clip_text_embedding.chunk(2)
negative_pooled_embedding, _ = pooled_text_embedding.chunk(2)
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
time_ids, _ = time_ids.chunk(2)
self.set_unet_context(
timestep=timestep,
clip_text_embedding=negative_embedding,
clip_text_embedding=negative_text_embedding,
pooled_text_embedding=negative_pooled_embedding,
time_ids=time_ids,
**kwargs,
)
if "ip_adapter" in self.unet.provider.contexts:
# this implementation is a bit hacky, it should be refactored in the future
ip_adapter_context = self.unet.use_context("ip_adapter")
image_embedding_copy = ip_adapter_context["clip_image_embedding"].clone()
ip_adapter_context["clip_image_embedding"], _ = ip_adapter_context["clip_image_embedding"].chunk(2)
degraded_noise = self.unet(degraded_latents)
ip_adapter_context["clip_image_embedding"] = image_embedding_copy
else:
degraded_noise = self.unet(degraded_latents)
return sag.scale * (noise - degraded_noise)

View File

@ -3,11 +3,12 @@ from typing import Sequence
import numpy as np
import torch
from jaxtyping import Float
from PIL import Image
from torch import Tensor, device as Device, dtype as DType
import imaginairy.vendored.refiners.fluxion.layers as fl
from imaginairy.vendored.refiners.fluxion.utils import image_to_tensor, interpolate, no_grad, normalize, pad
from imaginairy.vendored.refiners.fluxion.utils import interpolate, no_grad, normalize, pad
from imaginairy.vendored.refiners.foundationals.segment_anything.image_encoder import SAMViT, SAMViTH
from imaginairy.vendored.refiners.foundationals.segment_anything.mask_decoder import MaskDecoder
from imaginairy.vendored.refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder
@ -55,7 +56,7 @@ class SegmentAnything(fl.Module):
foreground_points: Sequence[tuple[float, float]] | None = None,
background_points: Sequence[tuple[float, float]] | None = None,
box_points: Sequence[Sequence[tuple[float, float]]] | None = None,
masks: Sequence[Image.Image] | None = None,
low_res_mask: Float[Tensor, "1 1 256 256"] | None = None,
binarize: bool = True,
) -> tuple[Tensor, Tensor, Tensor]:
if isinstance(input, ImageEmbedding):
@ -74,15 +75,13 @@ class SegmentAnything(fl.Module):
)
self.point_encoder.set_type_mask(type_mask=type_mask)
if masks is not None:
mask_tensor = torch.stack(
tensors=[image_to_tensor(image=mask, device=self.device, dtype=self.dtype) for mask in masks]
)
mask_embedding = self.mask_encoder(mask_tensor)
if low_res_mask is not None:
mask_embedding = self.mask_encoder(low_res_mask)
else:
mask_embedding = self.mask_encoder.get_no_mask_dense_embedding(
image_embedding_size=self.image_encoder.image_embedding_size
)
point_embedding = self.point_encoder(
self.normalize(coordinates, target_size=target_size, original_size=original_size)
)

View File

@ -1 +1 @@
vendored from git@github.com:finegrain-ai/refiners.git @ 20c229903f53d05dc1c44659ec97603660ef964c
vendored from git@github.com:finegrain-ai/refiners.git @ ce3035923ba71bcb5044708d2f1c37fd1d6722e9