feature: updates refiners vendored library (#458)
* feature: updates refiners vendored library has a small bugfix that will soon be replaced by a better fix from upstream refiners Co-authored-by: Bryce <github20210803@accounts.brycedrennan.com>pull/461/head
parent
fbb16f6c62
commit
1bf53e47cf
@ -1,146 +1,140 @@
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Callable, Iterator
|
||||
from warnings import warn
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
import imaginairy.vendored.refiners.fluxion.layers as fl
|
||||
from imaginairy.vendored.refiners.fluxion.adapters.adapter import Adapter
|
||||
from imaginairy.vendored.refiners.fluxion.adapters.lora import Lora, LoraAdapter
|
||||
from imaginairy.vendored.refiners.fluxion.utils import load_from_safetensors, load_metadata_from_safetensors
|
||||
from imaginairy.vendored.refiners.foundationals.clip.text_encoder import FeedForward, TransformerLayer
|
||||
from imaginairy.vendored.refiners.foundationals.latent_diffusion import (
|
||||
CLIPTextEncoderL,
|
||||
LatentDiffusionAutoencoder,
|
||||
SD1UNet,
|
||||
StableDiffusion_1,
|
||||
)
|
||||
from imaginairy.vendored.refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
|
||||
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import Controlnet
|
||||
|
||||
MODELS = ["unet", "text_encoder", "lda"]
|
||||
|
||||
|
||||
class LoraTarget(str, Enum):
|
||||
Self = "self"
|
||||
Attention = "Attention"
|
||||
SelfAttention = "SelfAttention"
|
||||
CrossAttention = "CrossAttentionBlock2d"
|
||||
FeedForward = "FeedForward"
|
||||
TransformerLayer = "TransformerLayer"
|
||||
|
||||
def get_class(self) -> type[fl.Chain]:
|
||||
match self:
|
||||
case LoraTarget.Self:
|
||||
return fl.Chain
|
||||
case LoraTarget.Attention:
|
||||
return fl.Attention
|
||||
case LoraTarget.SelfAttention:
|
||||
return fl.SelfAttention
|
||||
case LoraTarget.CrossAttention:
|
||||
return CrossAttentionBlock2d
|
||||
case LoraTarget.FeedForward:
|
||||
return FeedForward
|
||||
case LoraTarget.TransformerLayer:
|
||||
return TransformerLayer
|
||||
|
||||
|
||||
def _predicate(k: type[fl.Module]) -> Callable[[fl.Module, fl.Chain], bool]:
|
||||
def f(m: fl.Module, _: fl.Chain) -> bool:
|
||||
if isinstance(m, Lora): # do not adapt other LoRAs
|
||||
raise StopIteration
|
||||
if isinstance(m, Controlnet): # do not adapt Controlnet linears
|
||||
raise StopIteration
|
||||
return isinstance(m, k)
|
||||
|
||||
return f
|
||||
|
||||
|
||||
def _iter_linears(module: fl.Chain) -> Iterator[tuple[fl.Linear, fl.Chain]]:
|
||||
for m, p in module.walk(_predicate(fl.Linear)):
|
||||
assert isinstance(m, fl.Linear)
|
||||
yield (m, p)
|
||||
|
||||
|
||||
def lora_targets(
|
||||
module: fl.Chain,
|
||||
target: LoraTarget | list[LoraTarget],
|
||||
) -> Iterator[tuple[fl.Linear, fl.Chain]]:
|
||||
if isinstance(target, list):
|
||||
for t in target:
|
||||
yield from lora_targets(module, t)
|
||||
return
|
||||
|
||||
if target == LoraTarget.Self:
|
||||
yield from _iter_linears(module)
|
||||
return
|
||||
|
||||
for layer, _ in module.walk(_predicate(target.get_class())):
|
||||
assert isinstance(layer, fl.Chain)
|
||||
yield from _iter_linears(layer)
|
||||
|
||||
|
||||
class SD1LoraAdapter(fl.Chain, Adapter[StableDiffusion_1]):
|
||||
metadata: dict[str, str] | None
|
||||
tensors: dict[str, Tensor]
|
||||
from imaginairy.vendored.refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
|
||||
|
||||
|
||||
class SDLoraManager:
|
||||
def __init__(
|
||||
self,
|
||||
target: StableDiffusion_1,
|
||||
sub_targets: dict[str, list[LoraTarget]],
|
||||
scale: float = 1.0,
|
||||
weights: dict[str, Tensor] | None = None,
|
||||
):
|
||||
with self.setup_adapter(target):
|
||||
super().__init__(target)
|
||||
|
||||
self.sub_adapters: list[LoraAdapter[SD1UNet | CLIPTextEncoderL | LatentDiffusionAutoencoder]] = []
|
||||
|
||||
for model_name in MODELS:
|
||||
if not (model_targets := sub_targets.get(model_name, [])):
|
||||
continue
|
||||
model = getattr(target, "clip_text_encoder" if model_name == "text_encoder" else model_name)
|
||||
|
||||
lora_weights = [weights[k] for k in sorted(weights) if k.startswith(model_name)] if weights else None
|
||||
self.sub_adapters.append(
|
||||
LoraAdapter[type(model)](
|
||||
model,
|
||||
sub_targets=lora_targets(model, model_targets),
|
||||
scale=scale,
|
||||
weights=lora_weights,
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_safetensors(
|
||||
cls,
|
||||
target: StableDiffusion_1,
|
||||
checkpoint_path: Path | str,
|
||||
target: LatentDiffusionModel,
|
||||
) -> None:
|
||||
self.target = target
|
||||
|
||||
@property
|
||||
def unet(self) -> fl.Chain:
|
||||
unet = self.target.unet
|
||||
assert isinstance(unet, fl.Chain)
|
||||
return unet
|
||||
|
||||
@property
|
||||
def clip_text_encoder(self) -> fl.Chain:
|
||||
clip_text_encoder = self.target.clip_text_encoder
|
||||
assert isinstance(clip_text_encoder, fl.Chain)
|
||||
return clip_text_encoder
|
||||
|
||||
def load(
|
||||
self,
|
||||
tensors: dict[str, Tensor],
|
||||
/,
|
||||
scale: float = 1.0,
|
||||
):
|
||||
metadata = load_metadata_from_safetensors(checkpoint_path)
|
||||
assert metadata is not None, "Invalid safetensors checkpoint: missing metadata"
|
||||
tensors = load_from_safetensors(checkpoint_path, device=target.device)
|
||||
|
||||
sub_targets: dict[str, list[LoraTarget]] = {}
|
||||
for model_name in MODELS:
|
||||
if not (v := metadata.get(f"{model_name}_targets", "")):
|
||||
continue
|
||||
sub_targets[model_name] = [LoraTarget(x) for x in v.split(",")]
|
||||
|
||||
return cls(
|
||||
target,
|
||||
sub_targets,
|
||||
scale=scale,
|
||||
weights=tensors,
|
||||
) -> None:
|
||||
"""Load the LoRA weights from a dictionary of tensors.
|
||||
|
||||
Expects the keys to be in the commonly found formats on CivitAI's hub.
|
||||
"""
|
||||
assert len(self.lora_adapters) == 0, "Loras already loaded"
|
||||
loras = Lora.from_dict(
|
||||
{key: value.to(device=self.target.device, dtype=self.target.dtype) for key, value in tensors.items()}
|
||||
)
|
||||
|
||||
def inject(self: "SD1LoraAdapter", parent: fl.Chain | None = None) -> "SD1LoraAdapter":
|
||||
for adapter in self.sub_adapters:
|
||||
adapter.inject()
|
||||
return super().inject(parent)
|
||||
|
||||
def eject(self) -> None:
|
||||
for adapter in self.sub_adapters:
|
||||
adapter.eject()
|
||||
super().eject()
|
||||
loras = {key: loras[key] for key in sorted(loras.keys(), key=SDLoraManager.sort_keys)}
|
||||
|
||||
# if no key contains "unet" or "text", assume all keys are for the unet
|
||||
if not "unet" in loras and not "text" in loras:
|
||||
loras = {f"unet_{key}": loras[key] for key in loras.keys()}
|
||||
|
||||
self.load_unet(loras)
|
||||
self.load_text_encoder(loras)
|
||||
|
||||
self.scale = scale
|
||||
|
||||
def load_text_encoder(self, loras: dict[str, Lora], /) -> None:
|
||||
text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key}
|
||||
SDLoraManager.auto_attach(text_encoder_loras, self.clip_text_encoder)
|
||||
|
||||
def load_unet(self, loras: dict[str, Lora], /) -> None:
|
||||
unet_loras = {key: loras[key] for key in loras.keys() if "unet" in key}
|
||||
exclude: list[str] = []
|
||||
exclude = [
|
||||
self.unet_exclusions[exclusion]
|
||||
for exclusion in self.unet_exclusions
|
||||
if all([exclusion not in key for key in unet_loras.keys()])
|
||||
]
|
||||
SDLoraManager.auto_attach(unet_loras, self.unet, exclude=exclude)
|
||||
|
||||
def unload(self) -> None:
|
||||
for lora_adapter in self.lora_adapters:
|
||||
lora_adapter.eject()
|
||||
|
||||
@property
|
||||
def loras(self) -> list[Lora]:
|
||||
return list(self.unet.layers(Lora)) + list(self.clip_text_encoder.layers(Lora))
|
||||
|
||||
@property
|
||||
def lora_adapters(self) -> list[LoraAdapter]:
|
||||
return list(self.unet.layers(LoraAdapter)) + list(self.clip_text_encoder.layers(LoraAdapter))
|
||||
|
||||
@property
|
||||
def unet_exclusions(self) -> dict[str, str]:
|
||||
return {
|
||||
"time": "TimestepEncoder",
|
||||
"res": "ResidualBlock",
|
||||
"downsample": "DownsampleBlock",
|
||||
"upsample": "UpsampleBlock",
|
||||
}
|
||||
|
||||
@property
|
||||
def scale(self) -> float:
|
||||
assert len(self.loras) > 0, "No loras found"
|
||||
assert all([lora.scale == self.loras[0].scale for lora in self.loras])
|
||||
return self.loras[0].scale
|
||||
|
||||
@scale.setter
|
||||
def scale(self, value: float) -> None:
|
||||
for lora in self.loras:
|
||||
lora.scale = value
|
||||
|
||||
@staticmethod
|
||||
def pad(input: str, /, padding_length: int = 2) -> str:
|
||||
new_split: list[str] = []
|
||||
for s in input.split("_"):
|
||||
if s.isdigit():
|
||||
new_split.append(s.zfill(padding_length))
|
||||
else:
|
||||
new_split.append(s)
|
||||
return "_".join(new_split)
|
||||
|
||||
@staticmethod
|
||||
def sort_keys(key: str, /) -> tuple[str, int]:
|
||||
# out0 happens sometimes as an alias for out ; this dict might not be exhaustive
|
||||
key_char_order = {"q": 1, "k": 2, "v": 3, "out": 4, "out0": 4}
|
||||
|
||||
for i, s in enumerate(key.split("_")):
|
||||
if s in key_char_order:
|
||||
prefix = SDLoraManager.pad("_".join(key.split("_")[:i]))
|
||||
return (prefix, key_char_order[s])
|
||||
|
||||
return (SDLoraManager.pad(key), 5)
|
||||
|
||||
@staticmethod
|
||||
def auto_attach(
|
||||
loras: dict[str, Lora],
|
||||
target: fl.Chain,
|
||||
/,
|
||||
exclude: list[str] | None = None,
|
||||
) -> None:
|
||||
failed_loras: dict[str, Lora] = {}
|
||||
for key, lora in loras.items():
|
||||
if attach := lora.auto_attach(target, exclude=exclude):
|
||||
adapter, parent = attach
|
||||
adapter.inject(parent)
|
||||
else:
|
||||
failed_loras[key] = lora
|
||||
|
||||
if failed_loras:
|
||||
warn(f"failed to attach {len(failed_loras)}/{len(loras)} loras to {target.__class__.__name__}")
|
||||
|
||||
# TODO: add a stronger sanity check to make sure loras are attached correctly
|
||||
|
@ -1,11 +1,7 @@
|
||||
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM
|
||||
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.ddpm import DDPM
|
||||
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver
|
||||
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.euler import EulerScheduler
|
||||
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
||||
|
||||
__all__ = [
|
||||
"Scheduler",
|
||||
"DPMSolver",
|
||||
"DDPM",
|
||||
"DDIM",
|
||||
]
|
||||
__all__ = ["Scheduler", "DPMSolver", "DDPM", "DDIM", "EulerScheduler"]
|
||||
|
@ -0,0 +1,84 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor
|
||||
|
||||
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
|
||||
|
||||
|
||||
class EulerScheduler(Scheduler):
|
||||
def __init__(
|
||||
self,
|
||||
num_inference_steps: int,
|
||||
num_train_timesteps: int = 1_000,
|
||||
initial_diffusion_rate: float = 8.5e-4,
|
||||
final_diffusion_rate: float = 1.2e-2,
|
||||
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
|
||||
device: Device | str = "cpu",
|
||||
dtype: Dtype = float32,
|
||||
):
|
||||
if noise_schedule != NoiseSchedule.QUADRATIC:
|
||||
raise NotImplementedError
|
||||
super().__init__(
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_train_timesteps=num_train_timesteps,
|
||||
initial_diffusion_rate=initial_diffusion_rate,
|
||||
final_diffusion_rate=final_diffusion_rate,
|
||||
noise_schedule=noise_schedule,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.sigmas = self._generate_sigmas()
|
||||
|
||||
@property
|
||||
def init_noise_sigma(self) -> Tensor:
|
||||
return self.sigmas.max()
|
||||
|
||||
def _generate_timesteps(self) -> Tensor:
|
||||
# We need to use numpy here because:
|
||||
# numpy.linspace(0,999,31)[15] is 499.49999999999994
|
||||
# torch.linspace(0,999,31)[15] is 499.5
|
||||
# ...and we want the same result as the original codebase.
|
||||
timesteps = torch.tensor(
|
||||
np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps), dtype=self.dtype, device=self.device
|
||||
).flip(0)
|
||||
return timesteps
|
||||
|
||||
def _generate_sigmas(self) -> Tensor:
|
||||
sigmas = self.noise_std / self.cumulative_scale_factors
|
||||
sigmas = torch.tensor(np.interp(self.timesteps.cpu().numpy(), np.arange(0, len(sigmas)), sigmas.cpu().numpy()))
|
||||
sigmas = torch.cat([sigmas, tensor([0.0])])
|
||||
return sigmas.to(device=self.device, dtype=self.dtype)
|
||||
|
||||
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
|
||||
sigma = self.sigmas[step]
|
||||
return x / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: Tensor,
|
||||
noise: Tensor,
|
||||
step: int,
|
||||
generator: Generator | None = None,
|
||||
s_churn: float = 0.0,
|
||||
s_tmin: float = 0.0,
|
||||
s_tmax: float = float("inf"),
|
||||
s_noise: float = 1.0,
|
||||
) -> Tensor:
|
||||
sigma = self.sigmas[step]
|
||||
|
||||
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0
|
||||
|
||||
alt_noise = torch.randn(noise.shape, generator=generator, device=noise.device, dtype=noise.dtype)
|
||||
eps = alt_noise * s_noise
|
||||
sigma_hat = sigma * (gamma + 1)
|
||||
if gamma > 0:
|
||||
x = x + eps * (sigma_hat**2 - sigma**2) ** 0.5
|
||||
|
||||
predicted_x = x - sigma_hat * noise
|
||||
|
||||
# 1st order Euler
|
||||
derivative = (x - predicted_x) / sigma_hat
|
||||
dt = self.sigmas[step + 1] - sigma_hat
|
||||
denoised_x = x + derivative * dt
|
||||
|
||||
return denoised_x
|
@ -1 +1 @@
|
||||
vendored from git@github.com:finegrain-ai/refiners.git @ 20c229903f53d05dc1c44659ec97603660ef964c
|
||||
vendored from git@github.com:finegrain-ai/refiners.git @ ce3035923ba71bcb5044708d2f1c37fd1d6722e9
|
||||
|
Loading…
Reference in New Issue