mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-10-31 03:20:40 +00:00
build: vendorize refiners
so we can still work in conda envs
This commit is contained in:
parent
f84406f12c
commit
55e27160f5
13
Makefile
13
Makefile
@ -201,6 +201,19 @@ vendorize_normal_map:
|
|||||||
make af
|
make af
|
||||||
|
|
||||||
|
|
||||||
|
vendorize_refiners:
|
||||||
|
export REPO=git@github.com:finegrain-ai/refiners.git PKG=refiners COMMIT=20c229903f53d05dc1c44659ec97603660ef964c && \
|
||||||
|
make download_repo REPO=$$REPO PKG=$$PKG COMMIT=$$COMMIT && \
|
||||||
|
mkdir -p ./imaginairy/vendored/$$PKG && \
|
||||||
|
rm -rf ./imaginairy/vendored/$$PKG/* && \
|
||||||
|
cp -R ./downloads/refiners/src/refiners/* ./imaginairy/vendored/$$PKG/ && \
|
||||||
|
cp ./downloads/refiners/LICENSE ./imaginairy/vendored/$$PKG/ && \
|
||||||
|
rm -rf ./imaginairy/vendored/$$PKG/training_utils && \
|
||||||
|
echo "vendored from $$REPO @ $$COMMIT" | tee ./imaginairy/vendored/$$PKG/readme.txt
|
||||||
|
find ./imaginairy/vendored/refiners/ -type f -name "*.py" -exec sed -i '' 's/from refiners/from imaginairy.vendored.refiners/g' {} + &&\
|
||||||
|
find ./imaginairy/vendored/refiners/ -type f -name "*.py" -exec sed -i '' 's/import refiners/import imaginairy.vendored.refiners/g' {} + &&\
|
||||||
|
make af
|
||||||
|
|
||||||
|
|
||||||
vendorize: ## vendorize a github repo. `make vendorize REPO=git@github.com:openai/CLIP.git PKG=clip`
|
vendorize: ## vendorize a github repo. `make vendorize REPO=git@github.com:openai/CLIP.git PKG=clip`
|
||||||
mkdir -p ./downloads
|
mkdir -p ./downloads
|
||||||
|
@ -12,17 +12,17 @@ if TYPE_CHECKING:
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _generate_single_image_compvis(
|
def _generate_single_image(
|
||||||
prompt: "ImaginePrompt",
|
prompt: "ImaginePrompt",
|
||||||
debug_img_callback=None,
|
debug_img_callback=None,
|
||||||
progress_img_callback=None,
|
progress_img_callback=None,
|
||||||
progress_img_interval_steps=3,
|
progress_img_interval_steps=3,
|
||||||
progress_img_interval_min_s=0.1,
|
progress_img_interval_min_s=0.1,
|
||||||
half_mode=None,
|
|
||||||
add_caption=False,
|
add_caption=False,
|
||||||
# controlnet, finetune, naive, auto
|
# controlnet, finetune, naive, auto
|
||||||
inpaint_method="finetune",
|
inpaint_method="finetune",
|
||||||
return_latent=False,
|
return_latent=False,
|
||||||
|
dtype=None,
|
||||||
):
|
):
|
||||||
import torch.nn
|
import torch.nn
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
@ -96,7 +96,7 @@ def _generate_single_image_compvis(
|
|||||||
weights_location=prompt.model_weights,
|
weights_location=prompt.model_weights,
|
||||||
config_path=prompt.model_architecture,
|
config_path=prompt.model_architecture,
|
||||||
control_weights_locations=control_modes,
|
control_weights_locations=control_modes,
|
||||||
half_mode=half_mode,
|
half_mode=dtype == torch.float16,
|
||||||
for_inpainting=for_inpainting and inpaint_method == "finetune",
|
for_inpainting=for_inpainting and inpaint_method == "finetune",
|
||||||
)
|
)
|
||||||
is_controlnet_model = hasattr(model, "control_key")
|
is_controlnet_model = hasattr(model, "control_key")
|
||||||
@ -502,7 +502,6 @@ def _generate_composition_image(
|
|||||||
):
|
):
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from imaginairy.api.generate_refiners import generate_single_image
|
|
||||||
from imaginairy.utils import default, get_default_dtype
|
from imaginairy.utils import default, get_default_dtype
|
||||||
|
|
||||||
cutoff = normalize_image_size(cutoff)
|
cutoff = normalize_image_size(cutoff)
|
||||||
@ -532,7 +531,7 @@ def _generate_composition_image(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
result = generate_single_image(composition_prompt, dtype=dtype)
|
result = _generate_single_image(composition_prompt, dtype=dtype)
|
||||||
img = result.images["generated"]
|
img = result.images["generated"]
|
||||||
while img.width < target_width:
|
while img.width < target_width:
|
||||||
from imaginairy.enhancers.upscale_realesrgan import upscale_image
|
from imaginairy.enhancers.upscale_realesrgan import upscale_image
|
||||||
|
@ -27,7 +27,6 @@ def generate_single_image(
|
|||||||
):
|
):
|
||||||
import torch.nn
|
import torch.nn
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
from refiners.foundationals.latent_diffusion.schedulers import DDIM, DPMSolver
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from imaginairy.api.generate import (
|
from imaginairy.api.generate import (
|
||||||
@ -61,6 +60,10 @@ def generate_single_image(
|
|||||||
prepare_image_for_outpaint,
|
prepare_image_for_outpaint,
|
||||||
)
|
)
|
||||||
from imaginairy.utils.safety import create_safety_score
|
from imaginairy.utils.safety import create_safety_score
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers import (
|
||||||
|
DDIM,
|
||||||
|
DPMSolver,
|
||||||
|
)
|
||||||
|
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
@ -513,7 +516,9 @@ def prep_control_input(
|
|||||||
if not control_config:
|
if not control_config:
|
||||||
msg = f"Unknown control mode: {control_input.mode}"
|
msg = f"Unknown control mode: {control_input.mode}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
from refiners.foundationals.latent_diffusion import SD1ControlnetAdapter
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion import (
|
||||||
|
SD1ControlnetAdapter,
|
||||||
|
)
|
||||||
|
|
||||||
controlnet = SD1ControlnetAdapter( # type: ignore
|
controlnet = SD1ControlnetAdapter( # type: ignore
|
||||||
name=control_input.mode,
|
name=control_input.mode,
|
||||||
|
@ -5,42 +5,52 @@ import math
|
|||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Any, List, Literal
|
from typing import Any, List, Literal
|
||||||
|
|
||||||
import refiners.fluxion.layers as fl
|
|
||||||
import torch
|
import torch
|
||||||
from refiners.fluxion.layers.attentions import ScaledDotProductAttention
|
from torch import Tensor, device as Device, dtype as DType, nn
|
||||||
from refiners.fluxion.layers.chain import ChainError
|
from torch.nn import functional as F
|
||||||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
|
||||||
from refiners.foundationals.latent_diffusion.model import (
|
import imaginairy.vendored.refiners.fluxion.layers as fl
|
||||||
|
from imaginairy.schema import WeightedPrompt
|
||||||
|
from imaginairy.utils.feather_tile import rebuild_image, tile_image
|
||||||
|
from imaginairy.vendored.refiners.fluxion.layers.attentions import (
|
||||||
|
ScaledDotProductAttention,
|
||||||
|
)
|
||||||
|
from imaginairy.vendored.refiners.fluxion.layers.chain import ChainError
|
||||||
|
from imaginairy.vendored.refiners.foundationals.clip.text_encoder import (
|
||||||
|
CLIPTextEncoderL,
|
||||||
|
)
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.model import (
|
||||||
TLatentDiffusionModel,
|
TLatentDiffusionModel,
|
||||||
)
|
)
|
||||||
from refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.ddim import (
|
||||||
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
DDIM,
|
||||||
from refiners.foundationals.latent_diffusion.self_attention_guidance import (
|
)
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import (
|
||||||
|
Scheduler,
|
||||||
|
)
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.self_attention_guidance import (
|
||||||
SelfAttentionMap,
|
SelfAttentionMap,
|
||||||
)
|
)
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import (
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import (
|
||||||
Controlnet,
|
Controlnet,
|
||||||
SD1ControlnetAdapter,
|
SD1ControlnetAdapter,
|
||||||
)
|
)
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import (
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.model import (
|
||||||
SD1Autoencoder,
|
SD1Autoencoder,
|
||||||
SD1UNet,
|
SD1UNet,
|
||||||
StableDiffusion_1 as RefinerStableDiffusion_1,
|
StableDiffusion_1 as RefinerStableDiffusion_1,
|
||||||
StableDiffusion_1_Inpainting as RefinerStableDiffusion_1_Inpainting,
|
StableDiffusion_1_Inpainting as RefinerStableDiffusion_1_Inpainting,
|
||||||
)
|
)
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import (
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import (
|
||||||
SDXLAutoencoder,
|
SDXLAutoencoder,
|
||||||
StableDiffusion_XL as RefinerStableDiffusion_XL,
|
StableDiffusion_XL as RefinerStableDiffusion_XL,
|
||||||
)
|
)
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import (
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import (
|
||||||
DoubleTextEncoder,
|
DoubleTextEncoder,
|
||||||
)
|
)
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import (
|
||||||
from torch import Tensor, device as Device, dtype as DType, nn
|
SDXLUNet,
|
||||||
from torch.nn import functional as F
|
)
|
||||||
|
|
||||||
from imaginairy.schema import WeightedPrompt
|
|
||||||
from imaginairy.utils.feather_tile import rebuild_image, tile_image
|
|
||||||
from imaginairy.weight_management.conversion import cast_weights
|
from imaginairy.weight_management.conversion import cast_weights
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -375,8 +385,8 @@ class StableDiffusion_1_Inpainting(TileModeMixin, RefinerStableDiffusion_1_Inpai
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
total_weight = sum(wp.weight for wp in prompts)
|
total_weight = sum(wp.weight for wp in prompts)
|
||||||
if str(self.clip_text_encoder.device) == "cpu":
|
if str(self.clip_text_encoder.device) == "cpu": # type: ignore
|
||||||
self.clip_text_encoder = self.clip_text_encoder.to(dtype=torch.float32)
|
self.clip_text_encoder = self.clip_text_encoder.to(dtype=torch.float32) # type: ignore
|
||||||
conditioning = sum(
|
conditioning = sum(
|
||||||
self.clip_text_encoder(wp.text) * (wp.weight / total_weight)
|
self.clip_text_encoder(wp.text) * (wp.weight / total_weight)
|
||||||
for wp in prompts
|
for wp in prompts
|
||||||
|
@ -16,9 +16,6 @@ from huggingface_hub import (
|
|||||||
try_to_load_from_cache,
|
try_to_load_from_cache,
|
||||||
)
|
)
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
|
||||||
from refiners.foundationals.latent_diffusion import DoubleTextEncoder, SD1UNet, SDXLUNet
|
|
||||||
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
|
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
from imaginairy import config as iconfig
|
from imaginairy import config as iconfig
|
||||||
@ -29,6 +26,17 @@ from imaginairy.utils import clear_gpu_cache, get_device, instantiate_from_confi
|
|||||||
from imaginairy.utils.model_cache import memory_managed_model
|
from imaginairy.utils.model_cache import memory_managed_model
|
||||||
from imaginairy.utils.named_resolutions import normalize_image_size
|
from imaginairy.utils.named_resolutions import normalize_image_size
|
||||||
from imaginairy.utils.paths import PKG_ROOT
|
from imaginairy.utils.paths import PKG_ROOT
|
||||||
|
from imaginairy.vendored.refiners.foundationals.clip.text_encoder import (
|
||||||
|
CLIPTextEncoderL,
|
||||||
|
)
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion import (
|
||||||
|
DoubleTextEncoder,
|
||||||
|
SD1UNet,
|
||||||
|
SDXLUNet,
|
||||||
|
)
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.model import (
|
||||||
|
LatentDiffusionModel,
|
||||||
|
)
|
||||||
from imaginairy.weight_management import translators
|
from imaginairy.weight_management import translators
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -823,7 +831,7 @@ def open_weights(filepath, device=None):
|
|||||||
device = get_device()
|
device = get_device()
|
||||||
|
|
||||||
if "safetensor" in filepath.lower():
|
if "safetensor" in filepath.lower():
|
||||||
from refiners.fluxion.utils import safe_open
|
from imaginairy.vendored.refiners.fluxion.utils import safe_open
|
||||||
|
|
||||||
with safe_open(path=filepath, framework="pytorch", device=device) as tensors:
|
with safe_open(path=filepath, framework="pytorch", device=device) as tensors:
|
||||||
state_dict = {
|
state_dict = {
|
||||||
|
21
imaginairy/vendored/refiners/LICENSE
Normal file
21
imaginairy/vendored/refiners/LICENSE
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2023 Lagon Technologies
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
3
imaginairy/vendored/refiners/fluxion/__init__.py
Normal file
3
imaginairy/vendored/refiners/fluxion/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from imaginairy.vendored.refiners.fluxion.utils import load_from_safetensors, manual_seed, norm, pad, save_to_safetensors
|
||||||
|
|
||||||
|
__all__ = ["norm", "manual_seed", "save_to_safetensors", "load_from_safetensors", "pad"]
|
@ -0,0 +1,3 @@
|
|||||||
|
from imaginairy.vendored.refiners.fluxion.adapters.adapter import Adapter
|
||||||
|
|
||||||
|
__all__ = ["Adapter"]
|
101
imaginairy/vendored/refiners/fluxion/adapters/adapter.py
Normal file
101
imaginairy/vendored/refiners/fluxion/adapters/adapter.py
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
import contextlib
|
||||||
|
from typing import Any, Generic, Iterator, TypeVar
|
||||||
|
|
||||||
|
import imaginairy.vendored.refiners.fluxion.layers as fl
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=fl.Module)
|
||||||
|
TAdapter = TypeVar("TAdapter", bound="Adapter[Any]") # Self (see PEP 673)
|
||||||
|
|
||||||
|
|
||||||
|
class Adapter(Generic[T]):
|
||||||
|
# we store _target into a one element list to avoid pytorch thinking it is a submodule
|
||||||
|
_target: "list[T]"
|
||||||
|
|
||||||
|
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||||
|
super().__init_subclass__(**kwargs)
|
||||||
|
assert issubclass(cls, fl.Chain), f"Adapter {cls.__name__} must be a Chain"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def target(self) -> T:
|
||||||
|
return self._target[0]
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def setup_adapter(self, target: T) -> Iterator[None]:
|
||||||
|
assert isinstance(self, fl.Chain)
|
||||||
|
assert (not hasattr(self, "_modules")) or (
|
||||||
|
len(self) == 0
|
||||||
|
), "Call the Chain constructor in the setup_adapter context."
|
||||||
|
self._target = [target]
|
||||||
|
|
||||||
|
if not isinstance(self.target, fl.ContextModule):
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
|
||||||
|
_old_can_refresh_parent = target._can_refresh_parent
|
||||||
|
target._can_refresh_parent = False
|
||||||
|
yield
|
||||||
|
target._can_refresh_parent = _old_can_refresh_parent
|
||||||
|
|
||||||
|
def inject(self: TAdapter, parent: fl.Chain | None = None) -> TAdapter:
|
||||||
|
assert isinstance(self, fl.Chain)
|
||||||
|
|
||||||
|
if (parent is None) and isinstance(self.target, fl.ContextModule):
|
||||||
|
parent = self.target.parent
|
||||||
|
if parent is not None:
|
||||||
|
assert isinstance(parent, fl.Chain), f"{self.target} has invalid parent {parent}"
|
||||||
|
|
||||||
|
target_parent = self.find_parent(self.target)
|
||||||
|
|
||||||
|
if parent is None:
|
||||||
|
if isinstance(self.target, fl.ContextModule):
|
||||||
|
self.target._set_parent(target_parent) # type: ignore[reportPrivateUsage]
|
||||||
|
return self
|
||||||
|
|
||||||
|
# In general, `true_parent` is `parent`. We do this to support multiple adaptation,
|
||||||
|
# i.e. initializing two adapters before injecting them.
|
||||||
|
true_parent = parent.ensure_find_parent(self.target)
|
||||||
|
true_parent.replace(
|
||||||
|
old_module=self.target,
|
||||||
|
new_module=self,
|
||||||
|
old_module_parent=target_parent,
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def eject(self) -> None:
|
||||||
|
assert isinstance(self, fl.Chain)
|
||||||
|
|
||||||
|
# In general, the "actual target" is the target.
|
||||||
|
# Here we deal with the edge case where the target
|
||||||
|
# is part of the replacement block and has been adapted by
|
||||||
|
# another adapter after this one. For instance, this is the
|
||||||
|
# case when stacking Controlnets.
|
||||||
|
actual_target = lookup_top_adapter(self, self.target)
|
||||||
|
|
||||||
|
if (parent := self.parent) is None:
|
||||||
|
if isinstance(actual_target, fl.ContextModule):
|
||||||
|
actual_target._set_parent(None) # type: ignore[reportPrivateUsage]
|
||||||
|
else:
|
||||||
|
parent.replace(old_module=self, new_module=actual_target)
|
||||||
|
|
||||||
|
def _pre_structural_copy(self) -> None:
|
||||||
|
if isinstance(self.target, fl.Chain):
|
||||||
|
raise RuntimeError("Chain adapters typically cannot be copied, eject them first.")
|
||||||
|
|
||||||
|
def _post_structural_copy(self: TAdapter, source: TAdapter) -> None:
|
||||||
|
self._target = [source.target]
|
||||||
|
|
||||||
|
|
||||||
|
def lookup_top_adapter(top: fl.Chain, target: fl.Module) -> fl.Module:
|
||||||
|
"""Lookup and return last adapter in parents tree (or target if none)."""
|
||||||
|
|
||||||
|
target_parent = top.find_parent(target)
|
||||||
|
if (target_parent is None) or (target_parent == top):
|
||||||
|
return target
|
||||||
|
|
||||||
|
r, p = target, target_parent
|
||||||
|
while p != top:
|
||||||
|
if isinstance(p, Adapter):
|
||||||
|
r = p
|
||||||
|
assert p.parent, f"parent tree of {top} is broken"
|
||||||
|
p = p.parent
|
||||||
|
return r
|
130
imaginairy/vendored/refiners/fluxion/adapters/lora.py
Normal file
130
imaginairy/vendored/refiners/fluxion/adapters/lora.py
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
from typing import Any, Generic, Iterable, TypeVar
|
||||||
|
|
||||||
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
|
from torch.nn import Parameter as TorchParameter
|
||||||
|
from torch.nn.init import normal_, zeros_
|
||||||
|
|
||||||
|
import imaginairy.vendored.refiners.fluxion.layers as fl
|
||||||
|
from imaginairy.vendored.refiners.fluxion.adapters.adapter import Adapter
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=fl.Chain)
|
||||||
|
TLoraAdapter = TypeVar("TLoraAdapter", bound="LoraAdapter[Any]") # Self (see PEP 673)
|
||||||
|
|
||||||
|
|
||||||
|
class Lora(fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_features: int,
|
||||||
|
out_features: int,
|
||||||
|
rank: int = 16,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.in_features = in_features
|
||||||
|
self.out_features = out_features
|
||||||
|
self.rank = rank
|
||||||
|
self.scale: float = 1.0
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
fl.Linear(in_features=in_features, out_features=rank, bias=False, device=device, dtype=dtype),
|
||||||
|
fl.Linear(in_features=rank, out_features=out_features, bias=False, device=device, dtype=dtype),
|
||||||
|
fl.Lambda(func=self.scale_outputs),
|
||||||
|
)
|
||||||
|
|
||||||
|
normal_(tensor=self.Linear_1.weight, std=1 / self.rank)
|
||||||
|
zeros_(tensor=self.Linear_2.weight)
|
||||||
|
|
||||||
|
def scale_outputs(self, x: Tensor) -> Tensor:
|
||||||
|
return x * self.scale
|
||||||
|
|
||||||
|
def set_scale(self, scale: float) -> None:
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None:
|
||||||
|
self.Linear_1.weight = TorchParameter(down_weight.to(device=self.device, dtype=self.dtype))
|
||||||
|
self.Linear_2.weight = TorchParameter(up_weight.to(device=self.device, dtype=self.dtype))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def up_weight(self) -> Tensor:
|
||||||
|
return self.Linear_2.weight.data
|
||||||
|
|
||||||
|
@property
|
||||||
|
def down_weight(self) -> Tensor:
|
||||||
|
return self.Linear_1.weight.data
|
||||||
|
|
||||||
|
|
||||||
|
class SingleLoraAdapter(fl.Sum, Adapter[fl.Linear]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
target: fl.Linear,
|
||||||
|
rank: int = 16,
|
||||||
|
scale: float = 1.0,
|
||||||
|
) -> None:
|
||||||
|
self.in_features = target.in_features
|
||||||
|
self.out_features = target.out_features
|
||||||
|
self.rank = rank
|
||||||
|
self.scale = scale
|
||||||
|
with self.setup_adapter(target):
|
||||||
|
super().__init__(
|
||||||
|
target,
|
||||||
|
Lora(
|
||||||
|
in_features=target.in_features,
|
||||||
|
out_features=target.out_features,
|
||||||
|
rank=rank,
|
||||||
|
device=target.device,
|
||||||
|
dtype=target.dtype,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.Lora.set_scale(scale=scale)
|
||||||
|
|
||||||
|
|
||||||
|
class LoraAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
target: T,
|
||||||
|
sub_targets: Iterable[tuple[fl.Linear, fl.Chain]],
|
||||||
|
rank: int | None = None,
|
||||||
|
scale: float = 1.0,
|
||||||
|
weights: list[Tensor] | None = None,
|
||||||
|
) -> None:
|
||||||
|
with self.setup_adapter(target):
|
||||||
|
super().__init__(target)
|
||||||
|
|
||||||
|
if weights is not None:
|
||||||
|
assert len(weights) % 2 == 0
|
||||||
|
weights_rank = weights[0].shape[1]
|
||||||
|
if rank is None:
|
||||||
|
rank = weights_rank
|
||||||
|
else:
|
||||||
|
assert rank == weights_rank
|
||||||
|
|
||||||
|
assert rank is not None, "either pass a rank or weights"
|
||||||
|
|
||||||
|
self.sub_targets = sub_targets
|
||||||
|
self.sub_adapters: list[tuple[SingleLoraAdapter, fl.Chain]] = []
|
||||||
|
|
||||||
|
for linear, parent in self.sub_targets:
|
||||||
|
self.sub_adapters.append((SingleLoraAdapter(target=linear, rank=rank, scale=scale), parent))
|
||||||
|
|
||||||
|
if weights is not None:
|
||||||
|
assert len(self.sub_adapters) == (len(weights) // 2)
|
||||||
|
for i, (adapter, _) in enumerate(self.sub_adapters):
|
||||||
|
lora = adapter.Lora
|
||||||
|
assert (
|
||||||
|
lora.rank == weights[i * 2].shape[1]
|
||||||
|
), f"Rank of Lora layer {lora.rank} must match shape of weights {weights[i*2].shape[1]}"
|
||||||
|
adapter.Lora.load_weights(up_weight=weights[i * 2], down_weight=weights[i * 2 + 1])
|
||||||
|
|
||||||
|
def inject(self: TLoraAdapter, parent: fl.Chain | None = None) -> TLoraAdapter:
|
||||||
|
for adapter, adapter_parent in self.sub_adapters:
|
||||||
|
adapter.inject(adapter_parent)
|
||||||
|
return super().inject(parent)
|
||||||
|
|
||||||
|
def eject(self) -> None:
|
||||||
|
for adapter, _ in self.sub_adapters:
|
||||||
|
adapter.eject()
|
||||||
|
super().eject()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def weights(self) -> list[Tensor]:
|
||||||
|
return [w for adapter, _ in self.sub_adapters for w in [adapter.Lora.up_weight, adapter.Lora.down_weight]]
|
53
imaginairy/vendored/refiners/fluxion/context.py
Normal file
53
imaginairy/vendored/refiners/fluxion/context.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
Context = dict[str, Any]
|
||||||
|
Contexts = dict[str, Context]
|
||||||
|
|
||||||
|
|
||||||
|
class ContextProvider:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.contexts: Contexts = {}
|
||||||
|
|
||||||
|
def set_context(self, key: str, value: Context) -> None:
|
||||||
|
self.contexts[key] = value
|
||||||
|
|
||||||
|
def get_context(self, key: str) -> Any:
|
||||||
|
return self.contexts.get(key)
|
||||||
|
|
||||||
|
def update_contexts(self, new_contexts: Contexts) -> None:
|
||||||
|
for key, value in new_contexts.items():
|
||||||
|
if key not in self.contexts:
|
||||||
|
self.contexts[key] = value
|
||||||
|
else:
|
||||||
|
self.contexts[key].update(value)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create(contexts: Contexts) -> "ContextProvider":
|
||||||
|
provider = ContextProvider()
|
||||||
|
provider.update_contexts(contexts)
|
||||||
|
return provider
|
||||||
|
|
||||||
|
def __add__(self, other: "ContextProvider") -> "ContextProvider":
|
||||||
|
self.contexts.update(other.contexts)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __lshift__(self, other: "ContextProvider") -> "ContextProvider":
|
||||||
|
other.contexts.update(self.contexts)
|
||||||
|
return other
|
||||||
|
|
||||||
|
def __bool__(self) -> bool:
|
||||||
|
return bool(self.contexts)
|
||||||
|
|
||||||
|
def _get_repr_for_value(self, value: Any) -> str:
|
||||||
|
if isinstance(value, Tensor):
|
||||||
|
return f"Tensor(shape={value.shape}, dtype={value.dtype}, device={value.device})"
|
||||||
|
return repr(value)
|
||||||
|
|
||||||
|
def _get_repr_for_dict(self, context_dict: Context) -> dict[str, str]:
|
||||||
|
return {key: self._get_repr_for_value(value) for key, value in context_dict.items()}
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
contexts_repr = {key: self._get_repr_for_dict(value) for key, value in self.contexts.items()}
|
||||||
|
return f"{self.__class__.__name__}(contexts={contexts_repr})"
|
110
imaginairy/vendored/refiners/fluxion/layers/__init__.py
Normal file
110
imaginairy/vendored/refiners/fluxion/layers/__init__.py
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
from imaginairy.vendored.refiners.fluxion.layers.activations import GLU, ApproximateGeLU, GeLU, ReLU, Sigmoid, SiLU
|
||||||
|
from imaginairy.vendored.refiners.fluxion.layers.attentions import Attention, SelfAttention, SelfAttention2d
|
||||||
|
from imaginairy.vendored.refiners.fluxion.layers.basics import (
|
||||||
|
Buffer,
|
||||||
|
Chunk,
|
||||||
|
Cos,
|
||||||
|
Flatten,
|
||||||
|
GetArg,
|
||||||
|
Identity,
|
||||||
|
Multiply,
|
||||||
|
Parameter,
|
||||||
|
Permute,
|
||||||
|
Reshape,
|
||||||
|
Sin,
|
||||||
|
Slicing,
|
||||||
|
Squeeze,
|
||||||
|
Transpose,
|
||||||
|
Unbind,
|
||||||
|
Unflatten,
|
||||||
|
Unsqueeze,
|
||||||
|
View,
|
||||||
|
)
|
||||||
|
from imaginairy.vendored.refiners.fluxion.layers.chain import (
|
||||||
|
Breakpoint,
|
||||||
|
Chain,
|
||||||
|
Concatenate,
|
||||||
|
Distribute,
|
||||||
|
Lambda,
|
||||||
|
Matmul,
|
||||||
|
Parallel,
|
||||||
|
Passthrough,
|
||||||
|
Residual,
|
||||||
|
Return,
|
||||||
|
SetContext,
|
||||||
|
Sum,
|
||||||
|
UseContext,
|
||||||
|
)
|
||||||
|
from imaginairy.vendored.refiners.fluxion.layers.conv import Conv2d, ConvTranspose2d
|
||||||
|
from imaginairy.vendored.refiners.fluxion.layers.converter import Converter
|
||||||
|
from imaginairy.vendored.refiners.fluxion.layers.embedding import Embedding
|
||||||
|
from imaginairy.vendored.refiners.fluxion.layers.linear import Linear, MultiLinear
|
||||||
|
from imaginairy.vendored.refiners.fluxion.layers.maxpool import MaxPool1d, MaxPool2d
|
||||||
|
from imaginairy.vendored.refiners.fluxion.layers.module import ContextModule, Module, WeightedModule
|
||||||
|
from imaginairy.vendored.refiners.fluxion.layers.norm import GroupNorm, InstanceNorm2d, LayerNorm, LayerNorm2d
|
||||||
|
from imaginairy.vendored.refiners.fluxion.layers.padding import ReflectionPad2d
|
||||||
|
from imaginairy.vendored.refiners.fluxion.layers.pixelshuffle import PixelUnshuffle
|
||||||
|
from imaginairy.vendored.refiners.fluxion.layers.sampling import Downsample, Interpolate, Upsample
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Embedding",
|
||||||
|
"LayerNorm",
|
||||||
|
"GroupNorm",
|
||||||
|
"LayerNorm2d",
|
||||||
|
"InstanceNorm2d",
|
||||||
|
"GeLU",
|
||||||
|
"GLU",
|
||||||
|
"SiLU",
|
||||||
|
"ReLU",
|
||||||
|
"ApproximateGeLU",
|
||||||
|
"Sigmoid",
|
||||||
|
"Attention",
|
||||||
|
"SelfAttention",
|
||||||
|
"SelfAttention2d",
|
||||||
|
"Identity",
|
||||||
|
"GetArg",
|
||||||
|
"View",
|
||||||
|
"Flatten",
|
||||||
|
"Unflatten",
|
||||||
|
"Transpose",
|
||||||
|
"Permute",
|
||||||
|
"Squeeze",
|
||||||
|
"Unsqueeze",
|
||||||
|
"Reshape",
|
||||||
|
"Slicing",
|
||||||
|
"Parameter",
|
||||||
|
"Sin",
|
||||||
|
"Cos",
|
||||||
|
"Chunk",
|
||||||
|
"Multiply",
|
||||||
|
"Unbind",
|
||||||
|
"Matmul",
|
||||||
|
"Buffer",
|
||||||
|
"Lambda",
|
||||||
|
"Return",
|
||||||
|
"Sum",
|
||||||
|
"Residual",
|
||||||
|
"Chain",
|
||||||
|
"UseContext",
|
||||||
|
"SetContext",
|
||||||
|
"Parallel",
|
||||||
|
"Distribute",
|
||||||
|
"Passthrough",
|
||||||
|
"Breakpoint",
|
||||||
|
"Concatenate",
|
||||||
|
"Conv2d",
|
||||||
|
"ConvTranspose2d",
|
||||||
|
"Linear",
|
||||||
|
"MultiLinear",
|
||||||
|
"Downsample",
|
||||||
|
"Upsample",
|
||||||
|
"Module",
|
||||||
|
"WeightedModule",
|
||||||
|
"ContextModule",
|
||||||
|
"Interpolate",
|
||||||
|
"ReflectionPad2d",
|
||||||
|
"PixelUnshuffle",
|
||||||
|
"Converter",
|
||||||
|
"MaxPool1d",
|
||||||
|
"MaxPool2d",
|
||||||
|
]
|
77
imaginairy/vendored/refiners/fluxion/layers/activations.py
Normal file
77
imaginairy/vendored/refiners/fluxion/layers/activations.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
from torch import Tensor, sigmoid
|
||||||
|
from torch.nn.functional import (
|
||||||
|
gelu, # type: ignore
|
||||||
|
silu,
|
||||||
|
)
|
||||||
|
|
||||||
|
from imaginairy.vendored.refiners.fluxion.layers.module import Module
|
||||||
|
|
||||||
|
|
||||||
|
class Activation(Module):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
|
||||||
|
class SiLU(Activation):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return silu(x) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class ReLU(Activation):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return x.relu()
|
||||||
|
|
||||||
|
|
||||||
|
class GeLU(Activation):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return gelu(x) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class ApproximateGeLU(Activation):
|
||||||
|
"""
|
||||||
|
The approximate form of Gaussian Error Linear Unit (GELU)
|
||||||
|
For more details, see section 2: https://arxiv.org/abs/1606.08415
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return x * sigmoid(1.702 * x)
|
||||||
|
|
||||||
|
|
||||||
|
class Sigmoid(Activation):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return x.sigmoid()
|
||||||
|
|
||||||
|
|
||||||
|
class GLU(Activation):
|
||||||
|
"""
|
||||||
|
Gated Linear Unit activation layer.
|
||||||
|
|
||||||
|
See https://arxiv.org/abs/2002.05202v1 for details.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, activation: Activation) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.activation = activation
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"{self.__class__.__name__}(activation={self.activation})"
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
assert x.shape[-1] % 2 == 0, "Non-batch input dimension must be divisible by 2"
|
||||||
|
output, gate = x.chunk(2, dim=-1)
|
||||||
|
return output * self.activation(gate)
|
246
imaginairy/vendored/refiners/fluxion/layers/attentions.py
Normal file
246
imaginairy/vendored/refiners/fluxion/layers/attentions.py
Normal file
@ -0,0 +1,246 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from jaxtyping import Float
|
||||||
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
|
from torch.nn.functional import scaled_dot_product_attention as _scaled_dot_product_attention # type: ignore
|
||||||
|
|
||||||
|
from imaginairy.vendored.refiners.fluxion.context import Contexts
|
||||||
|
from imaginairy.vendored.refiners.fluxion.layers.basics import Identity
|
||||||
|
from imaginairy.vendored.refiners.fluxion.layers.chain import Chain, Distribute, Lambda, Parallel
|
||||||
|
from imaginairy.vendored.refiners.fluxion.layers.linear import Linear
|
||||||
|
from imaginairy.vendored.refiners.fluxion.layers.module import Module
|
||||||
|
|
||||||
|
|
||||||
|
def scaled_dot_product_attention(
|
||||||
|
query: Float[Tensor, "batch source_sequence_length dim"],
|
||||||
|
key: Float[Tensor, "batch target_sequence_length dim"],
|
||||||
|
value: Float[Tensor, "batch target_sequence_length dim"],
|
||||||
|
is_causal: bool = False,
|
||||||
|
) -> Float[Tensor, "batch source_sequence_length dim"]:
|
||||||
|
return _scaled_dot_product_attention(query, key, value, is_causal=is_causal) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def sparse_dot_product_attention_non_optimized(
|
||||||
|
query: Float[Tensor, "batch source_sequence_length dim"],
|
||||||
|
key: Float[Tensor, "batch target_sequence_length dim"],
|
||||||
|
value: Float[Tensor, "batch target_sequence_length dim"],
|
||||||
|
is_causal: bool = False,
|
||||||
|
) -> Float[Tensor, "batch source_sequence_length dim"]:
|
||||||
|
if is_causal:
|
||||||
|
# TODO: implement causal attention
|
||||||
|
raise NotImplementedError("Causal attention for non_optimized attention is not yet implemented")
|
||||||
|
_, _, _, dim = query.shape
|
||||||
|
attention = query @ key.permute(0, 1, 3, 2)
|
||||||
|
attention = attention / math.sqrt(dim)
|
||||||
|
attention = torch.softmax(input=attention, dim=-1)
|
||||||
|
return attention @ value
|
||||||
|
|
||||||
|
|
||||||
|
class ScaledDotProductAttention(Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_heads: int = 1,
|
||||||
|
is_causal: bool | None = None,
|
||||||
|
is_optimized: bool = True,
|
||||||
|
slice_size: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.is_causal = is_causal
|
||||||
|
self.is_optimized = is_optimized
|
||||||
|
self.slice_size = slice_size
|
||||||
|
self.dot_product = (
|
||||||
|
scaled_dot_product_attention if self.is_optimized else sparse_dot_product_attention_non_optimized
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
query: Float[Tensor, "batch num_queries embedding_dim"],
|
||||||
|
key: Float[Tensor, "batch num_keys embedding_dim"],
|
||||||
|
value: Float[Tensor, "batch num_values embedding_dim"],
|
||||||
|
is_causal: bool | None = None,
|
||||||
|
) -> Float[Tensor, "batch num_queries dim"]:
|
||||||
|
if self.slice_size is None:
|
||||||
|
return self._process_attention(query, key, value, is_causal)
|
||||||
|
|
||||||
|
return self._sliced_attention(query, key, value, is_causal=is_causal, slice_size=self.slice_size)
|
||||||
|
|
||||||
|
def _sliced_attention(
|
||||||
|
self,
|
||||||
|
query: Float[Tensor, "batch num_queries embedding_dim"],
|
||||||
|
key: Float[Tensor, "batch num_keys embedding_dim"],
|
||||||
|
value: Float[Tensor, "batch num_values embedding_dim"],
|
||||||
|
slice_size: int,
|
||||||
|
is_causal: bool | None = None,
|
||||||
|
) -> Float[Tensor, "batch num_queries dim"]:
|
||||||
|
_, num_queries, _ = query.shape
|
||||||
|
output = torch.zeros_like(query)
|
||||||
|
for start_idx in range(0, num_queries, slice_size):
|
||||||
|
end_idx = min(start_idx + slice_size, num_queries)
|
||||||
|
output[:, start_idx:end_idx, :] = self._process_attention(
|
||||||
|
query[:, start_idx:end_idx, :], key, value, is_causal
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def _process_attention(
|
||||||
|
self,
|
||||||
|
query: Float[Tensor, "batch num_queries embedding_dim"],
|
||||||
|
key: Float[Tensor, "batch num_keys embedding_dim"],
|
||||||
|
value: Float[Tensor, "batch num_values embedding_dim"],
|
||||||
|
is_causal: bool | None = None,
|
||||||
|
) -> Float[Tensor, "batch num_queries dim"]:
|
||||||
|
return self.merge_multi_head(
|
||||||
|
x=self.dot_product(
|
||||||
|
query=self.split_to_multi_head(query),
|
||||||
|
key=self.split_to_multi_head(key),
|
||||||
|
value=self.split_to_multi_head(value),
|
||||||
|
is_causal=(
|
||||||
|
is_causal if is_causal is not None else (self.is_causal if self.is_causal is not None else False)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def split_to_multi_head(
|
||||||
|
self, x: Float[Tensor, "batch_size sequence_length embedding_dim"]
|
||||||
|
) -> Float[Tensor, "batch_size num_heads sequence_length (embedding_dim//num_heads)"]:
|
||||||
|
assert (
|
||||||
|
len(x.shape) == 3
|
||||||
|
), f"Expected tensor with shape (batch_size sequence_length embedding_dim), got {x.shape}"
|
||||||
|
assert (
|
||||||
|
x.shape[-1] % self.num_heads == 0
|
||||||
|
), f"Embedding dim (x.shape[-1]={x.shape[-1]}) must be divisible by num heads"
|
||||||
|
return x.reshape(x.shape[0], x.shape[1], self.num_heads, x.shape[-1] // self.num_heads).transpose(1, 2)
|
||||||
|
|
||||||
|
def merge_multi_head(
|
||||||
|
self, x: Float[Tensor, "batch_size num_heads sequence_length heads_dim"]
|
||||||
|
) -> Float[Tensor, "batch_size sequence_length heads_dim * num_heads"]:
|
||||||
|
return x.transpose(1, 2).reshape(x.shape[0], x.shape[2], self.num_heads * x.shape[-1])
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int,
|
||||||
|
num_heads: int = 1,
|
||||||
|
key_embedding_dim: int | None = None,
|
||||||
|
value_embedding_dim: int | None = None,
|
||||||
|
inner_dim: int | None = None,
|
||||||
|
use_bias: bool = True,
|
||||||
|
is_causal: bool | None = None,
|
||||||
|
is_optimized: bool = True,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
assert (
|
||||||
|
embedding_dim % num_heads == 0
|
||||||
|
), f"embedding_dim {embedding_dim} must be divisible by num_heads {num_heads}"
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.heads_dim = embedding_dim // num_heads
|
||||||
|
self.key_embedding_dim = key_embedding_dim or embedding_dim
|
||||||
|
self.value_embedding_dim = value_embedding_dim or embedding_dim
|
||||||
|
self.inner_dim = inner_dim or embedding_dim
|
||||||
|
self.use_bias = use_bias
|
||||||
|
self.is_causal = is_causal
|
||||||
|
self.is_optimized = is_optimized
|
||||||
|
super().__init__(
|
||||||
|
Distribute(
|
||||||
|
Linear(
|
||||||
|
in_features=self.embedding_dim,
|
||||||
|
out_features=self.inner_dim,
|
||||||
|
bias=self.use_bias,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
Linear(
|
||||||
|
in_features=self.key_embedding_dim,
|
||||||
|
out_features=self.inner_dim,
|
||||||
|
bias=self.use_bias,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
Linear(
|
||||||
|
in_features=self.value_embedding_dim,
|
||||||
|
out_features=self.inner_dim,
|
||||||
|
bias=self.use_bias,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
ScaledDotProductAttention(num_heads=num_heads, is_causal=is_causal, is_optimized=is_optimized),
|
||||||
|
Linear(
|
||||||
|
in_features=self.inner_dim,
|
||||||
|
out_features=self.embedding_dim,
|
||||||
|
bias=True,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAttention(Attention):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int,
|
||||||
|
inner_dim: int | None = None,
|
||||||
|
num_heads: int = 1,
|
||||||
|
use_bias: bool = True,
|
||||||
|
is_causal: bool | None = None,
|
||||||
|
is_optimized: bool = True,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
inner_dim=inner_dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
use_bias=use_bias,
|
||||||
|
is_causal=is_causal,
|
||||||
|
is_optimized=is_optimized,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
self.insert(0, Parallel(Identity(), Identity(), Identity()))
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAttention2d(SelfAttention):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
num_heads: int = 1,
|
||||||
|
use_bias: bool = True,
|
||||||
|
is_causal: bool | None = None,
|
||||||
|
is_optimized: bool = True,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
assert channels % num_heads == 0, f"channels {channels} must be divisible by num_heads {num_heads}"
|
||||||
|
self.channels = channels
|
||||||
|
super().__init__(
|
||||||
|
embedding_dim=channels,
|
||||||
|
num_heads=num_heads,
|
||||||
|
use_bias=use_bias,
|
||||||
|
is_causal=is_causal,
|
||||||
|
is_optimized=is_optimized,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
self.insert(0, Lambda(self.tensor_2d_to_sequence))
|
||||||
|
self.append(Lambda(self.sequence_to_tensor_2d))
|
||||||
|
|
||||||
|
def init_context(self) -> Contexts:
|
||||||
|
return {"reshape": {"height": None, "width": None}}
|
||||||
|
|
||||||
|
def tensor_2d_to_sequence(
|
||||||
|
self, x: Float[Tensor, "batch channels height width"]
|
||||||
|
) -> Float[Tensor, "batch height*width channels"]:
|
||||||
|
height, width = x.shape[-2:]
|
||||||
|
self.set_context(context="reshape", value={"height": height, "width": width})
|
||||||
|
return x.reshape(x.shape[0], x.shape[1], height * width).transpose(1, 2)
|
||||||
|
|
||||||
|
def sequence_to_tensor_2d(
|
||||||
|
self, x: Float[Tensor, "batch sequence_length channels"]
|
||||||
|
) -> Float[Tensor, "batch channels height width"]:
|
||||||
|
height, width = self.use_context("reshape").values()
|
||||||
|
return x.transpose(1, 2).reshape(x.shape[0], x.shape[2], height, width)
|
207
imaginairy/vendored/refiners/fluxion/layers/basics.py
Normal file
207
imaginairy/vendored/refiners/fluxion/layers/basics.py
Normal file
@ -0,0 +1,207 @@
|
|||||||
|
import torch
|
||||||
|
from torch import Size, Tensor, device as Device, dtype as DType, randn
|
||||||
|
from torch.nn import Parameter as TorchParameter
|
||||||
|
|
||||||
|
from imaginairy.vendored.refiners.fluxion.layers.module import Module, WeightedModule
|
||||||
|
|
||||||
|
|
||||||
|
class Identity(Module):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class View(Module):
|
||||||
|
def __init__(self, *shape: int) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.shape = shape
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return x.view(*self.shape)
|
||||||
|
|
||||||
|
|
||||||
|
class GetArg(Module):
|
||||||
|
def __init__(self, index: int) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.index = index
|
||||||
|
|
||||||
|
def forward(self, *args: Tensor) -> Tensor:
|
||||||
|
return args[self.index]
|
||||||
|
|
||||||
|
|
||||||
|
class Flatten(Module):
|
||||||
|
def __init__(self, start_dim: int = 0, end_dim: int = -1) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.start_dim = start_dim
|
||||||
|
self.end_dim = end_dim
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return x.flatten(self.start_dim, self.end_dim)
|
||||||
|
|
||||||
|
|
||||||
|
class Unflatten(Module):
|
||||||
|
def __init__(self, dim: int) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, sizes: Size) -> Tensor:
|
||||||
|
return x.unflatten(self.dim, sizes) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class Reshape(Module):
|
||||||
|
"""
|
||||||
|
Reshape the input tensor to the given shape. The shape must be compatible with the input tensor shape. The batch
|
||||||
|
dimension is preserved.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *shape: int) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.shape = shape
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return x.reshape(x.shape[0], *self.shape)
|
||||||
|
|
||||||
|
|
||||||
|
class Transpose(Module):
|
||||||
|
def __init__(self, dim0: int, dim1: int) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.dim0 = dim0
|
||||||
|
self.dim1 = dim1
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return x.transpose(self.dim0, self.dim1)
|
||||||
|
|
||||||
|
|
||||||
|
class Permute(Module):
|
||||||
|
def __init__(self, *dims: int) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.dims = dims
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return x.permute(*self.dims)
|
||||||
|
|
||||||
|
|
||||||
|
class Slicing(Module):
|
||||||
|
def __init__(self, dim: int = 0, start: int = 0, end: int | None = None, step: int = 1) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.start = start
|
||||||
|
self.end = end
|
||||||
|
self.step = step
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
dim_size = x.shape[self.dim]
|
||||||
|
start = self.start if self.start >= 0 else dim_size + self.start
|
||||||
|
end = self.end or dim_size
|
||||||
|
end = end if end >= 0 else dim_size + end
|
||||||
|
start = max(min(start, dim_size), 0)
|
||||||
|
end = max(min(end, dim_size), 0)
|
||||||
|
if start >= end:
|
||||||
|
return self.get_empty_slice(x)
|
||||||
|
indices = torch.arange(start=start, end=end, step=self.step, device=x.device)
|
||||||
|
return x.index_select(self.dim, indices)
|
||||||
|
|
||||||
|
def get_empty_slice(self, x: Tensor) -> Tensor:
|
||||||
|
"""
|
||||||
|
Return an empty slice of the same shape as the input tensor to mimic PyTorch's slicing behavior.
|
||||||
|
"""
|
||||||
|
shape = list(x.shape)
|
||||||
|
shape[self.dim] = 0
|
||||||
|
return torch.empty(*shape, device=x.device)
|
||||||
|
|
||||||
|
|
||||||
|
class Squeeze(Module):
|
||||||
|
def __init__(self, dim: int) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return x.squeeze(self.dim)
|
||||||
|
|
||||||
|
|
||||||
|
class Unsqueeze(Module):
|
||||||
|
def __init__(self, dim: int) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return x.unsqueeze(self.dim)
|
||||||
|
|
||||||
|
|
||||||
|
class Unbind(Module):
|
||||||
|
def __init__(self, dim: int = 0) -> None:
|
||||||
|
self.dim = dim
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> tuple[Tensor, ...]:
|
||||||
|
return x.unbind(dim=self.dim) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class Chunk(Module):
|
||||||
|
def __init__(self, chunks: int, dim: int = 0) -> None:
|
||||||
|
self.chunks = chunks
|
||||||
|
self.dim = dim
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> tuple[Tensor, ...]:
|
||||||
|
return x.chunk(chunks=self.chunks, dim=self.dim) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class Sin(Module):
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return torch.sin(input=x)
|
||||||
|
|
||||||
|
|
||||||
|
class Cos(Module):
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return torch.cos(input=x)
|
||||||
|
|
||||||
|
|
||||||
|
class Multiply(Module):
|
||||||
|
def __init__(self, scale: float = 1.0, bias: float = 0.0) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.scale = scale
|
||||||
|
self.bias = bias
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return self.scale * x + self.bias
|
||||||
|
|
||||||
|
|
||||||
|
class Parameter(WeightedModule):
|
||||||
|
"""
|
||||||
|
A layer that wraps a tensor as a parameter. This is useful to create a parameter that is not a weight or a bias.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *dims: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.dims = dims
|
||||||
|
self.weight = TorchParameter(randn(*dims, device=device, dtype=dtype))
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return self.weight.expand(x.shape[0], *self.dims)
|
||||||
|
|
||||||
|
|
||||||
|
class Buffer(WeightedModule):
|
||||||
|
"""
|
||||||
|
A layer that wraps a tensor as a buffer. This is useful to create a buffer that is not a weight or a bias.
|
||||||
|
|
||||||
|
Buffers are not trainable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *dims: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.dims = dims
|
||||||
|
self.register_buffer("buffer", randn(*dims, device=device, dtype=dtype))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> Device:
|
||||||
|
return self.buffer.device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self) -> DType:
|
||||||
|
return self.buffer.dtype
|
||||||
|
|
||||||
|
def forward(self, _: Tensor) -> Tensor:
|
||||||
|
return self.buffer
|
586
imaginairy/vendored/refiners/fluxion/layers/chain.py
Normal file
586
imaginairy/vendored/refiners/fluxion/layers/chain.py
Normal file
@ -0,0 +1,586 @@
|
|||||||
|
import inspect
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Any, Callable, Iterable, Iterator, TypeVar, cast, overload
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor, cat, device as Device, dtype as DType
|
||||||
|
|
||||||
|
from imaginairy.vendored.refiners.fluxion.context import ContextProvider, Contexts
|
||||||
|
from imaginairy.vendored.refiners.fluxion.layers.module import ContextModule, Module, ModuleTree, WeightedModule
|
||||||
|
from imaginairy.vendored.refiners.fluxion.utils import summarize_tensor
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=Module)
|
||||||
|
TChain = TypeVar("TChain", bound="Chain") # because Self (PEP 673) is not in 3.10
|
||||||
|
|
||||||
|
|
||||||
|
class Lambda(Module):
|
||||||
|
"""Lambda is a wrapper around a callable object that allows it to be used as a PyTorch module."""
|
||||||
|
|
||||||
|
def __init__(self, func: Callable[..., Any]) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.func = func
|
||||||
|
|
||||||
|
def forward(self, *args: Any) -> Any:
|
||||||
|
return self.func(*args)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
func_name = getattr(self.func, "__name__", "partial_function")
|
||||||
|
return f"Lambda({func_name}{str(inspect.signature(self.func))})"
|
||||||
|
|
||||||
|
|
||||||
|
def generate_unique_names(
|
||||||
|
modules: tuple[Module, ...],
|
||||||
|
) -> dict[str, Module]:
|
||||||
|
class_counts: dict[str, int] = {}
|
||||||
|
unique_names: list[tuple[str, Module]] = []
|
||||||
|
for module in modules:
|
||||||
|
class_name = module.__class__.__name__
|
||||||
|
class_counts[class_name] = class_counts.get(class_name, 0) + 1
|
||||||
|
name_counter: dict[str, int] = {}
|
||||||
|
for module in modules:
|
||||||
|
class_name = module.__class__.__name__
|
||||||
|
name_counter[class_name] = name_counter.get(class_name, 0) + 1
|
||||||
|
unique_name = f"{class_name}_{name_counter[class_name]}" if class_counts[class_name] > 1 else class_name
|
||||||
|
unique_names.append((unique_name, module))
|
||||||
|
return dict(unique_names)
|
||||||
|
|
||||||
|
|
||||||
|
class UseContext(ContextModule):
|
||||||
|
def __init__(self, context: str, key: str) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.context = context
|
||||||
|
self.key = key
|
||||||
|
self.func: Callable[[Any], Any] = lambda x: x
|
||||||
|
|
||||||
|
def __call__(self, *args: Any) -> Any:
|
||||||
|
context = self.use_context(self.context)
|
||||||
|
assert context, f"context {self.context} is unset"
|
||||||
|
value = context.get(self.key)
|
||||||
|
assert value is not None, f"context entry {self.context}.{self.key} is unset"
|
||||||
|
return self.func(value)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"{self.__class__.__name__}(context={repr(self.context)}, key={repr(self.key)})"
|
||||||
|
|
||||||
|
def compose(self, func: Callable[[Any], Any]) -> "UseContext":
|
||||||
|
self.func = func
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class SetContext(ContextModule):
|
||||||
|
"""A Module that sets a context value when executed.
|
||||||
|
|
||||||
|
The context need to pre exist in the context provider.
|
||||||
|
#TODO Is there a way to create the context if it doesn't exist?
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, context: str, key: str, callback: Callable[[Any, Any], Any] | None = None) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.context = context
|
||||||
|
self.key = key
|
||||||
|
self.callback = callback
|
||||||
|
|
||||||
|
def __call__(self, x: Tensor) -> Tensor:
|
||||||
|
if context := self.use_context(self.context):
|
||||||
|
if not self.callback:
|
||||||
|
context.update({self.key: x})
|
||||||
|
else:
|
||||||
|
self.callback(context[self.key], x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"{self.__class__.__name__}(context={repr(self.context)}, key={repr(self.key)})"
|
||||||
|
|
||||||
|
|
||||||
|
class ReturnException(Exception):
|
||||||
|
"""Exception raised when a Return module is encountered."""
|
||||||
|
|
||||||
|
def __init__(self, value: Tensor):
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
|
||||||
|
class Return(Module):
|
||||||
|
"""A Module that stops the execution of a Chain when encountered."""
|
||||||
|
|
||||||
|
def forward(self, x: Tensor):
|
||||||
|
raise ReturnException(x)
|
||||||
|
|
||||||
|
|
||||||
|
def structural_copy(m: T) -> T:
|
||||||
|
return m.structural_copy() if isinstance(m, ContextModule) else m
|
||||||
|
|
||||||
|
|
||||||
|
class ChainError(RuntimeError):
|
||||||
|
"""Exception raised when an error occurs during the execution of a Chain."""
|
||||||
|
|
||||||
|
def __init__(self, message: str, /) -> None:
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class Chain(ContextModule):
|
||||||
|
_modules: dict[str, Module]
|
||||||
|
_provider: ContextProvider
|
||||||
|
_tag = "CHAIN"
|
||||||
|
|
||||||
|
def __init__(self, *args: Module | Iterable[Module]) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._provider = ContextProvider()
|
||||||
|
modules = cast(
|
||||||
|
tuple[Module],
|
||||||
|
(
|
||||||
|
tuple(args[0])
|
||||||
|
if len(args) == 1 and isinstance(args[0], Iterable) and not isinstance(args[0], Chain)
|
||||||
|
else tuple(args)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
for module in modules:
|
||||||
|
# Violating this would mean a ContextModule ends up in two chains,
|
||||||
|
# with a single one correctly set as its parent.
|
||||||
|
assert (
|
||||||
|
(not isinstance(module, ContextModule))
|
||||||
|
or (not module._can_refresh_parent)
|
||||||
|
or (module.parent is None)
|
||||||
|
or (module.parent == self)
|
||||||
|
), f"{module.__class__.__name__} already has parent {module.parent.__class__.__name__}"
|
||||||
|
|
||||||
|
self._regenerate_keys(modules)
|
||||||
|
self._reset_context()
|
||||||
|
|
||||||
|
for module in self:
|
||||||
|
if isinstance(module, ContextModule) and module.parent != self:
|
||||||
|
module._set_parent(self)
|
||||||
|
|
||||||
|
def __setattr__(self, name: str, value: Any) -> None:
|
||||||
|
if isinstance(value, torch.nn.Module):
|
||||||
|
raise ValueError(
|
||||||
|
"Chain does not support setting modules by attribute. Instead, use a mutation method like `append` or"
|
||||||
|
" wrap it within a single element list to prevent pytorch from registering it as a submodule."
|
||||||
|
)
|
||||||
|
super().__setattr__(name, value)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider(self) -> ContextProvider:
|
||||||
|
return self._provider
|
||||||
|
|
||||||
|
def init_context(self) -> Contexts:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def _register_provider(self, context: Contexts | None = None) -> None:
|
||||||
|
if context:
|
||||||
|
self._provider.update_contexts(context)
|
||||||
|
|
||||||
|
for module in self:
|
||||||
|
if isinstance(module, Chain):
|
||||||
|
module._register_provider(context=self._provider.contexts)
|
||||||
|
|
||||||
|
def _reset_context(self) -> None:
|
||||||
|
self._register_provider(self.init_context())
|
||||||
|
|
||||||
|
def set_context(self, context: str, value: Any) -> None:
|
||||||
|
self._provider.set_context(context, value)
|
||||||
|
self._register_provider()
|
||||||
|
|
||||||
|
def _show_error_in_tree(self, name: str, /, max_lines: int = 20) -> str:
|
||||||
|
tree = ModuleTree(module=self)
|
||||||
|
classname_counter: dict[str, int] = defaultdict(int)
|
||||||
|
first_ancestor = self.get_parents()[-1] if self.get_parents() else self
|
||||||
|
|
||||||
|
def find_state_dict_key(module: Module, /) -> str | None:
|
||||||
|
for key, layer in module.named_modules():
|
||||||
|
if layer == self:
|
||||||
|
return ".".join((key, name))
|
||||||
|
return None
|
||||||
|
|
||||||
|
for child in tree:
|
||||||
|
classname, count = name.rsplit(sep="_", maxsplit=1) if "_" in name else (name, "1")
|
||||||
|
if child["class_name"] == classname:
|
||||||
|
classname_counter[classname] += 1
|
||||||
|
if classname_counter[classname] == int(count):
|
||||||
|
state_dict_key = find_state_dict_key(first_ancestor)
|
||||||
|
child["value"] = f">>> {child['value']} | {state_dict_key}"
|
||||||
|
break
|
||||||
|
|
||||||
|
tree_repr = tree._generate_tree_repr(tree.root, depth=3) # type: ignore[reportPrivateUsage]
|
||||||
|
|
||||||
|
lines = tree_repr.split(sep="\n")
|
||||||
|
error_line_idx = next((idx for idx, line in enumerate(iterable=lines) if line.startswith(">>>")), 0)
|
||||||
|
|
||||||
|
return ModuleTree.shorten_tree_repr(tree_repr, line_index=error_line_idx, max_lines=max_lines)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _pretty_print_args(*args: Any) -> str:
|
||||||
|
"""
|
||||||
|
Flatten nested tuples and print tensors with their shape and other informations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _flatten_tuple(t: Tensor | tuple[Any, ...], /) -> list[Any]:
|
||||||
|
if isinstance(t, tuple):
|
||||||
|
return [item for subtuple in t for item in _flatten_tuple(subtuple)]
|
||||||
|
else:
|
||||||
|
return [t]
|
||||||
|
|
||||||
|
flat_args = _flatten_tuple(args)
|
||||||
|
|
||||||
|
return "\n".join(
|
||||||
|
[
|
||||||
|
f"{idx}: {summarize_tensor(arg) if isinstance(arg, Tensor) else arg}"
|
||||||
|
for idx, arg in enumerate(iterable=flat_args)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def _filter_traceback(self, *frames: traceback.FrameSummary) -> list[traceback.FrameSummary]:
|
||||||
|
patterns_to_exclude = [
|
||||||
|
(r"torch/nn/modules/", r"^_call_impl$"),
|
||||||
|
(r"torch/nn/functional\.py", r""),
|
||||||
|
(r"refiners/fluxion/layers/", r"^_call_layer$"),
|
||||||
|
(r"refiners/fluxion/layers/", r"^forward$"),
|
||||||
|
(r"refiners/fluxion/layers/chain\.py", r""),
|
||||||
|
(r"", r"^_"),
|
||||||
|
]
|
||||||
|
|
||||||
|
def should_exclude(frame: traceback.FrameSummary, /) -> bool:
|
||||||
|
for filename_pattern, name_pattern in patterns_to_exclude:
|
||||||
|
if re.search(pattern=filename_pattern, string=frame.filename) and re.search(
|
||||||
|
pattern=name_pattern, string=frame.name
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
return [frame for frame in frames if not should_exclude(frame)]
|
||||||
|
|
||||||
|
def _call_layer(self, layer: Module, name: str, /, *args: Any) -> Any:
|
||||||
|
try:
|
||||||
|
return layer(*args)
|
||||||
|
except Exception as e:
|
||||||
|
exc_type, _, exc_traceback = sys.exc_info()
|
||||||
|
assert exc_type
|
||||||
|
tb_list = traceback.extract_tb(tb=exc_traceback)
|
||||||
|
filtered_tb_list = self._filter_traceback(*tb_list)
|
||||||
|
formatted_tb = "".join(traceback.format_list(extracted_list=filtered_tb_list))
|
||||||
|
pretty_args = Chain._pretty_print_args(args)
|
||||||
|
error_tree = self._show_error_in_tree(name)
|
||||||
|
|
||||||
|
exception_str = re.sub(pattern=r"\n\s*\n", repl="\n", string=str(object=e))
|
||||||
|
message = f"{formatted_tb}\n{exception_str}\n---------------\n{error_tree}\n{pretty_args}"
|
||||||
|
if "Error" not in exception_str:
|
||||||
|
message = f"{exc_type.__name__}:\n {message}"
|
||||||
|
|
||||||
|
raise ChainError(message) from None
|
||||||
|
|
||||||
|
def forward(self, *args: Any) -> Any:
|
||||||
|
result: tuple[Any] | Any = None
|
||||||
|
intermediate_args: tuple[Any, ...] = args
|
||||||
|
for name, layer in self._modules.items():
|
||||||
|
result = self._call_layer(layer, name, *intermediate_args)
|
||||||
|
intermediate_args = (result,) if not isinstance(result, tuple) else result
|
||||||
|
|
||||||
|
self._reset_context()
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _regenerate_keys(self, modules: Iterable[Module]) -> None:
|
||||||
|
self._modules = generate_unique_names(tuple(modules)) # type: ignore
|
||||||
|
|
||||||
|
def __add__(self, other: "Chain | Module | list[Module]") -> "Chain":
|
||||||
|
if isinstance(other, Module):
|
||||||
|
other = Chain(other)
|
||||||
|
if isinstance(other, list):
|
||||||
|
other = Chain(*other)
|
||||||
|
return Chain(*self, *other)
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __getitem__(self, key: int) -> Module:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __getitem__(self, key: str) -> Module:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __getitem__(self, key: slice) -> "Chain":
|
||||||
|
...
|
||||||
|
|
||||||
|
def __getitem__(self, key: int | str | slice) -> Module:
|
||||||
|
if isinstance(key, slice):
|
||||||
|
copy = self.structural_copy()
|
||||||
|
copy._regenerate_keys(modules=list(copy)[key])
|
||||||
|
return copy
|
||||||
|
elif isinstance(key, str):
|
||||||
|
return self._modules[key]
|
||||||
|
else:
|
||||||
|
return list(self)[key]
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[Module]:
|
||||||
|
return iter(self._modules.values())
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self._modules)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> Device | None:
|
||||||
|
wm = self.find(WeightedModule)
|
||||||
|
return None if wm is None else wm.device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self) -> DType | None:
|
||||||
|
wm = self.find(WeightedModule)
|
||||||
|
return None if wm is None else wm.dtype
|
||||||
|
|
||||||
|
def _walk(
|
||||||
|
self, predicate: Callable[[Module, "Chain"], bool] | None = None, recurse: bool = False
|
||||||
|
) -> Iterator[tuple[Module, "Chain"]]:
|
||||||
|
if predicate is None:
|
||||||
|
predicate = lambda _m, _p: True
|
||||||
|
for module in self:
|
||||||
|
try:
|
||||||
|
p = predicate(module, self)
|
||||||
|
except StopIteration:
|
||||||
|
continue
|
||||||
|
if p:
|
||||||
|
yield (module, self)
|
||||||
|
if not recurse:
|
||||||
|
continue
|
||||||
|
if isinstance(module, Chain):
|
||||||
|
yield from module.walk(predicate, recurse)
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def walk(
|
||||||
|
self, predicate: Callable[[Module, "Chain"], bool] | None = None, recurse: bool = False
|
||||||
|
) -> Iterator[tuple[Module, "Chain"]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def walk(self, predicate: type[T], recurse: bool = False) -> Iterator[tuple[T, "Chain"]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
def walk(
|
||||||
|
self, predicate: type[T] | Callable[[Module, "Chain"], bool] | None = None, recurse: bool = False
|
||||||
|
) -> Iterator[tuple[T, "Chain"]] | Iterator[tuple[Module, "Chain"]]:
|
||||||
|
if isinstance(predicate, type):
|
||||||
|
return self._walk(lambda m, _: isinstance(m, predicate), recurse)
|
||||||
|
else:
|
||||||
|
return self._walk(predicate, recurse)
|
||||||
|
|
||||||
|
def layers(self, layer_type: type[T], recurse: bool = False) -> Iterator[T]:
|
||||||
|
for module, _ in self.walk(layer_type, recurse):
|
||||||
|
yield module
|
||||||
|
|
||||||
|
def find(self, layer_type: type[T]) -> T | None:
|
||||||
|
return next(self.layers(layer_type=layer_type), None)
|
||||||
|
|
||||||
|
def ensure_find(self, layer_type: type[T]) -> T:
|
||||||
|
r = self.find(layer_type)
|
||||||
|
assert r is not None, f"could not find {layer_type} in {self}"
|
||||||
|
return r
|
||||||
|
|
||||||
|
def find_parent(self, module: Module) -> "Chain | None":
|
||||||
|
if module in self: # avoid DFS-crawling the whole tree
|
||||||
|
return self
|
||||||
|
for _, parent in self.walk(lambda m, _: m == module):
|
||||||
|
return parent
|
||||||
|
return None
|
||||||
|
|
||||||
|
def ensure_find_parent(self, module: Module) -> "Chain":
|
||||||
|
r = self.find_parent(module)
|
||||||
|
assert r is not None, f"could not find {module} in {self}"
|
||||||
|
return r
|
||||||
|
|
||||||
|
def insert(self, index: int, module: Module) -> None:
|
||||||
|
if index < 0:
|
||||||
|
index = max(0, len(self._modules) + index + 1)
|
||||||
|
modules = list(self)
|
||||||
|
modules.insert(index, module)
|
||||||
|
self._regenerate_keys(modules)
|
||||||
|
if isinstance(module, ContextModule):
|
||||||
|
module._set_parent(self)
|
||||||
|
self._register_provider()
|
||||||
|
|
||||||
|
def insert_before_type(self, module_type: type[Module], new_module: Module) -> None:
|
||||||
|
for i, module in enumerate(self):
|
||||||
|
if isinstance(module, module_type):
|
||||||
|
self.insert(i, new_module)
|
||||||
|
return
|
||||||
|
raise ValueError(f"No module of type {module_type.__name__} found in the chain.")
|
||||||
|
|
||||||
|
def insert_after_type(self, module_type: type[Module], new_module: Module) -> None:
|
||||||
|
for i, module in enumerate(self):
|
||||||
|
if isinstance(module, module_type):
|
||||||
|
self.insert(i + 1, new_module)
|
||||||
|
return
|
||||||
|
raise ValueError(f"No module of type {module_type.__name__} found in the chain.")
|
||||||
|
|
||||||
|
def append(self, module: Module) -> None:
|
||||||
|
self.insert(-1, module)
|
||||||
|
|
||||||
|
def pop(self, index: int = -1) -> Module | tuple[Module]:
|
||||||
|
modules = list(self)
|
||||||
|
if index < 0:
|
||||||
|
index = len(modules) + index
|
||||||
|
if index < 0 or index >= len(modules):
|
||||||
|
raise IndexError("Index out of range.")
|
||||||
|
removed_module = modules.pop(index)
|
||||||
|
if isinstance(removed_module, ContextModule):
|
||||||
|
removed_module._set_parent(None)
|
||||||
|
self._regenerate_keys(modules)
|
||||||
|
return removed_module
|
||||||
|
|
||||||
|
def remove(self, module: Module) -> None:
|
||||||
|
"""Remove a module from the chain."""
|
||||||
|
modules = list(self)
|
||||||
|
try:
|
||||||
|
modules.remove(module)
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(f"{module} is not in {self}")
|
||||||
|
self._regenerate_keys(modules)
|
||||||
|
if isinstance(module, ContextModule):
|
||||||
|
module._set_parent(None)
|
||||||
|
|
||||||
|
def replace(
|
||||||
|
self,
|
||||||
|
old_module: Module,
|
||||||
|
new_module: Module,
|
||||||
|
old_module_parent: "Chain | None" = None,
|
||||||
|
) -> None:
|
||||||
|
"""Replace a module in the chain with a new module."""
|
||||||
|
modules = list(self)
|
||||||
|
try:
|
||||||
|
modules[modules.index(old_module)] = new_module
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(f"{old_module} is not in {self}")
|
||||||
|
self._regenerate_keys(modules)
|
||||||
|
if isinstance(new_module, ContextModule):
|
||||||
|
new_module._set_parent(self)
|
||||||
|
if isinstance(old_module, ContextModule):
|
||||||
|
old_module._set_parent(old_module_parent)
|
||||||
|
|
||||||
|
def structural_copy(self: TChain) -> TChain:
|
||||||
|
"""Copy the structure of the Chain tree.
|
||||||
|
|
||||||
|
This method returns a recursive copy of the Chain tree where all inner nodes
|
||||||
|
(instances of Chain and its subclasses) are duplicated and all leaves
|
||||||
|
(regular Modules) are not.
|
||||||
|
|
||||||
|
Such copies can be adapted without disrupting the base model, but do not
|
||||||
|
require extra GPU memory since the weights are in the leaves and hence not copied.
|
||||||
|
"""
|
||||||
|
if hasattr(self, "_pre_structural_copy"):
|
||||||
|
self._pre_structural_copy()
|
||||||
|
|
||||||
|
modules = [structural_copy(m) for m in self]
|
||||||
|
clone = super().structural_copy()
|
||||||
|
clone._provider = ContextProvider.create(clone.init_context())
|
||||||
|
|
||||||
|
for module in modules:
|
||||||
|
clone.append(module=module)
|
||||||
|
|
||||||
|
if hasattr(clone, "_post_structural_copy"):
|
||||||
|
clone._post_structural_copy(self)
|
||||||
|
|
||||||
|
return clone
|
||||||
|
|
||||||
|
def _show_only_tag(self) -> bool:
|
||||||
|
return self.__class__ == Chain
|
||||||
|
|
||||||
|
|
||||||
|
class Parallel(Chain):
|
||||||
|
_tag = "PAR"
|
||||||
|
|
||||||
|
def forward(self, *args: Any) -> tuple[Tensor, ...]:
|
||||||
|
return tuple([self._call_layer(module, name, *args) for name, module in self._modules.items()])
|
||||||
|
|
||||||
|
def _show_only_tag(self) -> bool:
|
||||||
|
return self.__class__ == Parallel
|
||||||
|
|
||||||
|
|
||||||
|
class Distribute(Chain):
|
||||||
|
_tag = "DISTR"
|
||||||
|
|
||||||
|
def forward(self, *args: Any) -> tuple[Tensor, ...]:
|
||||||
|
n, m = len(args), len(self._modules)
|
||||||
|
assert n == m, f"Number of positional arguments ({n}) must match number of sub-modules ({m})."
|
||||||
|
return tuple([self._call_layer(module, name, arg) for arg, (name, module) in zip(args, self._modules.items())])
|
||||||
|
|
||||||
|
def _show_only_tag(self) -> bool:
|
||||||
|
return self.__class__ == Distribute
|
||||||
|
|
||||||
|
|
||||||
|
class Passthrough(Chain):
|
||||||
|
_tag = "PASS"
|
||||||
|
|
||||||
|
def forward(self, *inputs: Any) -> Any:
|
||||||
|
super().forward(*inputs)
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def _show_only_tag(self) -> bool:
|
||||||
|
return self.__class__ == Passthrough
|
||||||
|
|
||||||
|
|
||||||
|
class Sum(Chain):
|
||||||
|
_tag = "SUM"
|
||||||
|
|
||||||
|
def forward(self, *inputs: Any) -> Any:
|
||||||
|
output = None
|
||||||
|
for layer in self:
|
||||||
|
layer_output: Any = layer(*inputs)
|
||||||
|
if isinstance(layer_output, tuple):
|
||||||
|
layer_output = sum(layer_output) # type: ignore
|
||||||
|
output = layer_output if output is None else output + layer_output
|
||||||
|
return output
|
||||||
|
|
||||||
|
def _show_only_tag(self) -> bool:
|
||||||
|
return self.__class__ == Sum
|
||||||
|
|
||||||
|
|
||||||
|
class Residual(Chain):
|
||||||
|
_tag = "RES"
|
||||||
|
|
||||||
|
def forward(self, *inputs: Any) -> Any:
|
||||||
|
assert len(inputs) == 1, "Residual connection can only be used with a single input."
|
||||||
|
return super().forward(*inputs) + inputs[0]
|
||||||
|
|
||||||
|
|
||||||
|
class Breakpoint(ContextModule):
|
||||||
|
def __init__(self, vscode: bool = True):
|
||||||
|
super().__init__()
|
||||||
|
self.vscode = vscode
|
||||||
|
|
||||||
|
def forward(self, *args: Any):
|
||||||
|
if self.vscode:
|
||||||
|
import debugpy # type: ignore
|
||||||
|
|
||||||
|
debugpy.breakpoint() # type: ignore
|
||||||
|
else:
|
||||||
|
breakpoint()
|
||||||
|
return args[0] if len(args) == 1 else args
|
||||||
|
|
||||||
|
|
||||||
|
class Concatenate(Chain):
|
||||||
|
_tag = "CAT"
|
||||||
|
|
||||||
|
def __init__(self, *modules: Module, dim: int = 0) -> None:
|
||||||
|
super().__init__(*modules)
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
def forward(self, *args: Any) -> Tensor:
|
||||||
|
outputs = [module(*args) for module in self]
|
||||||
|
return cat([output for output in outputs if output is not None], dim=self.dim)
|
||||||
|
|
||||||
|
def _show_only_tag(self) -> bool:
|
||||||
|
return self.__class__ == Concatenate
|
||||||
|
|
||||||
|
|
||||||
|
class Matmul(Chain):
|
||||||
|
_tag = "MATMUL"
|
||||||
|
|
||||||
|
def __init__(self, input: Module, other: Module) -> None:
|
||||||
|
super().__init__(
|
||||||
|
input,
|
||||||
|
other,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, *args: Tensor) -> Tensor:
|
||||||
|
return torch.matmul(input=self[0](*args), other=self[1](*args))
|
96
imaginairy/vendored/refiners/fluxion/layers/conv.py
Normal file
96
imaginairy/vendored/refiners/fluxion/layers/conv.py
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
from torch import device as Device, dtype as DType, nn
|
||||||
|
|
||||||
|
from imaginairy.vendored.refiners.fluxion.layers.module import WeightedModule
|
||||||
|
|
||||||
|
|
||||||
|
class Conv2d(nn.Conv2d, WeightedModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: int | tuple[int, int],
|
||||||
|
stride: int | tuple[int, int] = (1, 1),
|
||||||
|
padding: int | tuple[int, int] | str = (0, 0),
|
||||||
|
groups: int = 1,
|
||||||
|
use_bias: bool = True,
|
||||||
|
dilation: int | tuple[int, int] = (1, 1),
|
||||||
|
padding_mode: str = "zeros",
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__( # type: ignore
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
dilation,
|
||||||
|
groups,
|
||||||
|
use_bias,
|
||||||
|
padding_mode,
|
||||||
|
device,
|
||||||
|
dtype,
|
||||||
|
)
|
||||||
|
self.use_bias = use_bias
|
||||||
|
|
||||||
|
|
||||||
|
class Conv1d(nn.Conv1d, WeightedModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: int | tuple[int],
|
||||||
|
stride: int | tuple[int] = 1,
|
||||||
|
padding: int | tuple[int] | str = 0,
|
||||||
|
groups: int = 1,
|
||||||
|
use_bias: bool = True,
|
||||||
|
dilation: int | tuple[int] = 1,
|
||||||
|
padding_mode: str = "zeros",
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__( # type: ignore
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
dilation,
|
||||||
|
groups,
|
||||||
|
use_bias,
|
||||||
|
padding_mode,
|
||||||
|
device,
|
||||||
|
dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ConvTranspose2d(nn.ConvTranspose2d, WeightedModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: int | tuple[int, int],
|
||||||
|
stride: int | tuple[int, int] = 1,
|
||||||
|
padding: int | tuple[int, int] = 0,
|
||||||
|
output_padding: int | tuple[int, int] = 0,
|
||||||
|
groups: int = 1,
|
||||||
|
use_bias: bool = True,
|
||||||
|
dilation: int | tuple[int, int] = 1,
|
||||||
|
padding_mode: str = "zeros",
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__( # type: ignore
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
output_padding=output_padding,
|
||||||
|
dilation=dilation,
|
||||||
|
groups=groups,
|
||||||
|
bias=use_bias,
|
||||||
|
padding_mode=padding_mode,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
45
imaginairy/vendored/refiners/fluxion/layers/converter.py
Normal file
45
imaginairy/vendored/refiners/fluxion/layers/converter.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from imaginairy.vendored.refiners.fluxion.layers.module import ContextModule
|
||||||
|
|
||||||
|
|
||||||
|
class Converter(ContextModule):
|
||||||
|
"""
|
||||||
|
A Converter class that adjusts tensor properties based on a parent module's settings.
|
||||||
|
|
||||||
|
This class inherits from `ContextModule` and provides functionality to adjust
|
||||||
|
the device and dtype of input tensor(s) to match the parent module's attributes.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
set_device (bool): If True, matches the device of the input tensor(s) to the parent's device.
|
||||||
|
set_dtype (bool): If True, matches the dtype of the input tensor(s) to the parent's dtype.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Ensure the parent module has `device` and `dtype` attributes if `set_device` or `set_dtype` are set to True.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, set_device: bool = True, set_dtype: bool = True) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.set_device = set_device
|
||||||
|
self.set_dtype = set_dtype
|
||||||
|
|
||||||
|
def forward(self, *inputs: Tensor) -> tuple[Tensor, ...]:
|
||||||
|
parent = self.ensure_parent
|
||||||
|
converted_tensors: list[Tensor] = []
|
||||||
|
|
||||||
|
for x in inputs:
|
||||||
|
if self.set_device:
|
||||||
|
device = parent.device
|
||||||
|
assert device is not None, "parent has no device"
|
||||||
|
x = x.to(device=device)
|
||||||
|
if self.set_dtype:
|
||||||
|
dtype = parent.dtype
|
||||||
|
assert dtype is not None, "parent has no dtype"
|
||||||
|
x = x.to(dtype=dtype)
|
||||||
|
|
||||||
|
converted_tensors.append(x)
|
||||||
|
|
||||||
|
return tuple(converted_tensors)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"{self.__class__.__name__}(set_device={self.set_device}, set_dtype={self.set_dtype})"
|
21
imaginairy/vendored/refiners/fluxion/layers/embedding.py
Normal file
21
imaginairy/vendored/refiners/fluxion/layers/embedding.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
from jaxtyping import Float, Int
|
||||||
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
|
from torch.nn import Embedding as _Embedding
|
||||||
|
|
||||||
|
from imaginairy.vendored.refiners.fluxion.layers.module import WeightedModule
|
||||||
|
|
||||||
|
|
||||||
|
class Embedding(_Embedding, WeightedModule): # type: ignore
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_embeddings: int,
|
||||||
|
embedding_dim: int,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
):
|
||||||
|
_Embedding.__init__( # type: ignore
|
||||||
|
self, num_embeddings=num_embeddings, embedding_dim=embedding_dim, device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: Int[Tensor, "batch length"]) -> Float[Tensor, "batch length embedding_dim"]: # type: ignore
|
||||||
|
return super().forward(x)
|
49
imaginairy/vendored/refiners/fluxion/layers/linear.py
Normal file
49
imaginairy/vendored/refiners/fluxion/layers/linear.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
from jaxtyping import Float
|
||||||
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
|
from torch.nn import Linear as _Linear
|
||||||
|
|
||||||
|
from imaginairy.vendored.refiners.fluxion.layers.activations import ReLU
|
||||||
|
from imaginairy.vendored.refiners.fluxion.layers.chain import Chain
|
||||||
|
from imaginairy.vendored.refiners.fluxion.layers.module import Module, WeightedModule
|
||||||
|
|
||||||
|
|
||||||
|
class Linear(_Linear, WeightedModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_features: int,
|
||||||
|
out_features: int,
|
||||||
|
bias: bool = True,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.in_features = in_features
|
||||||
|
self.out_features = out_features
|
||||||
|
super().__init__( # type: ignore
|
||||||
|
in_features=in_features,
|
||||||
|
out_features=out_features,
|
||||||
|
bias=bias,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: Float[Tensor, "batch in_features"]) -> Float[Tensor, "batch out_features"]: # type: ignore
|
||||||
|
return super().forward(x)
|
||||||
|
|
||||||
|
|
||||||
|
class MultiLinear(Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dim: int,
|
||||||
|
output_dim: int,
|
||||||
|
inner_dim: int,
|
||||||
|
num_layers: int,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
layers: list[Module] = []
|
||||||
|
for i in range(num_layers - 1):
|
||||||
|
layers.append(Linear(input_dim if i == 0 else inner_dim, inner_dim, device=device, dtype=dtype))
|
||||||
|
layers.append(ReLU())
|
||||||
|
layers.append(Linear(inner_dim, output_dim, device=device, dtype=dtype))
|
||||||
|
|
||||||
|
super().__init__(layers)
|
43
imaginairy/vendored/refiners/fluxion/layers/maxpool.py
Normal file
43
imaginairy/vendored/refiners/fluxion/layers/maxpool.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from imaginairy.vendored.refiners.fluxion.layers.module import Module
|
||||||
|
|
||||||
|
|
||||||
|
class MaxPool1d(nn.MaxPool1d, Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
kernel_size: int,
|
||||||
|
stride: int | None = None,
|
||||||
|
padding: int = 0,
|
||||||
|
dilation: int = 1,
|
||||||
|
return_indices: bool = False,
|
||||||
|
ceil_mode: bool = False,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
return_indices=return_indices,
|
||||||
|
ceil_mode=ceil_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MaxPool2d(nn.MaxPool2d, Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
kernel_size: int | tuple[int, int],
|
||||||
|
stride: int | tuple[int, int] | None = None,
|
||||||
|
padding: int | tuple[int, int] = (0, 0),
|
||||||
|
dilation: int | tuple[int, int] = (1, 1),
|
||||||
|
return_indices: bool = False,
|
||||||
|
ceil_mode: bool = False,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding, # type: ignore
|
||||||
|
dilation=dilation,
|
||||||
|
return_indices=return_indices,
|
||||||
|
ceil_mode=ceil_mode,
|
||||||
|
)
|
264
imaginairy/vendored/refiners/fluxion/layers/module.py
Normal file
264
imaginairy/vendored/refiners/fluxion/layers/module.py
Normal file
@ -0,0 +1,264 @@
|
|||||||
|
import sys
|
||||||
|
from collections import defaultdict
|
||||||
|
from inspect import Parameter, signature
|
||||||
|
from pathlib import Path
|
||||||
|
from types import ModuleType
|
||||||
|
from typing import TYPE_CHECKING, Any, DefaultDict, Generator, Sequence, TypedDict, TypeVar, cast
|
||||||
|
|
||||||
|
from torch import device as Device, dtype as DType
|
||||||
|
from torch.nn.modules.module import Module as TorchModule
|
||||||
|
|
||||||
|
from imaginairy.vendored.refiners.fluxion.context import Context, ContextProvider
|
||||||
|
from imaginairy.vendored.refiners.fluxion.utils import load_from_safetensors
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from imaginairy.vendored.refiners.fluxion.layers.chain import Chain
|
||||||
|
|
||||||
|
T = TypeVar("T", bound="Module")
|
||||||
|
TContextModule = TypeVar("TContextModule", bound="ContextModule")
|
||||||
|
BasicType = str | float | int | bool
|
||||||
|
|
||||||
|
|
||||||
|
class Module(TorchModule):
|
||||||
|
_parameters: dict[str, Any]
|
||||||
|
_buffers: dict[str, Any]
|
||||||
|
_tag: str = ""
|
||||||
|
|
||||||
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
|
super().__init__(*args, *kwargs) # type: ignore[reportUnknownMemberType]
|
||||||
|
|
||||||
|
def __getattr__(self, name: str) -> Any:
|
||||||
|
return super().__getattr__(name=name)
|
||||||
|
|
||||||
|
def __setattr__(self, name: str, value: Any) -> None:
|
||||||
|
return super().__setattr__(name=name, value=value)
|
||||||
|
|
||||||
|
def load_from_safetensors(self, tensors_path: str | Path, strict: bool = True) -> "Module":
|
||||||
|
state_dict = load_from_safetensors(tensors_path)
|
||||||
|
self.load_state_dict(state_dict, strict=strict)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def named_modules(self, *args: Any, **kwargs: Any) -> "Generator[tuple[str, Module], None, None]": # type: ignore
|
||||||
|
return super().named_modules(*args) # type: ignore
|
||||||
|
|
||||||
|
def to(self: T, device: Device | str | None = None, dtype: DType | None = None) -> T: # type: ignore
|
||||||
|
return super().to(device=device, dtype=dtype) # type: ignore
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
basic_attributes_str = ", ".join(
|
||||||
|
f"{key}={value}" for key, value in self.basic_attributes(init_attrs_only=True).items()
|
||||||
|
)
|
||||||
|
result = f"{self.__class__.__name__}({basic_attributes_str})"
|
||||||
|
return result
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
tree = ModuleTree(module=self)
|
||||||
|
return repr(tree)
|
||||||
|
|
||||||
|
def pretty_print(self, depth: int = -1) -> None:
|
||||||
|
tree = ModuleTree(module=self)
|
||||||
|
print(tree._generate_tree_repr(tree.root, is_root=True, depth=depth)) # type: ignore[reportPrivateUsage]
|
||||||
|
|
||||||
|
def basic_attributes(self, init_attrs_only: bool = False) -> dict[str, BasicType]:
|
||||||
|
"""Return a dictionary of basic attributes of the module.
|
||||||
|
|
||||||
|
Basic attributes are public attributes made of basic types (int, float, str, bool) or a sequence of basic types.
|
||||||
|
"""
|
||||||
|
sig = signature(obj=self.__init__)
|
||||||
|
init_params = set(sig.parameters.keys()) - {"self"}
|
||||||
|
default_values = {k: v.default for k, v in sig.parameters.items() if v.default is not Parameter.empty}
|
||||||
|
|
||||||
|
def is_basic_attribute(key: str, value: Any) -> bool:
|
||||||
|
if key.startswith("_"):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if isinstance(value, BasicType):
|
||||||
|
return True
|
||||||
|
|
||||||
|
if isinstance(value, Sequence) and all(isinstance(y, BasicType) for y in cast(Sequence[Any], value)):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
return {
|
||||||
|
key: str(object=value)
|
||||||
|
for key, value in self.__dict__.items()
|
||||||
|
if is_basic_attribute(key=key, value=value)
|
||||||
|
and (not init_attrs_only or (key in init_params and value != default_values.get(key)))
|
||||||
|
}
|
||||||
|
|
||||||
|
def _show_only_tag(self) -> bool:
|
||||||
|
"""Whether to show only the tag when printing the module.
|
||||||
|
|
||||||
|
This is useful to distinguish between Chain subclasses that override their forward from one another.
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class ContextModule(Module):
|
||||||
|
# we store parent into a one element list to avoid pytorch thinking it's a submodule
|
||||||
|
_parent: "list[Chain]"
|
||||||
|
_can_refresh_parent: bool = True # see usage in Adapter and Chain
|
||||||
|
|
||||||
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
|
super().__init__(*args, *kwargs)
|
||||||
|
self._parent = []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parent(self) -> "Chain | None":
|
||||||
|
return self._parent[0] if self._parent else None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ensure_parent(self) -> "Chain":
|
||||||
|
assert self._parent, "module is not bound to a Chain"
|
||||||
|
return self._parent[0]
|
||||||
|
|
||||||
|
def _set_parent(self, parent: "Chain | None") -> None:
|
||||||
|
if not self._can_refresh_parent:
|
||||||
|
return
|
||||||
|
if parent is None:
|
||||||
|
self._parent = []
|
||||||
|
return
|
||||||
|
# Always insert the module in the Chain first to avoid inconsistencies.
|
||||||
|
assert self in iter(parent), f"{self} not in {parent}"
|
||||||
|
self._parent = [parent]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider(self) -> ContextProvider:
|
||||||
|
return self.ensure_parent.provider
|
||||||
|
|
||||||
|
def get_parents(self) -> "list[Chain]":
|
||||||
|
return self._parent + self._parent[0].get_parents() if self._parent else []
|
||||||
|
|
||||||
|
def use_context(self, context_name: str) -> Context:
|
||||||
|
"""Retrieve the context object from the module's context provider."""
|
||||||
|
context = self.provider.get_context(context_name)
|
||||||
|
assert context is not None, f"Context {context_name} not found."
|
||||||
|
return context
|
||||||
|
|
||||||
|
def structural_copy(self: TContextModule) -> TContextModule:
|
||||||
|
clone = object.__new__(self.__class__)
|
||||||
|
|
||||||
|
not_torch_attributes = [
|
||||||
|
key
|
||||||
|
for key, value in self.__dict__.items()
|
||||||
|
if not key.startswith("_")
|
||||||
|
and isinstance(sys.modules.get(type(value).__module__), ModuleType)
|
||||||
|
and "torch" not in sys.modules[type(value).__module__].__name__
|
||||||
|
]
|
||||||
|
|
||||||
|
for k in not_torch_attributes:
|
||||||
|
setattr(clone, k, getattr(self, k))
|
||||||
|
|
||||||
|
ContextModule.__init__(self=clone)
|
||||||
|
|
||||||
|
return clone
|
||||||
|
|
||||||
|
|
||||||
|
class WeightedModule(Module):
|
||||||
|
@property
|
||||||
|
def device(self) -> Device:
|
||||||
|
return self.weight.device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self) -> DType:
|
||||||
|
return self.weight.dtype
|
||||||
|
|
||||||
|
|
||||||
|
class TreeNode(TypedDict):
|
||||||
|
value: str
|
||||||
|
class_name: str
|
||||||
|
children: list["TreeNode"]
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleTree:
|
||||||
|
def __init__(self, module: Module) -> None:
|
||||||
|
self.root: TreeNode = self._module_to_tree(module=module)
|
||||||
|
self._fold_successive_identical(node=self.root)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"{self.__class__.__name__}(root={self.root['value']})"
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return self._generate_tree_repr(self.root, is_root=True, depth=7)
|
||||||
|
|
||||||
|
def __iter__(self) -> Generator[TreeNode, None, None]:
|
||||||
|
for child in self.root["children"]:
|
||||||
|
yield child
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def shorten_tree_repr(cls, tree_repr: str, /, line_index: int = 0, max_lines: int = 20) -> str:
|
||||||
|
"""Shorten the tree representation to a given number of lines around a given line index."""
|
||||||
|
lines = tree_repr.split(sep="\n")
|
||||||
|
start_idx = max(0, line_index - max_lines // 2)
|
||||||
|
end_idx = min(len(lines), line_index + max_lines // 2 + 1)
|
||||||
|
return "\n".join(lines[start_idx:end_idx])
|
||||||
|
|
||||||
|
def _generate_tree_repr(
|
||||||
|
self, node: TreeNode, /, *, prefix: str = "", is_last: bool = True, is_root: bool = True, depth: int = -1
|
||||||
|
) -> str:
|
||||||
|
if depth == 0 and node["children"]:
|
||||||
|
return f"{prefix}{'└── ' if is_last else '├── '}{node['value']} ..."
|
||||||
|
|
||||||
|
if depth > 0:
|
||||||
|
depth -= 1
|
||||||
|
|
||||||
|
tree_icon: str = "" if is_root else ("└── " if is_last else "├── ")
|
||||||
|
counts: DefaultDict[str, int] = defaultdict(int)
|
||||||
|
|
||||||
|
for child in node["children"]:
|
||||||
|
counts[child["class_name"]] += 1
|
||||||
|
|
||||||
|
instance_counts: DefaultDict[str, int] = defaultdict(int)
|
||||||
|
lines = [f"{prefix}{tree_icon}{node['value']}"]
|
||||||
|
new_prefix: str = " " if is_last else "│ "
|
||||||
|
|
||||||
|
for i, child in enumerate(iterable=node["children"]):
|
||||||
|
instance_counts[child["class_name"]] += 1
|
||||||
|
|
||||||
|
if counts[child["class_name"]] > 1:
|
||||||
|
child_value = f"{child['value']} #{instance_counts[child['class_name']]}"
|
||||||
|
else:
|
||||||
|
child_value = child["value"]
|
||||||
|
|
||||||
|
child_str = self._generate_tree_repr(
|
||||||
|
{"value": child_value, "class_name": child["class_name"], "children": child["children"]},
|
||||||
|
prefix=prefix + new_prefix,
|
||||||
|
is_last=i == len(node["children"]) - 1,
|
||||||
|
is_root=False,
|
||||||
|
depth=depth,
|
||||||
|
)
|
||||||
|
|
||||||
|
if child_str:
|
||||||
|
lines.append(child_str)
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
def _module_to_tree(self, module: Module) -> TreeNode:
|
||||||
|
match (module._tag, module._show_only_tag()): # pyright: ignore[reportPrivateUsage]
|
||||||
|
case ("", False):
|
||||||
|
value = str(module)
|
||||||
|
case (_, True):
|
||||||
|
value = f"({module._tag})" # pyright: ignore[reportPrivateUsage]
|
||||||
|
case (_, False):
|
||||||
|
value = f"({module._tag}) {module}" # pyright: ignore[reportPrivateUsage]
|
||||||
|
|
||||||
|
class_name = module.__class__.__name__
|
||||||
|
|
||||||
|
node: TreeNode = {"value": value, "class_name": class_name, "children": []}
|
||||||
|
for child in module.children():
|
||||||
|
node["children"].append(self._module_to_tree(module=child)) # type: ignore
|
||||||
|
return node
|
||||||
|
|
||||||
|
def _fold_successive_identical(self, node: TreeNode) -> None:
|
||||||
|
i = 0
|
||||||
|
while i < len(node["children"]):
|
||||||
|
j = i
|
||||||
|
while j < len(node["children"]) and node["children"][i] == node["children"][j]:
|
||||||
|
j += 1
|
||||||
|
count = j - i
|
||||||
|
if count > 1:
|
||||||
|
node["children"][i]["value"] += f" (x{count})"
|
||||||
|
del node["children"][i + 1 : j]
|
||||||
|
self._fold_successive_identical(node=node["children"][i])
|
||||||
|
i += 1
|
88
imaginairy/vendored/refiners/fluxion/layers/norm.py
Normal file
88
imaginairy/vendored/refiners/fluxion/layers/norm.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
from jaxtyping import Float
|
||||||
|
from torch import Tensor, device as Device, dtype as DType, nn, ones, sqrt, zeros
|
||||||
|
|
||||||
|
from imaginairy.vendored.refiners.fluxion.layers.module import Module, WeightedModule
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm(nn.LayerNorm, WeightedModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
normalized_shape: int | list[int],
|
||||||
|
eps: float = 0.00001,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__( # type: ignore
|
||||||
|
normalized_shape=normalized_shape,
|
||||||
|
eps=eps,
|
||||||
|
elementwise_affine=True, # otherwise not a WeightedModule
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GroupNorm(nn.GroupNorm, WeightedModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
num_groups: int,
|
||||||
|
eps: float = 1e-5,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__( # type: ignore
|
||||||
|
num_groups=num_groups,
|
||||||
|
num_channels=channels,
|
||||||
|
eps=eps,
|
||||||
|
affine=True, # otherwise not a WeightedModule
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
self.channels = channels
|
||||||
|
self.num_groups = num_groups
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm2d(WeightedModule):
|
||||||
|
"""
|
||||||
|
2D Layer Normalization module.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
channels (int): Number of channels in the input tensor.
|
||||||
|
eps (float, optional): A small constant for numerical stability. Default: 1e-6.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(ones(channels, device=device, dtype=dtype))
|
||||||
|
self.bias = nn.Parameter(zeros(channels, device=device, dtype=dtype))
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def forward(self, x: Float[Tensor, "batch channels height width"]) -> Float[Tensor, "batch channels height width"]:
|
||||||
|
x_mean = x.mean(1, keepdim=True)
|
||||||
|
x_var = (x - x_mean).pow(2).mean(1, keepdim=True)
|
||||||
|
x_norm = (x - x_mean) / sqrt(x_var + self.eps)
|
||||||
|
x_out = self.weight.unsqueeze(-1).unsqueeze(-1) * x_norm + self.bias.unsqueeze(-1).unsqueeze(-1)
|
||||||
|
return x_out
|
||||||
|
|
||||||
|
|
||||||
|
class InstanceNorm2d(nn.InstanceNorm2d, Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_features: int,
|
||||||
|
eps: float = 1e-05,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__( # type: ignore
|
||||||
|
num_features=num_features,
|
||||||
|
eps=eps,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
8
imaginairy/vendored/refiners/fluxion/layers/padding.py
Normal file
8
imaginairy/vendored/refiners/fluxion/layers/padding.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from imaginairy.vendored.refiners.fluxion.layers.module import Module
|
||||||
|
|
||||||
|
|
||||||
|
class ReflectionPad2d(nn.ReflectionPad2d, Module):
|
||||||
|
def __init__(self, padding: int) -> None:
|
||||||
|
super().__init__(padding=padding)
|
@ -0,0 +1,8 @@
|
|||||||
|
from torch.nn import PixelUnshuffle as _PixelUnshuffle
|
||||||
|
|
||||||
|
from imaginairy.vendored.refiners.fluxion.layers.module import Module
|
||||||
|
|
||||||
|
|
||||||
|
class PixelUnshuffle(_PixelUnshuffle, Module):
|
||||||
|
def __init__(self, downscale_factor: int):
|
||||||
|
_PixelUnshuffle.__init__(self, downscale_factor=downscale_factor)
|
99
imaginairy/vendored/refiners/fluxion/layers/sampling.py
Normal file
99
imaginairy/vendored/refiners/fluxion/layers/sampling.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
from torch import Size, Tensor, device as Device, dtype as DType
|
||||||
|
from torch.nn.functional import pad
|
||||||
|
|
||||||
|
from imaginairy.vendored.refiners.fluxion.layers.basics import Identity
|
||||||
|
from imaginairy.vendored.refiners.fluxion.layers.chain import Chain, Lambda, Parallel, SetContext, UseContext
|
||||||
|
from imaginairy.vendored.refiners.fluxion.layers.conv import Conv2d
|
||||||
|
from imaginairy.vendored.refiners.fluxion.layers.module import Module
|
||||||
|
from imaginairy.vendored.refiners.fluxion.utils import interpolate
|
||||||
|
|
||||||
|
|
||||||
|
class Downsample(Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
scale_factor: int,
|
||||||
|
padding: int = 0,
|
||||||
|
register_shape: bool = True,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
):
|
||||||
|
"""Downsamples the input by the given scale factor.
|
||||||
|
|
||||||
|
If register_shape is True, the input shape is registered in the context. It will throw an error if the context
|
||||||
|
sampling is not set or if the context does not contain a list.
|
||||||
|
"""
|
||||||
|
self.channels = channels
|
||||||
|
self.in_channels = channels
|
||||||
|
self.out_channels = channels
|
||||||
|
self.scale_factor = scale_factor
|
||||||
|
self.padding = padding
|
||||||
|
super().__init__(
|
||||||
|
Conv2d(
|
||||||
|
in_channels=channels,
|
||||||
|
out_channels=channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=scale_factor,
|
||||||
|
padding=padding,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if padding == 0:
|
||||||
|
zero_pad: Callable[[Tensor], Tensor] = lambda x: pad(x, (0, 1, 0, 1))
|
||||||
|
self.insert(0, Lambda(zero_pad))
|
||||||
|
if register_shape:
|
||||||
|
self.insert(0, SetContext(context="sampling", key="shapes", callback=self.register_shape))
|
||||||
|
|
||||||
|
def register_shape(self, shapes: list[Size], x: Tensor) -> None:
|
||||||
|
shapes.append(x.shape[2:])
|
||||||
|
|
||||||
|
|
||||||
|
class Interpolate(Module):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, shape: Size) -> Tensor:
|
||||||
|
return interpolate(x, shape)
|
||||||
|
|
||||||
|
|
||||||
|
class Upsample(Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
upsample_factor: int | None = None,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
):
|
||||||
|
"""Upsamples the input by the given scale factor.
|
||||||
|
|
||||||
|
If upsample_factor is None, the input shape is taken from the context. It will throw an error if the context
|
||||||
|
sampling is not set or if the context is empty (then you should use the dynamic version of Downsample).
|
||||||
|
"""
|
||||||
|
self.channels = channels
|
||||||
|
self.upsample_factor = upsample_factor
|
||||||
|
super().__init__(
|
||||||
|
Parallel(
|
||||||
|
Identity(),
|
||||||
|
(
|
||||||
|
Lambda(self._get_static_shape)
|
||||||
|
if upsample_factor is not None
|
||||||
|
else UseContext(context="sampling", key="shapes").compose(lambda x: x.pop())
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Interpolate(),
|
||||||
|
Conv2d(
|
||||||
|
in_channels=channels,
|
||||||
|
out_channels=channels,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_static_shape(self, x: Tensor) -> Size:
|
||||||
|
assert self.upsample_factor is not None
|
||||||
|
return Size([size * self.upsample_factor for size in x.shape[2:]])
|
644
imaginairy/vendored/refiners/fluxion/model_converter.py
Normal file
644
imaginairy/vendored/refiners/fluxion/model_converter.py
Normal file
@ -0,0 +1,644 @@
|
|||||||
|
from collections import defaultdict
|
||||||
|
from enum import Enum, auto
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, DefaultDict, TypedDict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor, nn
|
||||||
|
from torch.utils.hooks import RemovableHandle
|
||||||
|
|
||||||
|
from imaginairy.vendored.refiners.fluxion.utils import no_grad, norm, save_to_safetensors
|
||||||
|
|
||||||
|
TORCH_BASIC_LAYERS: list[type[nn.Module]] = [
|
||||||
|
nn.Conv1d,
|
||||||
|
nn.Conv2d,
|
||||||
|
nn.Conv3d,
|
||||||
|
nn.ConvTranspose1d,
|
||||||
|
nn.ConvTranspose2d,
|
||||||
|
nn.ConvTranspose3d,
|
||||||
|
nn.Linear,
|
||||||
|
nn.BatchNorm1d,
|
||||||
|
nn.BatchNorm2d,
|
||||||
|
nn.BatchNorm3d,
|
||||||
|
nn.LayerNorm,
|
||||||
|
nn.GroupNorm,
|
||||||
|
nn.Embedding,
|
||||||
|
nn.MaxPool2d,
|
||||||
|
nn.AvgPool2d,
|
||||||
|
nn.AdaptiveAvgPool2d,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
ModelTypeShape = tuple[type[nn.Module], tuple[torch.Size, ...]]
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleArgsDict(TypedDict):
|
||||||
|
"""Represents positional and keyword arguments passed to a module.
|
||||||
|
|
||||||
|
- `positional`: A tuple of positional arguments.
|
||||||
|
- `keyword`: A dictionary of keyword arguments.
|
||||||
|
"""
|
||||||
|
|
||||||
|
positional: tuple[Any, ...]
|
||||||
|
keyword: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class ConversionStage(Enum):
|
||||||
|
"""Represents the current stage of the conversion process.
|
||||||
|
|
||||||
|
- `INIT`: The conversion process has not started.
|
||||||
|
- `BASIC_LAYERS_MATCH`: The source and target models have the same number of basic layers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
INIT = auto()
|
||||||
|
BASIC_LAYERS_MATCH = auto()
|
||||||
|
SHAPE_AND_LAYERS_MATCH = auto()
|
||||||
|
MODELS_OUTPUT_AGREE = auto()
|
||||||
|
|
||||||
|
|
||||||
|
class ModelConverter:
|
||||||
|
ModuleArgs = tuple[Any, ...] | dict[str, Any] | ModuleArgsDict
|
||||||
|
stage: ConversionStage = ConversionStage.INIT
|
||||||
|
_stored_mapping: dict[str, str] | None = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
source_model: nn.Module,
|
||||||
|
target_model: nn.Module,
|
||||||
|
source_keys_to_skip: list[str] | None = None,
|
||||||
|
target_keys_to_skip: list[str] | None = None,
|
||||||
|
custom_layer_mapping: dict[type[nn.Module], type[nn.Module]] | None = None,
|
||||||
|
threshold: float = 1e-5,
|
||||||
|
skip_output_check: bool = False,
|
||||||
|
skip_init_check: bool = False,
|
||||||
|
verbose: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Create a ModelConverter.
|
||||||
|
|
||||||
|
- `source_model`: The model to convert from.
|
||||||
|
- `target_model`: The model to convert to.
|
||||||
|
- `source_keys_to_skip`: A list of keys to skip when tracing the source model.
|
||||||
|
- `target_keys_to_skip`: A list of keys to skip when tracing the target model.
|
||||||
|
- `custom_layer_mapping`: A dictionary mapping custom layer types between the source and target models.
|
||||||
|
- `threshold`: The threshold for comparing outputs between the source and target models.
|
||||||
|
- `skip_output_check`: Whether to skip comparing the outputs of the source and target models.
|
||||||
|
- `skip_init_check`: Whether to skip checking that the source and target models have the same number of basic
|
||||||
|
layers.
|
||||||
|
- `verbose`: Whether to print messages during the conversion process.
|
||||||
|
|
||||||
|
The conversion process consists of three stages:
|
||||||
|
|
||||||
|
1. Verify that the source and target models have the same number of basic layers.
|
||||||
|
2. Find matching shapes and layers between the source and target models.
|
||||||
|
3. Convert the source model's state_dict to match the target model's state_dict.
|
||||||
|
4. Compare the outputs of the source and target models.
|
||||||
|
|
||||||
|
The conversion process can be run multiple times, and will resume from the last stage.
|
||||||
|
|
||||||
|
### Example:
|
||||||
|
```
|
||||||
|
converter = ModelConverter(source_model=source, target_model=target, threshold=0.1, verbose=False)
|
||||||
|
is_converted = converter(args)
|
||||||
|
if is_converted:
|
||||||
|
converter.save_to_safetensors(path="test.pt")
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
self.source_model = source_model
|
||||||
|
self.target_model = target_model
|
||||||
|
self.source_keys_to_skip = source_keys_to_skip or []
|
||||||
|
self.target_keys_to_skip = target_keys_to_skip or []
|
||||||
|
self.custom_layer_mapping = custom_layer_mapping or {}
|
||||||
|
self.threshold = threshold
|
||||||
|
self.skip_output_check = skip_output_check
|
||||||
|
self.skip_init_check = skip_init_check
|
||||||
|
self.verbose = verbose
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (
|
||||||
|
f"ModelConverter(source_model={self.source_model.__class__.__name__},"
|
||||||
|
f" target_model={self.target_model.__class__.__name__}, stage={self.stage})"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __bool__(self) -> bool:
|
||||||
|
return self.stage.value >= 2 if self.skip_output_check else self.stage.value >= 3
|
||||||
|
|
||||||
|
def run(self, source_args: ModuleArgs, target_args: ModuleArgs | None = None) -> bool:
|
||||||
|
"""
|
||||||
|
Run the conversion process.
|
||||||
|
|
||||||
|
- `source_args`: The arguments to pass to the source model it can be either a tuple of positional arguments,
|
||||||
|
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args`
|
||||||
|
is not provided, these arguments will also be passed to the target model.
|
||||||
|
- `target_args`: The arguments to pass to the target model it can be either a tuple of positional arguments,
|
||||||
|
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
|
||||||
|
|
||||||
|
### Returns:
|
||||||
|
|
||||||
|
- `True` if the conversion process is done and the models agree.
|
||||||
|
|
||||||
|
The conversion process consists of three stages:
|
||||||
|
|
||||||
|
1. Verify that the source and target models have the same number of basic layers.
|
||||||
|
2. Find matching shapes and layers between the source and target models.
|
||||||
|
3. Convert the source model's state_dict to match the target model's state_dict.
|
||||||
|
4. Compare the outputs of the source and target models.
|
||||||
|
|
||||||
|
The conversion process can be run multiple times, and will resume from the last stage.
|
||||||
|
"""
|
||||||
|
if target_args is None:
|
||||||
|
target_args = source_args
|
||||||
|
|
||||||
|
match self.stage:
|
||||||
|
case ConversionStage.MODELS_OUTPUT_AGREE:
|
||||||
|
self._increment_stage()
|
||||||
|
return True
|
||||||
|
|
||||||
|
case ConversionStage.SHAPE_AND_LAYERS_MATCH if self._run_shape_and_layers_match_stage(
|
||||||
|
source_args=source_args, target_args=target_args
|
||||||
|
):
|
||||||
|
self._increment_stage()
|
||||||
|
return True
|
||||||
|
|
||||||
|
case ConversionStage.BASIC_LAYERS_MATCH if self._run_basic_layers_match_stage(
|
||||||
|
source_args=source_args, target_args=target_args
|
||||||
|
):
|
||||||
|
self._increment_stage()
|
||||||
|
return self.run(source_args=source_args, target_args=target_args)
|
||||||
|
|
||||||
|
case ConversionStage.INIT if self._run_init_stage():
|
||||||
|
self._increment_stage()
|
||||||
|
return self.run(source_args=source_args, target_args=target_args)
|
||||||
|
|
||||||
|
case _:
|
||||||
|
self._log(message=f"Conversion failed at stage {self.stage.value}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _increment_stage(self) -> None:
|
||||||
|
"""Increment the stage of the conversion process."""
|
||||||
|
match self.stage:
|
||||||
|
case ConversionStage.INIT:
|
||||||
|
self.stage = ConversionStage.BASIC_LAYERS_MATCH
|
||||||
|
self._log(
|
||||||
|
message=(
|
||||||
|
"Stage 0 -> 1 - Models have the same number of basic layers. Finding matching shapes and"
|
||||||
|
" layers..."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
case ConversionStage.BASIC_LAYERS_MATCH:
|
||||||
|
self.stage = ConversionStage.SHAPE_AND_LAYERS_MATCH
|
||||||
|
self._log(
|
||||||
|
message=(
|
||||||
|
"Stage 1 -> 2 - Shape of both models agree. Applying state_dict to target model. Comparing"
|
||||||
|
" models..."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
case ConversionStage.SHAPE_AND_LAYERS_MATCH:
|
||||||
|
if self.skip_output_check:
|
||||||
|
self._log(
|
||||||
|
message=(
|
||||||
|
"Stage 2 - Nothing to do. Skipping output check. If you want to compare the outputs, set"
|
||||||
|
" `skip_output_check` to `False`"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.stage = ConversionStage.MODELS_OUTPUT_AGREE
|
||||||
|
self._log(
|
||||||
|
message=(
|
||||||
|
"Stage 2 -> 3 - Conversion is done and source and target models agree: you can export the"
|
||||||
|
" converted model using `save_to_safetensors`"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
case ConversionStage.MODELS_OUTPUT_AGREE:
|
||||||
|
self._log(
|
||||||
|
message=(
|
||||||
|
"Stage 3 - Nothing to do. Conversion is done and source and target models agree: you can export"
|
||||||
|
" the converted model using `save_to_safetensors`"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_state_dict(self) -> dict[str, Tensor]:
|
||||||
|
"""Get the converted state_dict."""
|
||||||
|
if not self:
|
||||||
|
raise ValueError("The conversion process is not done yet. Run `converter(args)` first.")
|
||||||
|
return self.target_model.state_dict()
|
||||||
|
|
||||||
|
def get_mapping(self) -> dict[str, str]:
|
||||||
|
"""Get the mapping between the source and target models' state_dicts."""
|
||||||
|
if not self:
|
||||||
|
raise ValueError("The conversion process is not done yet. Run `converter(args)` first.")
|
||||||
|
assert self._stored_mapping is not None, "Mapping is not stored"
|
||||||
|
return self._stored_mapping
|
||||||
|
|
||||||
|
def save_to_safetensors(self, path: Path | str, metadata: dict[str, str] | None = None, half: bool = False) -> None:
|
||||||
|
"""Save the converted model to a SafeTensors file.
|
||||||
|
|
||||||
|
This method can only be called after the conversion process is done.
|
||||||
|
|
||||||
|
- `path`: The path to save the converted model to.
|
||||||
|
- `metadata`: Metadata to save with the converted model.
|
||||||
|
- `half`: Whether to save the converted model as half precision.
|
||||||
|
|
||||||
|
### Raises:
|
||||||
|
- `ValueError` if the conversion process is not done yet. Run `converter(args)` first.
|
||||||
|
"""
|
||||||
|
if not self:
|
||||||
|
raise ValueError("The conversion process is not done yet. Run `converter(args)` first.")
|
||||||
|
state_dict = self.get_state_dict()
|
||||||
|
if half:
|
||||||
|
state_dict = {key: value.half() for key, value in state_dict.items()}
|
||||||
|
save_to_safetensors(path=path, tensors=state_dict, metadata=metadata)
|
||||||
|
|
||||||
|
def map_state_dicts(
|
||||||
|
self,
|
||||||
|
source_args: ModuleArgs,
|
||||||
|
target_args: ModuleArgs | None = None,
|
||||||
|
) -> dict[str, str] | None:
|
||||||
|
"""
|
||||||
|
Find a mapping between the source and target models' state_dicts.
|
||||||
|
|
||||||
|
- `source_args`: The arguments to pass to the source model it can be either a tuple of positional arguments,
|
||||||
|
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args`
|
||||||
|
is not provided, these arguments will also be passed to the target model.
|
||||||
|
- `target_args`: The arguments to pass to the target model it can be either a tuple of positional arguments,
|
||||||
|
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
|
||||||
|
|
||||||
|
### Returns:
|
||||||
|
- A dictionary mapping keys in the target model's state_dict to keys in the source model's state_dict.
|
||||||
|
"""
|
||||||
|
if target_args is None:
|
||||||
|
target_args = source_args
|
||||||
|
|
||||||
|
source_order = self._trace_module_execution_order(
|
||||||
|
module=self.source_model, args=source_args, keys_to_skip=self.source_keys_to_skip
|
||||||
|
)
|
||||||
|
target_order = self._trace_module_execution_order(
|
||||||
|
module=self.target_model, args=target_args, keys_to_skip=self.target_keys_to_skip
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self._assert_shapes_aligned(source_order=source_order, target_order=target_order):
|
||||||
|
return None
|
||||||
|
|
||||||
|
mapping: dict[str, str] = {}
|
||||||
|
for source_type_shape in source_order:
|
||||||
|
source_keys = source_order[source_type_shape]
|
||||||
|
target_type_shape = source_type_shape
|
||||||
|
if not self._is_torch_basic_layer(module_type=source_type_shape[0]):
|
||||||
|
for source_custom_type, target_custom_type in self.custom_layer_mapping.items():
|
||||||
|
if source_custom_type == source_type_shape[0]:
|
||||||
|
target_type_shape = (target_custom_type, source_type_shape[1])
|
||||||
|
break
|
||||||
|
|
||||||
|
target_keys = target_order[target_type_shape]
|
||||||
|
mapping.update(zip(target_keys, source_keys))
|
||||||
|
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
def compare_models(
|
||||||
|
self,
|
||||||
|
source_args: ModuleArgs,
|
||||||
|
target_args: ModuleArgs | None = None,
|
||||||
|
threshold: float = 1e-5,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Compare the outputs of the source and target models.
|
||||||
|
|
||||||
|
- `source_args`: The arguments to pass to the source model it can be either a tuple of positional arguments,
|
||||||
|
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys. If `target_args`
|
||||||
|
is not provided, these arguments will also be passed to the target model.
|
||||||
|
- `target_args`: The arguments to pass to the target model it can be either a tuple of positional arguments,
|
||||||
|
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
|
||||||
|
- `threshold`: The threshold for comparing outputs between the source and target models.
|
||||||
|
"""
|
||||||
|
if target_args is None:
|
||||||
|
target_args = source_args
|
||||||
|
|
||||||
|
source_outputs = self._collect_layers_outputs(
|
||||||
|
module=self.source_model, args=source_args, keys_to_skip=self.source_keys_to_skip
|
||||||
|
)
|
||||||
|
target_outputs = self._collect_layers_outputs(
|
||||||
|
module=self.target_model, args=target_args, keys_to_skip=self.target_keys_to_skip
|
||||||
|
)
|
||||||
|
|
||||||
|
diff, prev_source_key, prev_target_key = None, None, None
|
||||||
|
for (source_key, source_output), (target_key, target_output) in zip(source_outputs, target_outputs):
|
||||||
|
diff = norm(source_output - target_output.reshape(shape=source_output.shape)).item()
|
||||||
|
if diff > threshold:
|
||||||
|
self._log(
|
||||||
|
f"Models diverged between {prev_source_key} and {source_key}, and between {prev_target_key} and"
|
||||||
|
f" {target_key}, difference in norm: {diff}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
prev_source_key, prev_target_key = source_key, target_key
|
||||||
|
|
||||||
|
self._log(message=f"Models agree. Difference in norm: {diff}")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _run_init_stage(self) -> bool:
|
||||||
|
"""Run the init stage of the conversion process."""
|
||||||
|
if self.skip_init_check:
|
||||||
|
self._log(
|
||||||
|
message=(
|
||||||
|
"Skipping init check. If you want to check the number of basic layers, set `skip_init_check` to"
|
||||||
|
" `False`"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
is_count_correct = self._verify_basic_layers_count()
|
||||||
|
is_not_missing_layers = self._verify_missing_basic_layers()
|
||||||
|
|
||||||
|
return is_count_correct and is_not_missing_layers
|
||||||
|
|
||||||
|
def _run_basic_layers_match_stage(self, source_args: ModuleArgs, target_args: ModuleArgs | None) -> bool:
|
||||||
|
"""Run the basic layers match stage of the conversion process."""
|
||||||
|
mapping = self.map_state_dicts(source_args=source_args, target_args=target_args)
|
||||||
|
self._stored_mapping = mapping
|
||||||
|
if mapping is None:
|
||||||
|
self._log(message="Models do not have matching shapes.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
source_state_dict = self.source_model.state_dict()
|
||||||
|
target_state_dict = self.target_model.state_dict()
|
||||||
|
converted_state_dict = self._convert_state_dict(
|
||||||
|
source_state_dict=source_state_dict, target_state_dict=target_state_dict, state_dict_mapping=mapping
|
||||||
|
)
|
||||||
|
self.target_model.load_state_dict(state_dict=converted_state_dict)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _run_shape_and_layers_match_stage(self, source_args: ModuleArgs, target_args: ModuleArgs | None) -> bool:
|
||||||
|
"""Run the shape and layers match stage of the conversion process."""
|
||||||
|
if self.skip_output_check:
|
||||||
|
self._log(
|
||||||
|
message="Skipping output check. If you want to compare the outputs, set `skip_output_check` to `False`"
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
if self.compare_models(source_args=source_args, target_args=target_args, threshold=self.threshold):
|
||||||
|
self._log(message="Models agree. You can export the converted model using `save_to_safetensors`")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
self._log(message="Models do not agree. Try to increase the threshold or modify the models.")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
self._log(message=f"An error occurred while comparing the models: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _log(self, message: str) -> None:
|
||||||
|
"""Print a message if `verbose` is `True`."""
|
||||||
|
if self.verbose:
|
||||||
|
print(message)
|
||||||
|
|
||||||
|
def _debug_print_shapes(
|
||||||
|
self,
|
||||||
|
shape: ModelTypeShape,
|
||||||
|
source_keys: list[str],
|
||||||
|
target_keys: list[str],
|
||||||
|
) -> None:
|
||||||
|
"""Print the shapes of the sub-modules in `source_keys` and `target_keys`."""
|
||||||
|
self._log(message=f"{shape}")
|
||||||
|
max_len = max(len(source_keys), len(target_keys))
|
||||||
|
for i in range(max_len):
|
||||||
|
source_key = source_keys[i] if i < len(source_keys) else "---"
|
||||||
|
target_key = target_keys[i] if i < len(target_keys) else "---"
|
||||||
|
self._log(f"\t{source_key}\t{target_key}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _unpack_module_args(module_args: ModuleArgs) -> tuple[tuple[Any, ...], dict[str, Any]]:
|
||||||
|
"""Unpack the positional and keyword arguments passed to a module."""
|
||||||
|
match module_args:
|
||||||
|
case tuple(positional_args):
|
||||||
|
keyword_args: dict[str, Any] = {}
|
||||||
|
case {"positional": positional_args, "keyword": keyword_args}:
|
||||||
|
pass
|
||||||
|
case _:
|
||||||
|
positional_args = ()
|
||||||
|
keyword_args = dict(**module_args)
|
||||||
|
|
||||||
|
return positional_args, keyword_args
|
||||||
|
|
||||||
|
def _is_torch_basic_layer(self, module_type: type[nn.Module]) -> bool:
|
||||||
|
"""Check if a module type is a subclass of a torch basic layer."""
|
||||||
|
return any(issubclass(module_type, torch_basic_layer) for torch_basic_layer in TORCH_BASIC_LAYERS)
|
||||||
|
|
||||||
|
def _infer_basic_layer_type(self, module: nn.Module) -> type[nn.Module] | None:
|
||||||
|
"""Infer the type of a basic layer."""
|
||||||
|
layer_types = (
|
||||||
|
set(self.custom_layer_mapping.keys()) | set(self.custom_layer_mapping.values()) | set(TORCH_BASIC_LAYERS)
|
||||||
|
)
|
||||||
|
for layer_type in layer_types:
|
||||||
|
if isinstance(module, layer_type):
|
||||||
|
return layer_type
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_module_signature(self, module: nn.Module) -> ModelTypeShape:
|
||||||
|
"""Get the signature of a module."""
|
||||||
|
layer_type = self._infer_basic_layer_type(module=module)
|
||||||
|
assert layer_type is not None, f"Module {module} is not a basic layer"
|
||||||
|
param_shapes = [p.shape for p in module.parameters()]
|
||||||
|
return (layer_type, tuple(param_shapes))
|
||||||
|
|
||||||
|
def _count_basic_layers(self, module: nn.Module) -> dict[type[nn.Module], int]:
|
||||||
|
"""Count the number of basic layers in a module."""
|
||||||
|
count: DefaultDict[type[nn.Module], int] = defaultdict(int)
|
||||||
|
for submodule in module.modules():
|
||||||
|
layer_type = self._infer_basic_layer_type(module=submodule)
|
||||||
|
if layer_type is not None:
|
||||||
|
count[layer_type] += 1
|
||||||
|
|
||||||
|
return count
|
||||||
|
|
||||||
|
def _verify_basic_layers_count(self) -> bool:
|
||||||
|
"""Verify that the source and target models have the same number of basic layers."""
|
||||||
|
source_layers = self._count_basic_layers(module=self.source_model)
|
||||||
|
target_layers = self._count_basic_layers(module=self.target_model)
|
||||||
|
|
||||||
|
reverse_mapping = {v: k for k, v in self.custom_layer_mapping.items()}
|
||||||
|
|
||||||
|
diff: dict[type[nn.Module], tuple[int, int]] = {}
|
||||||
|
for layer_type, source_count in source_layers.items():
|
||||||
|
target_type = self.custom_layer_mapping.get(layer_type, layer_type)
|
||||||
|
target_count = target_layers.get(target_type, 0)
|
||||||
|
|
||||||
|
if source_count != target_count:
|
||||||
|
diff[layer_type] = (source_count, target_count)
|
||||||
|
|
||||||
|
for layer_type, target_count in target_layers.items():
|
||||||
|
source_type = reverse_mapping.get(layer_type, layer_type)
|
||||||
|
source_count = source_layers.get(source_type, 0)
|
||||||
|
|
||||||
|
if source_count != target_count:
|
||||||
|
diff[layer_type] = (source_count, target_count)
|
||||||
|
|
||||||
|
if diff:
|
||||||
|
message = "Models do not have the same number of basic layers:\n"
|
||||||
|
for layer_type, counts in diff.items():
|
||||||
|
message += f" {layer_type}: Source {counts[0]} - Target {counts[1]}\n"
|
||||||
|
self._log(message=message.strip())
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _is_weighted_leaf_module(self, module: nn.Module) -> bool:
|
||||||
|
"""Check if a module is a leaf module with weights."""
|
||||||
|
return next(module.parameters(), None) is not None and next(module.children(), None) is None
|
||||||
|
|
||||||
|
def _check_for_missing_basic_layers(self, module: nn.Module) -> list[type[nn.Module]]:
|
||||||
|
"""Check if a module has weighted leaf modules that are not basic layers."""
|
||||||
|
return [
|
||||||
|
type(submodule)
|
||||||
|
for submodule in module.modules()
|
||||||
|
if self._is_weighted_leaf_module(module=submodule) and not self._infer_basic_layer_type(module=submodule)
|
||||||
|
]
|
||||||
|
|
||||||
|
def _verify_missing_basic_layers(self) -> bool:
|
||||||
|
"""Verify that the source and target models do not have missing basic layers."""
|
||||||
|
missing_source_layers = self._check_for_missing_basic_layers(module=self.source_model)
|
||||||
|
missing_target_layers = self._check_for_missing_basic_layers(module=self.target_model)
|
||||||
|
|
||||||
|
if missing_source_layers or missing_target_layers:
|
||||||
|
self._log(
|
||||||
|
message=(
|
||||||
|
"Models might have missing basic layers. If you want to skip this check, set"
|
||||||
|
f" `skip_init_check` to `True`: {missing_source_layers}, {missing_target_layers}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def _trace_module_execution_order(
|
||||||
|
self,
|
||||||
|
module: nn.Module,
|
||||||
|
args: ModuleArgs,
|
||||||
|
keys_to_skip: list[str],
|
||||||
|
) -> dict[ModelTypeShape, list[str]]:
|
||||||
|
"""
|
||||||
|
Execute a forward pass and store the order of execution of specific sub-modules.
|
||||||
|
|
||||||
|
- `module`: The module to trace.
|
||||||
|
- `args`: The arguments to pass to the module it can be either a tuple of positional arguments,
|
||||||
|
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
|
||||||
|
- `keys_to_skip`: A list of keys to skip when tracing the module.
|
||||||
|
|
||||||
|
### Returns:
|
||||||
|
- A dictionary mapping the signature of each sub-module to a list of keys in the module's `named_modules`
|
||||||
|
"""
|
||||||
|
submodule_to_key: dict[nn.Module, str] = {}
|
||||||
|
execution_order: defaultdict[ModelTypeShape, list[str]] = defaultdict(list)
|
||||||
|
|
||||||
|
def collect_execution_order_hook(layer: nn.Module, *_: Any) -> None:
|
||||||
|
layer_signature = self.get_module_signature(module=layer)
|
||||||
|
execution_order[layer_signature].append(submodule_to_key[layer])
|
||||||
|
|
||||||
|
hooks: list[RemovableHandle] = []
|
||||||
|
named_modules: list[tuple[str, nn.Module]] = module.named_modules() # type: ignore
|
||||||
|
for name, submodule in named_modules:
|
||||||
|
if (self._infer_basic_layer_type(module=submodule) is not None) and name not in keys_to_skip:
|
||||||
|
submodule_to_key[submodule] = name # type: ignore
|
||||||
|
hook = submodule.register_forward_hook(hook=collect_execution_order_hook)
|
||||||
|
hooks.append(hook)
|
||||||
|
|
||||||
|
positional_args, keyword_args = self._unpack_module_args(module_args=args)
|
||||||
|
module(*positional_args, **keyword_args)
|
||||||
|
|
||||||
|
for hook in hooks:
|
||||||
|
hook.remove()
|
||||||
|
|
||||||
|
return dict(execution_order)
|
||||||
|
|
||||||
|
def _assert_shapes_aligned(
|
||||||
|
self, source_order: dict[ModelTypeShape, list[str]], target_order: dict[ModelTypeShape, list[str]]
|
||||||
|
) -> bool:
|
||||||
|
"""Assert that the shapes of the sub-modules in `source_order` and `target_order` are aligned."""
|
||||||
|
model_type_shapes = set(source_order.keys()) | set(target_order.keys())
|
||||||
|
|
||||||
|
default_type_shapes = [
|
||||||
|
type_shape for type_shape in model_type_shapes if self._is_torch_basic_layer(module_type=type_shape[0])
|
||||||
|
]
|
||||||
|
|
||||||
|
shape_mismatched = False
|
||||||
|
|
||||||
|
for model_type_shape in default_type_shapes:
|
||||||
|
source_keys = source_order.get(model_type_shape, [])
|
||||||
|
target_keys = target_order.get(model_type_shape, [])
|
||||||
|
|
||||||
|
if len(source_keys) != len(target_keys):
|
||||||
|
shape_mismatched = True
|
||||||
|
self._debug_print_shapes(shape=model_type_shape, source_keys=source_keys, target_keys=target_keys)
|
||||||
|
|
||||||
|
for source_custom_type in self.custom_layer_mapping.keys():
|
||||||
|
# iterate over all type_shapes that have the same type as source_custom_type
|
||||||
|
for source_type_shape in [
|
||||||
|
type_shape for type_shape in model_type_shapes if type_shape[0] == source_custom_type
|
||||||
|
]:
|
||||||
|
source_keys = source_order.get(source_type_shape, [])
|
||||||
|
target_custom_type = self.custom_layer_mapping[source_custom_type]
|
||||||
|
target_type_shape = (target_custom_type, source_type_shape[1])
|
||||||
|
target_keys = target_order.get(target_type_shape, [])
|
||||||
|
|
||||||
|
if len(source_keys) != len(target_keys):
|
||||||
|
shape_mismatched = True
|
||||||
|
self._debug_print_shapes(shape=source_type_shape, source_keys=source_keys, target_keys=target_keys)
|
||||||
|
|
||||||
|
return not shape_mismatched
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _convert_state_dict(
|
||||||
|
source_state_dict: dict[str, Tensor], target_state_dict: dict[str, Tensor], state_dict_mapping: dict[str, str]
|
||||||
|
) -> dict[str, Tensor]:
|
||||||
|
"""Convert the source model's state_dict to match the target model's state_dict."""
|
||||||
|
converted_state_dict: dict[str, Tensor] = {}
|
||||||
|
for target_key in target_state_dict:
|
||||||
|
target_prefix, suffix = target_key.rsplit(sep=".", maxsplit=1)
|
||||||
|
source_prefix = state_dict_mapping[target_prefix]
|
||||||
|
source_key = ".".join([source_prefix, suffix])
|
||||||
|
converted_state_dict[target_key] = source_state_dict[source_key]
|
||||||
|
|
||||||
|
return converted_state_dict
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def _collect_layers_outputs(
|
||||||
|
self, module: nn.Module, args: ModuleArgs, keys_to_skip: list[str]
|
||||||
|
) -> list[tuple[str, Tensor]]:
|
||||||
|
"""
|
||||||
|
Execute a forward pass and store the output of specific sub-modules.
|
||||||
|
|
||||||
|
- `module`: The module to trace.
|
||||||
|
- `args`: The arguments to pass to the module it can be either a tuple of positional arguments,
|
||||||
|
a dictionary of keyword arguments, or a dictionary with `positional` and `keyword` keys.
|
||||||
|
- `keys_to_skip`: A list of keys to skip when tracing the module.
|
||||||
|
|
||||||
|
### Returns:
|
||||||
|
- A list of tuples containing the key of each sub-module and its output.
|
||||||
|
|
||||||
|
### Note:
|
||||||
|
- The output of each sub-module is cloned to avoid memory leaks.
|
||||||
|
"""
|
||||||
|
submodule_to_key: dict[nn.Module, str] = {}
|
||||||
|
execution_order: list[tuple[str, Tensor]] = []
|
||||||
|
|
||||||
|
def collect_execution_order_hook(layer: nn.Module, _: Any, output: Tensor) -> None:
|
||||||
|
execution_order.append((submodule_to_key[layer], output.clone()))
|
||||||
|
|
||||||
|
hooks: list[RemovableHandle] = []
|
||||||
|
named_modules: list[tuple[str, nn.Module]] = module.named_modules() # type: ignore
|
||||||
|
for name, submodule in named_modules:
|
||||||
|
if (self._infer_basic_layer_type(module=submodule) is not None) and name not in keys_to_skip:
|
||||||
|
submodule_to_key[submodule] = name # type: ignore
|
||||||
|
hook = submodule.register_forward_hook(hook=collect_execution_order_hook)
|
||||||
|
hooks.append(hook)
|
||||||
|
|
||||||
|
positional_args, keyword_args = self._unpack_module_args(module_args=args)
|
||||||
|
module(*positional_args, **keyword_args)
|
||||||
|
|
||||||
|
for hook in hooks:
|
||||||
|
hook.remove()
|
||||||
|
|
||||||
|
return execution_order
|
206
imaginairy/vendored/refiners/fluxion/utils.py
Normal file
206
imaginairy/vendored/refiners/fluxion/utils.py
Normal file
@ -0,0 +1,206 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Iterable, Literal, TypeVar
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from jaxtyping import Float
|
||||||
|
from numpy import array, float32
|
||||||
|
from PIL import Image
|
||||||
|
from safetensors import safe_open as _safe_open # type: ignore
|
||||||
|
from safetensors.torch import save_file as _save_file # type: ignore
|
||||||
|
from torch import (
|
||||||
|
Tensor,
|
||||||
|
device as Device,
|
||||||
|
dtype as DType,
|
||||||
|
manual_seed as _manual_seed, # type: ignore
|
||||||
|
no_grad as _no_grad, # type: ignore
|
||||||
|
norm as _norm, # type: ignore
|
||||||
|
)
|
||||||
|
from torch.nn.functional import conv2d, interpolate as _interpolate, pad as _pad # type: ignore
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
E = TypeVar("E")
|
||||||
|
|
||||||
|
|
||||||
|
def norm(x: Tensor) -> Tensor:
|
||||||
|
return _norm(x) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def manual_seed(seed: int) -> None:
|
||||||
|
_manual_seed(seed)
|
||||||
|
|
||||||
|
|
||||||
|
class no_grad(_no_grad):
|
||||||
|
def __new__(cls, orig_func: Any | None = None) -> "no_grad": # type: ignore
|
||||||
|
return object.__new__(cls)
|
||||||
|
|
||||||
|
|
||||||
|
def pad(x: Tensor, pad: Iterable[int], value: float = 0.0, mode: str = "constant") -> Tensor:
|
||||||
|
return _pad(input=x, pad=pad, value=value, mode=mode) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def interpolate(x: Tensor, factor: float | torch.Size, mode: str = "nearest") -> Tensor:
|
||||||
|
return (
|
||||||
|
_interpolate(x, scale_factor=factor, mode=mode)
|
||||||
|
if isinstance(factor, float | int)
|
||||||
|
else _interpolate(x, size=factor, mode=mode)
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
# Adapted from https://github.com/pytorch/vision/blob/main/torchvision/transforms/_functional_tensor.py
|
||||||
|
def normalize(
|
||||||
|
tensor: Float[Tensor, "*batch channels height width"], mean: list[float], std: list[float]
|
||||||
|
) -> Float[Tensor, "*batch channels height width"]:
|
||||||
|
assert tensor.is_floating_point()
|
||||||
|
assert tensor.ndim >= 3
|
||||||
|
|
||||||
|
dtype = tensor.dtype
|
||||||
|
pixel_mean = torch.tensor(mean, dtype=dtype, device=tensor.device).view(-1, 1, 1)
|
||||||
|
pixel_std = torch.tensor(std, dtype=dtype, device=tensor.device).view(-1, 1, 1)
|
||||||
|
if (pixel_std == 0).any():
|
||||||
|
raise ValueError(f"std evaluated to zero after conversion to {dtype}, leading to division by zero.")
|
||||||
|
|
||||||
|
return (tensor - pixel_mean) / pixel_std
|
||||||
|
|
||||||
|
|
||||||
|
# Adapted from https://github.com/pytorch/vision/blob/main/torchvision/transforms/_functional_tensor.py
|
||||||
|
def gaussian_blur(
|
||||||
|
tensor: Float[Tensor, "*batch channels height width"],
|
||||||
|
kernel_size: int | tuple[int, int],
|
||||||
|
sigma: float | tuple[float, float] | None = None,
|
||||||
|
) -> Float[Tensor, "*batch channels height width"]:
|
||||||
|
assert torch.is_floating_point(tensor)
|
||||||
|
|
||||||
|
def get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Float[Tensor, "kernel_size"]:
|
||||||
|
ksize_half = (kernel_size - 1) * 0.5
|
||||||
|
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
|
||||||
|
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
|
||||||
|
kernel1d = pdf / pdf.sum()
|
||||||
|
return kernel1d
|
||||||
|
|
||||||
|
def get_gaussian_kernel2d(
|
||||||
|
kernel_size_x: int, kernel_size_y: int, sigma_x: float, sigma_y: float, dtype: DType, device: Device
|
||||||
|
) -> Float[Tensor, "kernel_size_y kernel_size_x"]:
|
||||||
|
kernel1d_x = get_gaussian_kernel1d(kernel_size_x, sigma_x).to(device, dtype=dtype)
|
||||||
|
kernel1d_y = get_gaussian_kernel1d(kernel_size_y, sigma_y).to(device, dtype=dtype)
|
||||||
|
kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :])
|
||||||
|
return kernel2d
|
||||||
|
|
||||||
|
def default_sigma(kernel_size: int) -> float:
|
||||||
|
return kernel_size * 0.15 + 0.35
|
||||||
|
|
||||||
|
if isinstance(kernel_size, int):
|
||||||
|
kx, ky = kernel_size, kernel_size
|
||||||
|
else:
|
||||||
|
kx, ky = kernel_size
|
||||||
|
|
||||||
|
if sigma is None:
|
||||||
|
sx, sy = default_sigma(kx), default_sigma(ky)
|
||||||
|
elif isinstance(sigma, float):
|
||||||
|
sx, sy = sigma, sigma
|
||||||
|
else:
|
||||||
|
assert isinstance(sigma, tuple)
|
||||||
|
sx, sy = sigma
|
||||||
|
|
||||||
|
channels = tensor.shape[-3]
|
||||||
|
kernel = get_gaussian_kernel2d(kx, ky, sx, sy, dtype=tensor.dtype, device=tensor.device)
|
||||||
|
kernel = kernel.expand(channels, 1, kernel.shape[0], kernel.shape[1])
|
||||||
|
|
||||||
|
# pad = (left, right, top, bottom)
|
||||||
|
tensor = pad(tensor, pad=(kx // 2, kx // 2, ky // 2, ky // 2), mode="reflect")
|
||||||
|
tensor = conv2d(tensor, weight=kernel, groups=channels)
|
||||||
|
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtype: DType | None = None) -> Tensor:
|
||||||
|
"""
|
||||||
|
Convert a PIL Image to a Tensor.
|
||||||
|
|
||||||
|
If the image is in mode `RGB` the tensor will have shape `[3, H, W]`, otherwise
|
||||||
|
`[1, H, W]` for mode `L` (grayscale) or `[4, H, W]` for mode `RGBA`.
|
||||||
|
|
||||||
|
Values are clamped to the range `[0, 1]`.
|
||||||
|
"""
|
||||||
|
image_tensor = torch.tensor(array(image).astype(float32) / 255.0, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
match image.mode:
|
||||||
|
case "L":
|
||||||
|
image_tensor = image_tensor.unsqueeze(0)
|
||||||
|
case "RGBA" | "RGB":
|
||||||
|
image_tensor = image_tensor.permute(2, 0, 1)
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Unsupported image mode: {image.mode}")
|
||||||
|
|
||||||
|
return image_tensor.unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
def tensor_to_image(tensor: Tensor) -> Image.Image:
|
||||||
|
"""
|
||||||
|
Convert a Tensor to a PIL Image.
|
||||||
|
|
||||||
|
The tensor must have shape `[1, channels, height, width]` where the number of
|
||||||
|
channels is either 1 (grayscale) or 3 (RGB) or 4 (RGBA).
|
||||||
|
|
||||||
|
Expected values are in the range `[0, 1]` and are clamped to this range.
|
||||||
|
"""
|
||||||
|
assert tensor.ndim == 4 and tensor.shape[0] == 1, f"Unsupported tensor shape: {tensor.shape}"
|
||||||
|
num_channels = tensor.shape[1]
|
||||||
|
tensor = tensor.clamp(0, 1).squeeze(0)
|
||||||
|
|
||||||
|
match num_channels:
|
||||||
|
case 1:
|
||||||
|
tensor = tensor.squeeze(0)
|
||||||
|
case 3 | 4:
|
||||||
|
tensor = tensor.permute(1, 2, 0)
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Unsupported number of channels: {num_channels}")
|
||||||
|
|
||||||
|
return Image.fromarray((tensor.cpu().numpy() * 255).astype("uint8")) # type: ignore[reportUnknownType]
|
||||||
|
|
||||||
|
|
||||||
|
def safe_open(
|
||||||
|
path: Path | str,
|
||||||
|
framework: Literal["pytorch", "tensorflow", "flax", "numpy"],
|
||||||
|
device: Device | str = "cpu",
|
||||||
|
) -> dict[str, Tensor]:
|
||||||
|
framework_mapping = {
|
||||||
|
"pytorch": "pt",
|
||||||
|
"tensorflow": "tf",
|
||||||
|
"flax": "flax",
|
||||||
|
"numpy": "numpy",
|
||||||
|
}
|
||||||
|
return _safe_open(str(path), framework=framework_mapping[framework], device=str(device)) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def load_from_safetensors(path: Path | str, device: Device | str = "cpu") -> dict[str, Tensor]:
|
||||||
|
with safe_open(path=path, framework="pytorch", device=device) as tensors: # type: ignore
|
||||||
|
return {key: tensors.get_tensor(key) for key in tensors.keys()} # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def load_metadata_from_safetensors(path: Path | str) -> dict[str, str] | None:
|
||||||
|
with safe_open(path=path, framework="pytorch") as tensors: # type: ignore
|
||||||
|
return tensors.metadata() # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def save_to_safetensors(path: Path | str, tensors: dict[str, Tensor], metadata: dict[str, str] | None = None) -> None:
|
||||||
|
_save_file(tensors, path, metadata) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def summarize_tensor(tensor: torch.Tensor, /) -> str:
|
||||||
|
return (
|
||||||
|
"Tensor("
|
||||||
|
+ ", ".join(
|
||||||
|
[
|
||||||
|
f"shape=({', '.join(map(str, tensor.shape))})",
|
||||||
|
f"dtype={str(object=tensor.dtype).removeprefix('torch.')}",
|
||||||
|
f"device={tensor.device}",
|
||||||
|
f"min={tensor.min():.2f}", # type: ignore
|
||||||
|
f"max={tensor.max():.2f}", # type: ignore
|
||||||
|
f"mean={tensor.mean():.2f}",
|
||||||
|
f"std={tensor.std():.2f}",
|
||||||
|
f"norm={norm(x=tensor):.2f}",
|
||||||
|
f"grad={tensor.requires_grad}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
+ ")"
|
||||||
|
)
|
Binary file not shown.
48
imaginairy/vendored/refiners/foundationals/clip/common.py
Normal file
48
imaginairy/vendored/refiners/foundationals/clip/common.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
from torch import Tensor, arange, device as Device, dtype as DType
|
||||||
|
|
||||||
|
import imaginairy.vendored.refiners.fluxion.layers as fl
|
||||||
|
|
||||||
|
|
||||||
|
class PositionalEncoder(fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_sequence_length: int,
|
||||||
|
embedding_dim: int,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.max_sequence_length = max_sequence_length
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
super().__init__(
|
||||||
|
fl.Lambda(func=self.get_position_ids),
|
||||||
|
fl.Embedding(
|
||||||
|
num_embeddings=max_sequence_length,
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def position_ids(self) -> Tensor:
|
||||||
|
return arange(end=self.max_sequence_length, device=self.device).reshape(1, -1)
|
||||||
|
|
||||||
|
def get_position_ids(self, x: Tensor) -> Tensor:
|
||||||
|
return self.position_ids[:, : x.shape[1]]
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int,
|
||||||
|
feedforward_dim: int,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.feedforward_dim = feedforward_dim
|
||||||
|
super().__init__(
|
||||||
|
fl.Linear(in_features=embedding_dim, out_features=feedforward_dim, device=device, dtype=dtype),
|
||||||
|
fl.GeLU(),
|
||||||
|
fl.Linear(in_features=feedforward_dim, out_features=embedding_dim, device=device, dtype=dtype),
|
||||||
|
)
|
148
imaginairy/vendored/refiners/foundationals/clip/concepts.py
Normal file
148
imaginairy/vendored/refiners/foundationals/clip/concepts.py
Normal file
@ -0,0 +1,148 @@
|
|||||||
|
import re
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor, cat, zeros
|
||||||
|
from torch.nn import Parameter
|
||||||
|
|
||||||
|
import imaginairy.vendored.refiners.fluxion.layers as fl
|
||||||
|
from imaginairy.vendored.refiners.fluxion.adapters.adapter import Adapter
|
||||||
|
from imaginairy.vendored.refiners.foundationals.clip.text_encoder import CLIPTextEncoder, TokenEncoder
|
||||||
|
from imaginairy.vendored.refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingExtender(fl.Chain, Adapter[TokenEncoder]):
|
||||||
|
old_weight: Parameter
|
||||||
|
new_weight: Parameter
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
target: TokenEncoder,
|
||||||
|
) -> None:
|
||||||
|
with self.setup_adapter(target):
|
||||||
|
super().__init__(fl.Lambda(func=self.lookup))
|
||||||
|
self.old_weight = cast(Parameter, target.weight)
|
||||||
|
self.new_weight = Parameter(
|
||||||
|
zeros([0, target.embedding_dim], device=target.device, dtype=target.dtype)
|
||||||
|
) # requires_grad=True by default
|
||||||
|
|
||||||
|
# Use F.embedding instead of nn.Embedding to make sure that gradients can only be computed for the new embeddings
|
||||||
|
def lookup(self, x: Tensor) -> Tensor:
|
||||||
|
# Concatenate old and new weights for dynamic embedding updates during training
|
||||||
|
return F.embedding(x, cat([self.old_weight, self.new_weight]))
|
||||||
|
|
||||||
|
def add_embedding(self, embedding: Tensor) -> None:
|
||||||
|
assert embedding.shape == (self.old_weight.shape[1],)
|
||||||
|
self.new_weight = Parameter(
|
||||||
|
cat([self.new_weight, embedding.unsqueeze(0).to(self.new_weight.device, self.new_weight.dtype)])
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_embeddings(self) -> int:
|
||||||
|
return self.old_weight.shape[0] + self.new_weight.shape[0]
|
||||||
|
|
||||||
|
|
||||||
|
class TokenExtender(fl.Chain, Adapter[CLIPTokenizer]):
|
||||||
|
def __init__(self, target: CLIPTokenizer) -> None:
|
||||||
|
with self.setup_adapter(target):
|
||||||
|
super().__init__(
|
||||||
|
CLIPTokenizer(
|
||||||
|
vocabulary_path=target.vocabulary_path,
|
||||||
|
sequence_length=target.sequence_length,
|
||||||
|
start_of_text_token_id=target.start_of_text_token_id,
|
||||||
|
end_of_text_token_id=target.end_of_text_token_id,
|
||||||
|
pad_token_id=target.pad_token_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_token(self, token: str, token_id: int) -> None:
|
||||||
|
token = token.lower()
|
||||||
|
tokenizer = self.ensure_find(CLIPTokenizer)
|
||||||
|
assert token_id not in tokenizer.token_to_id_mapping.values()
|
||||||
|
tokenizer.token_to_id_mapping[token] = token_id
|
||||||
|
current_pattern = tokenizer.token_pattern.pattern
|
||||||
|
new_pattern = re.escape(token) + "|" + current_pattern
|
||||||
|
tokenizer.token_pattern = re.compile(new_pattern, re.IGNORECASE)
|
||||||
|
# Define the keyword as its own smallest subtoken
|
||||||
|
tokenizer.byte_pair_encoding_cache[token] = token
|
||||||
|
|
||||||
|
|
||||||
|
class ConceptExtender(fl.Chain, Adapter[CLIPTextEncoder]):
|
||||||
|
"""
|
||||||
|
Extends the vocabulary of a CLIPTextEncoder with one or multiple new concepts, e.g. obtained via the Textual Inversion technique.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
import torch
|
||||||
|
from imaginairy.vendored.refiners.foundationals.clip.concepts import ConceptExtender
|
||||||
|
from imaginairy.vendored.refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
||||||
|
from imaginairy.vendored.refiners.fluxion.utils import load_from_safetensors
|
||||||
|
|
||||||
|
encoder = CLIPTextEncoderL(device="cuda")
|
||||||
|
tensors = load_from_safetensors("CLIPTextEncoderL.safetensors")
|
||||||
|
encoder.load_state_dict(tensors)
|
||||||
|
|
||||||
|
cat_embedding = torch.load("cat_embedding.bin")["<this-cat>"]
|
||||||
|
dog_embedding = torch.load("dog_embedding.bin")["<that-dog>"]
|
||||||
|
|
||||||
|
extender = ConceptExtender(encoder)
|
||||||
|
extender.add_concept(token="<this-cat>", embedding=cat_embedding)
|
||||||
|
extender.inject()
|
||||||
|
# New concepts can be added at any time
|
||||||
|
extender.add_concept(token="<that-dog>", embedding=dog_embedding)
|
||||||
|
|
||||||
|
# Now the encoder can be used with the new concepts
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, target: CLIPTextEncoder) -> None:
|
||||||
|
with self.setup_adapter(target):
|
||||||
|
super().__init__(target)
|
||||||
|
|
||||||
|
try:
|
||||||
|
token_encoder, token_encoder_parent = next(target.walk(TokenEncoder))
|
||||||
|
self._token_encoder_parent = [token_encoder_parent]
|
||||||
|
|
||||||
|
except StopIteration:
|
||||||
|
raise RuntimeError("TokenEncoder not found.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
clip_tokenizer, clip_tokenizer_parent = next(target.walk(CLIPTokenizer))
|
||||||
|
self._clip_tokenizer_parent = [clip_tokenizer_parent]
|
||||||
|
except StopIteration:
|
||||||
|
raise RuntimeError("Tokenizer not found.")
|
||||||
|
|
||||||
|
self._embedding_extender = [EmbeddingExtender(token_encoder)]
|
||||||
|
self._token_extender = [TokenExtender(clip_tokenizer)]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def embedding_extender(self) -> EmbeddingExtender:
|
||||||
|
assert len(self._embedding_extender) == 1, "EmbeddingExtender not found."
|
||||||
|
return self._embedding_extender[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def token_extender(self) -> TokenExtender:
|
||||||
|
assert len(self._token_extender) == 1, "TokenExtender not found."
|
||||||
|
return self._token_extender[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def token_encoder_parent(self) -> fl.Chain:
|
||||||
|
assert len(self._token_encoder_parent) == 1, "TokenEncoder parent not found."
|
||||||
|
return self._token_encoder_parent[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def clip_tokenizer_parent(self) -> fl.Chain:
|
||||||
|
assert len(self._clip_tokenizer_parent) == 1, "Tokenizer parent not found."
|
||||||
|
return self._clip_tokenizer_parent[0]
|
||||||
|
|
||||||
|
def add_concept(self, token: str, embedding: Tensor) -> None:
|
||||||
|
self.embedding_extender.add_embedding(embedding)
|
||||||
|
self.token_extender.add_token(token, self.embedding_extender.num_embeddings - 1)
|
||||||
|
|
||||||
|
def inject(self: "ConceptExtender", parent: fl.Chain | None = None) -> "ConceptExtender":
|
||||||
|
self.embedding_extender.inject(self.token_encoder_parent)
|
||||||
|
self.token_extender.inject(self.clip_tokenizer_parent)
|
||||||
|
return super().inject(parent)
|
||||||
|
|
||||||
|
def eject(self) -> None:
|
||||||
|
self.embedding_extender.eject()
|
||||||
|
self.token_extender.eject()
|
||||||
|
super().eject()
|
179
imaginairy/vendored/refiners/foundationals/clip/image_encoder.py
Normal file
179
imaginairy/vendored/refiners/foundationals/clip/image_encoder.py
Normal file
@ -0,0 +1,179 @@
|
|||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
|
|
||||||
|
import imaginairy.vendored.refiners.fluxion.layers as fl
|
||||||
|
from imaginairy.vendored.refiners.foundationals.clip.common import FeedForward, PositionalEncoder
|
||||||
|
|
||||||
|
|
||||||
|
class ClassToken(fl.Chain):
|
||||||
|
def __init__(self, embedding_dim: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
super().__init__(fl.Parameter(1, embedding_dim, device=device, dtype=dtype))
|
||||||
|
|
||||||
|
|
||||||
|
class PatchEncoder(fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
patch_size: int = 16,
|
||||||
|
use_bias: bool = True,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.use_bias = use_bias
|
||||||
|
super().__init__(
|
||||||
|
fl.Conv2d(
|
||||||
|
in_channels=self.in_channels,
|
||||||
|
out_channels=self.out_channels,
|
||||||
|
kernel_size=(self.patch_size, self.patch_size),
|
||||||
|
stride=(self.patch_size, self.patch_size),
|
||||||
|
use_bias=self.use_bias,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
fl.Permute(0, 2, 3, 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerLayer(fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int = 768,
|
||||||
|
feedforward_dim: int = 3072,
|
||||||
|
num_attention_heads: int = 12,
|
||||||
|
layer_norm_eps: float = 1e-5,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.feedforward_dim = feedforward_dim
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.layer_norm_eps = layer_norm_eps
|
||||||
|
super().__init__(
|
||||||
|
fl.Residual(
|
||||||
|
fl.LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype),
|
||||||
|
fl.SelfAttention(
|
||||||
|
embedding_dim=embedding_dim, num_heads=num_attention_heads, device=device, dtype=dtype
|
||||||
|
),
|
||||||
|
),
|
||||||
|
fl.Residual(
|
||||||
|
fl.LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype),
|
||||||
|
FeedForward(embedding_dim=embedding_dim, feedforward_dim=feedforward_dim, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ViTEmbeddings(fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
image_size: int = 224,
|
||||||
|
embedding_dim: int = 768,
|
||||||
|
patch_size: int = 32,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.image_size = image_size
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.patch_size = patch_size
|
||||||
|
super().__init__(
|
||||||
|
fl.Concatenate(
|
||||||
|
ClassToken(embedding_dim, device=device, dtype=dtype),
|
||||||
|
fl.Chain(
|
||||||
|
PatchEncoder(
|
||||||
|
in_channels=3,
|
||||||
|
out_channels=embedding_dim,
|
||||||
|
patch_size=patch_size,
|
||||||
|
use_bias=False,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
fl.Reshape((image_size // patch_size) ** 2, embedding_dim),
|
||||||
|
),
|
||||||
|
dim=1,
|
||||||
|
),
|
||||||
|
fl.Residual(
|
||||||
|
PositionalEncoder(
|
||||||
|
max_sequence_length=(image_size // patch_size) ** 2 + 1,
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPImageEncoder(fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
image_size: int = 224,
|
||||||
|
embedding_dim: int = 768,
|
||||||
|
output_dim: int = 512,
|
||||||
|
patch_size: int = 32,
|
||||||
|
num_layers: int = 12,
|
||||||
|
num_attention_heads: int = 12,
|
||||||
|
feedforward_dim: int = 3072,
|
||||||
|
layer_norm_eps: float = 1e-5,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.image_size = image_size
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.output_dim = output_dim
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.feedforward_dim = feedforward_dim
|
||||||
|
cls_token_pooling: Callable[[Tensor], Tensor] = lambda x: x[:, 0, :]
|
||||||
|
super().__init__(
|
||||||
|
ViTEmbeddings(
|
||||||
|
image_size=image_size, embedding_dim=embedding_dim, patch_size=patch_size, device=device, dtype=dtype
|
||||||
|
),
|
||||||
|
fl.LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype),
|
||||||
|
fl.Chain(
|
||||||
|
TransformerLayer(
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
feedforward_dim=feedforward_dim,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
layer_norm_eps=layer_norm_eps,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
),
|
||||||
|
fl.Lambda(func=cls_token_pooling),
|
||||||
|
fl.LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype),
|
||||||
|
fl.Linear(in_features=embedding_dim, out_features=output_dim, bias=False, device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPImageEncoderH(CLIPImageEncoder):
|
||||||
|
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
|
super().__init__(
|
||||||
|
embedding_dim=1280,
|
||||||
|
output_dim=1024,
|
||||||
|
patch_size=14,
|
||||||
|
num_layers=32,
|
||||||
|
num_attention_heads=16,
|
||||||
|
feedforward_dim=5120,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPImageEncoderG(CLIPImageEncoder):
|
||||||
|
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
|
super().__init__(
|
||||||
|
embedding_dim=1664,
|
||||||
|
output_dim=1280,
|
||||||
|
patch_size=14,
|
||||||
|
num_layers=48,
|
||||||
|
num_attention_heads=16,
|
||||||
|
feedforward_dim=8192,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
195
imaginairy/vendored/refiners/foundationals/clip/text_encoder.py
Normal file
195
imaginairy/vendored/refiners/foundationals/clip/text_encoder.py
Normal file
@ -0,0 +1,195 @@
|
|||||||
|
from torch import device as Device, dtype as DType
|
||||||
|
|
||||||
|
import imaginairy.vendored.refiners.fluxion.layers as fl
|
||||||
|
from imaginairy.vendored.refiners.foundationals.clip.common import FeedForward, PositionalEncoder
|
||||||
|
from imaginairy.vendored.refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class TokenEncoder(fl.Embedding):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocabulary_size: int,
|
||||||
|
embedding_dim: int,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.vocabulary_size = vocabulary_size
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
super().__init__(
|
||||||
|
num_embeddings=vocabulary_size,
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerLayer(fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int,
|
||||||
|
feedforward_dim: int,
|
||||||
|
num_attention_heads: int = 1,
|
||||||
|
layer_norm_eps: float = 1e-5,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.feedforward_dim = feedforward_dim
|
||||||
|
self.layer_norm_eps = layer_norm_eps
|
||||||
|
super().__init__(
|
||||||
|
fl.Residual(
|
||||||
|
fl.LayerNorm(
|
||||||
|
normalized_shape=embedding_dim,
|
||||||
|
eps=layer_norm_eps,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
fl.SelfAttention(
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
num_heads=num_attention_heads,
|
||||||
|
is_causal=True,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
fl.Residual(
|
||||||
|
fl.LayerNorm(
|
||||||
|
normalized_shape=embedding_dim,
|
||||||
|
eps=layer_norm_eps,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
FeedForward(
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
feedforward_dim=feedforward_dim,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPTextEncoder(fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int = 768,
|
||||||
|
max_sequence_length: int = 77,
|
||||||
|
vocabulary_size: int = 49408,
|
||||||
|
num_layers: int = 12,
|
||||||
|
num_attention_heads: int = 12,
|
||||||
|
feedforward_dim: int = 3072,
|
||||||
|
layer_norm_eps: float = 1e-5,
|
||||||
|
use_quick_gelu: bool = False,
|
||||||
|
tokenizer: CLIPTokenizer | None = None,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.max_sequence_length = max_sequence_length
|
||||||
|
self.vocabulary_size = vocabulary_size
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.feedforward_dim = feedforward_dim
|
||||||
|
self.layer_norm_eps = layer_norm_eps
|
||||||
|
self.use_quick_gelu = use_quick_gelu
|
||||||
|
super().__init__(
|
||||||
|
tokenizer or CLIPTokenizer(sequence_length=max_sequence_length),
|
||||||
|
fl.Converter(set_dtype=False),
|
||||||
|
fl.Sum(
|
||||||
|
TokenEncoder(
|
||||||
|
vocabulary_size=vocabulary_size,
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
PositionalEncoder(
|
||||||
|
max_sequence_length=max_sequence_length,
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
*(
|
||||||
|
TransformerLayer(
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
feedforward_dim=feedforward_dim,
|
||||||
|
layer_norm_eps=layer_norm_eps,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
),
|
||||||
|
fl.LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
if use_quick_gelu:
|
||||||
|
for gelu, parent in self.walk(predicate=lambda m, _: isinstance(m, fl.GeLU)):
|
||||||
|
parent.replace(old_module=gelu, new_module=fl.ApproximateGeLU())
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPTextEncoderL(CLIPTextEncoder):
|
||||||
|
"""
|
||||||
|
CLIPTextEncoderL is the CLIP text encoder with the following parameters:
|
||||||
|
embedding_dim=768
|
||||||
|
num_layers=12
|
||||||
|
num_attention_heads=12
|
||||||
|
feedforward_dim=3072
|
||||||
|
use_quick_gelu=True
|
||||||
|
|
||||||
|
We replace the GeLU activation function with an approximate GeLU to comply with the original CLIP implementation
|
||||||
|
of OpenAI (https://github.com/openai/CLIP/blob/main/clip/model.py#L166)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
|
super().__init__(
|
||||||
|
embedding_dim=768,
|
||||||
|
num_layers=12,
|
||||||
|
num_attention_heads=12,
|
||||||
|
feedforward_dim=3072,
|
||||||
|
use_quick_gelu=True,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPTextEncoderH(CLIPTextEncoder):
|
||||||
|
"""
|
||||||
|
CLIPTextEncoderH is the CLIP text encoder with the following parameters:
|
||||||
|
embedding_dim=1024
|
||||||
|
num_layers=23
|
||||||
|
num_attention_heads=16
|
||||||
|
feedforward_dim=4096
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
|
super().__init__(
|
||||||
|
embedding_dim=1024,
|
||||||
|
num_layers=23,
|
||||||
|
num_attention_heads=16,
|
||||||
|
feedforward_dim=4096,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPTextEncoderG(CLIPTextEncoder):
|
||||||
|
"""
|
||||||
|
CLIPTextEncoderG is the CLIP text encoder with the following parameters:
|
||||||
|
embedding_dim=1280
|
||||||
|
num_layers=32
|
||||||
|
num_attention_heads=16
|
||||||
|
feedforward_dim=5120
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
|
tokenizer = CLIPTokenizer(pad_token_id=0)
|
||||||
|
super().__init__(
|
||||||
|
embedding_dim=1280,
|
||||||
|
num_layers=32,
|
||||||
|
num_attention_heads=20,
|
||||||
|
feedforward_dim=5120,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
121
imaginairy/vendored/refiners/foundationals/clip/tokenizer.py
Normal file
121
imaginairy/vendored/refiners/foundationals/clip/tokenizer.py
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
import gzip
|
||||||
|
import re
|
||||||
|
from functools import lru_cache
|
||||||
|
from itertools import islice
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from torch import Tensor, tensor
|
||||||
|
|
||||||
|
import imaginairy.vendored.refiners.fluxion.layers as fl
|
||||||
|
from imaginairy.vendored.refiners.fluxion import pad
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPTokenizer(fl.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocabulary_path: str | Path = Path(__file__).resolve().parent / "bpe_simple_vocab_16e6.txt.gz",
|
||||||
|
sequence_length: int = 77,
|
||||||
|
start_of_text_token_id: int = 49406,
|
||||||
|
end_of_text_token_id: int = 49407,
|
||||||
|
pad_token_id: int = 49407,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.vocabulary_path = vocabulary_path
|
||||||
|
self.sequence_length = sequence_length
|
||||||
|
self.byte_to_unicode_mapping = self.get_bytes_to_unicode_mapping()
|
||||||
|
self.byte_decoder = {v: k for k, v in self.byte_to_unicode_mapping.items()}
|
||||||
|
merge_tuples = [
|
||||||
|
tuple(merge.split())
|
||||||
|
for merge in gzip.open(filename=vocabulary_path)
|
||||||
|
.read()
|
||||||
|
.decode(encoding="utf-8")
|
||||||
|
.split(sep="\n")[1 : 49152 - 256 - 2 + 1]
|
||||||
|
]
|
||||||
|
vocabulary = (
|
||||||
|
list(self.byte_to_unicode_mapping.values())
|
||||||
|
+ [v + "</w>" for v in self.byte_to_unicode_mapping.values()]
|
||||||
|
+ ["".join(merge) for merge in merge_tuples]
|
||||||
|
+ ["", ""]
|
||||||
|
)
|
||||||
|
self.token_to_id_mapping = {token: i for i, token in enumerate(iterable=vocabulary)}
|
||||||
|
self.byte_pair_encoding_ranks = {merge: i for i, merge in enumerate(iterable=merge_tuples)}
|
||||||
|
self.byte_pair_encoding_cache = {"": ""}
|
||||||
|
# Note: this regular expression does not support Unicode. It was changed so
|
||||||
|
# to get rid of the dependence on the `regex` module. Unicode support could
|
||||||
|
# potentially be added back by leveraging the `\w` character class.
|
||||||
|
self.token_pattern = re.compile(
|
||||||
|
pattern=r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[a-zA-Z]+|[0-9]|[^\s\w]+""",
|
||||||
|
flags=re.IGNORECASE,
|
||||||
|
)
|
||||||
|
self.start_of_text_token_id: int = start_of_text_token_id
|
||||||
|
self.end_of_text_token_id: int = end_of_text_token_id
|
||||||
|
self.pad_token_id: int = pad_token_id
|
||||||
|
|
||||||
|
def forward(self, text: str) -> Tensor:
|
||||||
|
tokens = self.encode(text=text, max_length=self.sequence_length).unsqueeze(dim=0)
|
||||||
|
assert (
|
||||||
|
tokens.shape[1] <= self.sequence_length
|
||||||
|
), f"Text is too long: tokens.shape[1] > sequence_length: {tokens.shape[1]} > {self.sequence_length}"
|
||||||
|
return pad(x=tokens, pad=(0, self.sequence_length - tokens.shape[1]), value=self.pad_token_id)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def get_bytes_to_unicode_mapping(self) -> dict[int, str]:
|
||||||
|
initial_byte_values = (
|
||||||
|
list(range(ord("!"), ord("~") + 1))
|
||||||
|
+ list(range(ord("¡"), ord("¬") + 1))
|
||||||
|
+ list(range(ord("®"), ord("ÿ") + 1))
|
||||||
|
)
|
||||||
|
extra_unicode_values = (byte for byte in range(2**8) if byte not in initial_byte_values)
|
||||||
|
byte_values = initial_byte_values + list(extra_unicode_values)
|
||||||
|
unicode_values = [chr(value) for value in byte_values]
|
||||||
|
return dict(zip(byte_values, unicode_values))
|
||||||
|
|
||||||
|
def byte_pair_encoding(self, token: str) -> str:
|
||||||
|
if token in self.byte_pair_encoding_cache:
|
||||||
|
return self.byte_pair_encoding_cache[token]
|
||||||
|
|
||||||
|
def recursive_bpe(word: tuple[str, ...]) -> tuple[str, ...]:
|
||||||
|
if len(word) < 2:
|
||||||
|
return word
|
||||||
|
pairs = {(i, (word[i], word[i + 1])) for i in range(len(word) - 1)}
|
||||||
|
min_pair = min(
|
||||||
|
pairs,
|
||||||
|
key=lambda pair: self.byte_pair_encoding_ranks.get(pair[1], float("inf")),
|
||||||
|
)
|
||||||
|
if min_pair[1] not in self.byte_pair_encoding_ranks:
|
||||||
|
return word
|
||||||
|
new_word: list[str] = []
|
||||||
|
i = 0
|
||||||
|
while i < len(word):
|
||||||
|
if i == min_pair[0]:
|
||||||
|
new_word.append(min_pair[1][0] + min_pair[1][1])
|
||||||
|
i += 2
|
||||||
|
else:
|
||||||
|
new_word.append(word[i])
|
||||||
|
i += 1
|
||||||
|
return recursive_bpe(tuple(new_word))
|
||||||
|
|
||||||
|
word = tuple(token[:-1]) + (token[-1] + "</w>",)
|
||||||
|
result = " ".join(recursive_bpe(word=word))
|
||||||
|
self.byte_pair_encoding_cache[token] = result
|
||||||
|
return result
|
||||||
|
|
||||||
|
def encode(self, text: str, max_length: int | None = None) -> Tensor:
|
||||||
|
text = re.sub(pattern=r"\s+", repl=" ", string=text.lower())
|
||||||
|
tokens = re.findall(pattern=self.token_pattern, string=text)
|
||||||
|
upper_bound = None
|
||||||
|
if max_length:
|
||||||
|
assert max_length >= 2
|
||||||
|
upper_bound = max_length - 2
|
||||||
|
encoded_tokens = islice(
|
||||||
|
(
|
||||||
|
self.token_to_id_mapping[subtoken]
|
||||||
|
for token in tokens
|
||||||
|
for subtoken in self.byte_pair_encoding(
|
||||||
|
token="".join(self.byte_to_unicode_mapping[character] for character in token.encode("utf-8"))
|
||||||
|
).split(sep=" ")
|
||||||
|
),
|
||||||
|
0,
|
||||||
|
upper_bound,
|
||||||
|
)
|
||||||
|
return tensor(data=[self.start_of_text_token_id, *encoded_tokens, self.end_of_text_token_id])
|
@ -0,0 +1,29 @@
|
|||||||
|
from .dinov2 import (
|
||||||
|
DINOv2_base,
|
||||||
|
DINOv2_base_reg,
|
||||||
|
DINOv2_large,
|
||||||
|
DINOv2_large_reg,
|
||||||
|
DINOv2_small,
|
||||||
|
DINOv2_small_reg,
|
||||||
|
)
|
||||||
|
from .vit import (
|
||||||
|
ViT,
|
||||||
|
ViT_base,
|
||||||
|
ViT_large,
|
||||||
|
ViT_small,
|
||||||
|
ViT_tiny,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DINOv2_base",
|
||||||
|
"DINOv2_base_reg",
|
||||||
|
"DINOv2_large",
|
||||||
|
"DINOv2_large_reg",
|
||||||
|
"DINOv2_small",
|
||||||
|
"DINOv2_small_reg",
|
||||||
|
"ViT",
|
||||||
|
"ViT_base",
|
||||||
|
"ViT_large",
|
||||||
|
"ViT_small",
|
||||||
|
"ViT_tiny",
|
||||||
|
]
|
148
imaginairy/vendored/refiners/foundationals/dinov2/dinov2.py
Normal file
148
imaginairy/vendored/refiners/foundationals/dinov2/dinov2.py
Normal file
@ -0,0 +1,148 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from imaginairy.vendored.refiners.foundationals.dinov2.vit import ViT
|
||||||
|
|
||||||
|
# TODO: add preprocessing logic like
|
||||||
|
# https://github.com/facebookresearch/dinov2/blob/2302b6b/dinov2/data/transforms.py#L77
|
||||||
|
|
||||||
|
|
||||||
|
class DINOv2_small(ViT):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
device: torch.device | str | None = None,
|
||||||
|
dtype: torch.dtype | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
embedding_dim=384,
|
||||||
|
patch_size=14,
|
||||||
|
image_size=518,
|
||||||
|
num_layers=12,
|
||||||
|
num_heads=6,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DINOv2_base(ViT):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
device: torch.device | str | None = None,
|
||||||
|
dtype: torch.dtype | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
embedding_dim=768,
|
||||||
|
patch_size=14,
|
||||||
|
image_size=518,
|
||||||
|
num_layers=12,
|
||||||
|
num_heads=12,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DINOv2_large(ViT):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
device: torch.device | str | None = None,
|
||||||
|
dtype: torch.dtype | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
embedding_dim=1024,
|
||||||
|
patch_size=14,
|
||||||
|
image_size=518,
|
||||||
|
num_layers=24,
|
||||||
|
num_heads=16,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: implement SwiGLU layer
|
||||||
|
# class DINOv2_giant2(ViT):
|
||||||
|
# def __init__(
|
||||||
|
# self,
|
||||||
|
# device: torch.device | str | None = None,
|
||||||
|
# dtype: torch.dtype | None = None,
|
||||||
|
# ) -> None:
|
||||||
|
# super().__init__(
|
||||||
|
# embedding_dim=1536,
|
||||||
|
# patch_size=14,
|
||||||
|
# image_size=518,
|
||||||
|
# num_layers=40,
|
||||||
|
# num_heads=24,
|
||||||
|
# device=device,
|
||||||
|
# dtype=dtype,
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
class DINOv2_small_reg(ViT):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
device: torch.device | str | None = None,
|
||||||
|
dtype: torch.dtype | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
embedding_dim=384,
|
||||||
|
patch_size=14,
|
||||||
|
image_size=518,
|
||||||
|
num_layers=12,
|
||||||
|
num_heads=6,
|
||||||
|
num_registers=4,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DINOv2_base_reg(ViT):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
device: torch.device | str | None = None,
|
||||||
|
dtype: torch.dtype | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
embedding_dim=768,
|
||||||
|
patch_size=14,
|
||||||
|
image_size=518,
|
||||||
|
num_layers=12,
|
||||||
|
num_heads=12,
|
||||||
|
num_registers=4,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DINOv2_large_reg(ViT):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
device: torch.device | str | None = None,
|
||||||
|
dtype: torch.dtype | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
embedding_dim=1024,
|
||||||
|
patch_size=14,
|
||||||
|
image_size=518,
|
||||||
|
num_layers=24,
|
||||||
|
num_heads=16,
|
||||||
|
num_registers=4,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: implement SwiGLU layer
|
||||||
|
# class DINOv2_giant2_reg(ViT):
|
||||||
|
# def __init__(
|
||||||
|
# self,
|
||||||
|
# device: torch.device | str | None = None,
|
||||||
|
# dtype: torch.dtype | None = None,
|
||||||
|
# ) -> None:
|
||||||
|
# super().__init__(
|
||||||
|
# embedding_dim=1536,
|
||||||
|
# patch_size=14,
|
||||||
|
# image_size=518,
|
||||||
|
# num_layers=40,
|
||||||
|
# num_heads=24,
|
||||||
|
# num_registers=4,
|
||||||
|
# device=device,
|
||||||
|
# dtype=dtype,
|
||||||
|
# )
|
373
imaginairy/vendored/refiners/foundationals/dinov2/vit.py
Normal file
373
imaginairy/vendored/refiners/foundationals/dinov2/vit.py
Normal file
@ -0,0 +1,373 @@
|
|||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
import imaginairy.vendored.refiners.fluxion.layers as fl
|
||||||
|
from imaginairy.vendored.refiners.fluxion.layers.activations import Activation
|
||||||
|
|
||||||
|
|
||||||
|
class ClassToken(fl.Chain):
|
||||||
|
"""Learnable token representing the class of the input."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int,
|
||||||
|
device: torch.device | str | None = None,
|
||||||
|
dtype: torch.dtype | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
fl.Parameter(
|
||||||
|
*(1, embedding_dim),
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PositionalEncoder(fl.Residual):
|
||||||
|
"""Encode the position of each patch in the input."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sequence_length: int,
|
||||||
|
embedding_dim: int,
|
||||||
|
device: torch.device | str | None = None,
|
||||||
|
dtype: torch.dtype | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.num_patches = sequence_length
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
fl.Parameter(
|
||||||
|
*(sequence_length, embedding_dim),
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LayerScale(fl.WeightedModule):
|
||||||
|
"""Scale the input tensor by a learnable parameter."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int,
|
||||||
|
init_value: float = 1.0,
|
||||||
|
dtype: torch.dtype | None = None,
|
||||||
|
device: torch.device | str | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
|
||||||
|
self.register_parameter(
|
||||||
|
name="weight",
|
||||||
|
param=torch.nn.Parameter(
|
||||||
|
torch.full(
|
||||||
|
size=(embedding_dim,),
|
||||||
|
fill_value=init_value,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return x * self.weight
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(fl.Chain):
|
||||||
|
"""Apply two linear transformations interleaved by an activation function."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int,
|
||||||
|
feedforward_dim: int,
|
||||||
|
activation: Activation = fl.GeLU, # type: ignore
|
||||||
|
device: torch.device | str | None = None,
|
||||||
|
dtype: torch.dtype | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.feedforward_dim = feedforward_dim
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
fl.Linear(
|
||||||
|
in_features=embedding_dim,
|
||||||
|
out_features=feedforward_dim,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
activation(),
|
||||||
|
fl.Linear(
|
||||||
|
in_features=feedforward_dim,
|
||||||
|
out_features=embedding_dim,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PatchEncoder(fl.Chain):
|
||||||
|
"""Encode an image into a sequence of patches."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
patch_size: int,
|
||||||
|
device: torch.device | str | None = None,
|
||||||
|
dtype: torch.dtype | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.patch_size = patch_size
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
fl.Conv2d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=patch_size,
|
||||||
|
stride=patch_size,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
), # (N,3,H,W) -> (N,D,P,P)
|
||||||
|
fl.Reshape(out_channels, -1), # (N,D,P,P) -> (N,D,P²)
|
||||||
|
fl.Transpose(1, 2), # (N,D,P²) -> (N,P²,D)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerLayer(fl.Chain):
|
||||||
|
"""Apply a multi-head self-attention mechanism to the input tensor."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
norm_eps: float,
|
||||||
|
mlp_ratio: int,
|
||||||
|
device: torch.device | str | None = None,
|
||||||
|
dtype: torch.dtype | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.norm_eps = norm_eps
|
||||||
|
self.mlp_ratio = mlp_ratio
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
fl.Residual(
|
||||||
|
fl.LayerNorm(
|
||||||
|
normalized_shape=embedding_dim,
|
||||||
|
eps=norm_eps,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
fl.SelfAttention(
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
LayerScale(
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
fl.Residual(
|
||||||
|
fl.LayerNorm(
|
||||||
|
normalized_shape=embedding_dim,
|
||||||
|
eps=norm_eps,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
FeedForward(
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
feedforward_dim=embedding_dim * mlp_ratio,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
LayerScale(
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Transformer(fl.Chain):
|
||||||
|
"""Alias for a Chain of TransformerLayer."""
|
||||||
|
|
||||||
|
|
||||||
|
class Registers(fl.Concatenate):
|
||||||
|
"""Insert register tokens between CLS token and patches."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_registers: int,
|
||||||
|
embedding_dim: int,
|
||||||
|
device: torch.device | str | None = None,
|
||||||
|
dtype: torch.dtype | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.num_registers = num_registers
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
fl.Slicing(dim=1, end=1),
|
||||||
|
fl.Parameter(
|
||||||
|
*(num_registers, embedding_dim),
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
fl.Slicing(dim=1, start=1),
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ViT(fl.Chain):
|
||||||
|
"""Vision Transformer (ViT).
|
||||||
|
|
||||||
|
see https://arxiv.org/abs/2010.11929v2
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int = 768,
|
||||||
|
patch_size: int = 16,
|
||||||
|
image_size: int = 224,
|
||||||
|
num_layers: int = 12,
|
||||||
|
num_heads: int = 12,
|
||||||
|
norm_eps: float = 1e-6,
|
||||||
|
mlp_ratio: int = 4,
|
||||||
|
num_registers: int = 0,
|
||||||
|
device: torch.device | str | None = None,
|
||||||
|
dtype: torch.dtype | None = None,
|
||||||
|
) -> None:
|
||||||
|
num_patches = image_size // patch_size
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.image_size = image_size
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.norm_eps = norm_eps
|
||||||
|
self.mlp_ratio = mlp_ratio
|
||||||
|
self.num_registers = num_registers
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
fl.Concatenate(
|
||||||
|
ClassToken(
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
PatchEncoder(
|
||||||
|
in_channels=3,
|
||||||
|
out_channels=embedding_dim,
|
||||||
|
patch_size=patch_size,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
dim=1,
|
||||||
|
),
|
||||||
|
# TODO: support https://github.com/facebookresearch/dinov2/blob/2302b6b/dinov2/models/vision_transformer.py#L179
|
||||||
|
PositionalEncoder(
|
||||||
|
sequence_length=num_patches**2 + 1,
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
Transformer(
|
||||||
|
TransformerLayer(
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
norm_eps=norm_eps,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
),
|
||||||
|
fl.LayerNorm(
|
||||||
|
normalized_shape=embedding_dim,
|
||||||
|
eps=norm_eps,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.num_registers > 0:
|
||||||
|
registers = Registers(
|
||||||
|
num_registers=num_registers,
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
self.insert_before_type(Transformer, registers)
|
||||||
|
|
||||||
|
|
||||||
|
class ViT_tiny(ViT):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
device: torch.device | str | None = None,
|
||||||
|
dtype: torch.dtype | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
embedding_dim=192,
|
||||||
|
patch_size=16,
|
||||||
|
image_size=224,
|
||||||
|
num_layers=12,
|
||||||
|
num_heads=3,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ViT_small(ViT):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
device: torch.device | str | None = None,
|
||||||
|
dtype: torch.dtype | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
embedding_dim=384,
|
||||||
|
patch_size=16,
|
||||||
|
image_size=224,
|
||||||
|
num_layers=12,
|
||||||
|
num_heads=6,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ViT_base(ViT):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
device: torch.device | str | None = None,
|
||||||
|
dtype: torch.dtype | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
embedding_dim=768,
|
||||||
|
patch_size=16,
|
||||||
|
image_size=224,
|
||||||
|
num_layers=12,
|
||||||
|
num_heads=12,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ViT_large(ViT):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
device: torch.device | str | None = None,
|
||||||
|
dtype: torch.dtype | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
embedding_dim=1024,
|
||||||
|
patch_size=16,
|
||||||
|
image_size=224,
|
||||||
|
num_layers=24,
|
||||||
|
num_heads=16,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
@ -0,0 +1,40 @@
|
|||||||
|
from imaginairy.vendored.refiners.foundationals.clip.text_encoder import (
|
||||||
|
CLIPTextEncoderL,
|
||||||
|
)
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.auto_encoder import (
|
||||||
|
LatentDiffusionAutoencoder,
|
||||||
|
)
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.freeu import SDFreeUAdapter
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers import DPMSolver, Scheduler
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1 import (
|
||||||
|
SD1ControlnetAdapter,
|
||||||
|
SD1IPAdapter,
|
||||||
|
SD1T2IAdapter,
|
||||||
|
SD1UNet,
|
||||||
|
StableDiffusion_1,
|
||||||
|
StableDiffusion_1_Inpainting,
|
||||||
|
)
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_xl import (
|
||||||
|
DoubleTextEncoder,
|
||||||
|
SDXLIPAdapter,
|
||||||
|
SDXLT2IAdapter,
|
||||||
|
SDXLUNet,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"StableDiffusion_1",
|
||||||
|
"StableDiffusion_1_Inpainting",
|
||||||
|
"SD1UNet",
|
||||||
|
"SD1ControlnetAdapter",
|
||||||
|
"SD1IPAdapter",
|
||||||
|
"SD1T2IAdapter",
|
||||||
|
"SDXLUNet",
|
||||||
|
"DoubleTextEncoder",
|
||||||
|
"SDXLIPAdapter",
|
||||||
|
"SDXLT2IAdapter",
|
||||||
|
"DPMSolver",
|
||||||
|
"Scheduler",
|
||||||
|
"CLIPTextEncoderL",
|
||||||
|
"LatentDiffusionAutoencoder",
|
||||||
|
"SDFreeUAdapter",
|
||||||
|
]
|
@ -0,0 +1,221 @@
|
|||||||
|
from PIL import Image
|
||||||
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
|
|
||||||
|
from imaginairy.vendored.refiners.fluxion.context import Contexts
|
||||||
|
from imaginairy.vendored.refiners.fluxion.layers import (
|
||||||
|
Chain,
|
||||||
|
Conv2d,
|
||||||
|
Downsample,
|
||||||
|
GroupNorm,
|
||||||
|
Identity,
|
||||||
|
Residual,
|
||||||
|
SelfAttention2d,
|
||||||
|
SiLU,
|
||||||
|
Slicing,
|
||||||
|
Sum,
|
||||||
|
Upsample,
|
||||||
|
)
|
||||||
|
from imaginairy.vendored.refiners.fluxion.utils import image_to_tensor, tensor_to_image
|
||||||
|
|
||||||
|
|
||||||
|
class Resnet(Sum):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
num_groups: int = 32,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
):
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
shortcut = (
|
||||||
|
Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, device=device, dtype=dtype)
|
||||||
|
if in_channels != out_channels
|
||||||
|
else Identity()
|
||||||
|
)
|
||||||
|
super().__init__(
|
||||||
|
shortcut,
|
||||||
|
Chain(
|
||||||
|
GroupNorm(channels=in_channels, num_groups=num_groups, device=device, dtype=dtype),
|
||||||
|
SiLU(),
|
||||||
|
Conv2d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
GroupNorm(channels=out_channels, num_groups=num_groups, device=device, dtype=dtype),
|
||||||
|
SiLU(),
|
||||||
|
Conv2d(
|
||||||
|
in_channels=out_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder(Chain):
|
||||||
|
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
|
resnet_sizes: list[int] = [128, 256, 512, 512, 512]
|
||||||
|
input_channels: int = 3
|
||||||
|
latent_dim: int = 8
|
||||||
|
resnet_layers: list[Chain] = [
|
||||||
|
Chain(
|
||||||
|
[
|
||||||
|
Resnet(
|
||||||
|
in_channels=resnet_sizes[i - 1] if i > 0 else resnet_sizes[0],
|
||||||
|
out_channels=resnet_sizes[i],
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
Resnet(in_channels=resnet_sizes[i], out_channels=resnet_sizes[i], device=device, dtype=dtype),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
for i in range(len(resnet_sizes))
|
||||||
|
]
|
||||||
|
for _, layer in zip(range(3), resnet_layers):
|
||||||
|
channels: int = layer[-1].out_channels # type: ignore
|
||||||
|
layer.append(Downsample(channels=channels, scale_factor=2, device=device, dtype=dtype))
|
||||||
|
|
||||||
|
attention_layer = Residual(
|
||||||
|
GroupNorm(channels=resnet_sizes[-1], num_groups=32, eps=1e-6, device=device, dtype=dtype),
|
||||||
|
SelfAttention2d(channels=resnet_sizes[-1], device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
resnet_layers[-1].insert_after_type(Resnet, attention_layer)
|
||||||
|
super().__init__(
|
||||||
|
Conv2d(
|
||||||
|
in_channels=input_channels,
|
||||||
|
out_channels=resnet_sizes[0],
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
Chain(*resnet_layers),
|
||||||
|
Chain(
|
||||||
|
GroupNorm(channels=resnet_sizes[-1], num_groups=32, eps=1e-6, device=device, dtype=dtype),
|
||||||
|
SiLU(),
|
||||||
|
Conv2d(
|
||||||
|
in_channels=resnet_sizes[-1],
|
||||||
|
out_channels=latent_dim,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Chain(
|
||||||
|
Conv2d(in_channels=8, out_channels=8, kernel_size=1, device=device, dtype=dtype),
|
||||||
|
Slicing(dim=1, end=4),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_context(self) -> Contexts:
|
||||||
|
return {"sampling": {"shapes": []}}
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder(Chain):
|
||||||
|
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
|
self.resnet_sizes: list[int] = [128, 256, 512, 512, 512]
|
||||||
|
self.latent_dim: int = 4
|
||||||
|
self.output_channels: int = 3
|
||||||
|
resnet_sizes = self.resnet_sizes[::-1]
|
||||||
|
resnet_layers: list[Chain] = [
|
||||||
|
(
|
||||||
|
Chain(
|
||||||
|
[
|
||||||
|
Resnet(
|
||||||
|
in_channels=resnet_sizes[i - 1] if i > 0 else resnet_sizes[0],
|
||||||
|
out_channels=resnet_sizes[i],
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
Resnet(in_channels=resnet_sizes[i], out_channels=resnet_sizes[i], device=device, dtype=dtype),
|
||||||
|
Resnet(in_channels=resnet_sizes[i], out_channels=resnet_sizes[i], device=device, dtype=dtype),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if i > 0
|
||||||
|
else Chain(
|
||||||
|
[
|
||||||
|
Resnet(in_channels=resnet_sizes[0], out_channels=resnet_sizes[i], device=device, dtype=dtype),
|
||||||
|
Resnet(in_channels=resnet_sizes[i], out_channels=resnet_sizes[i], device=device, dtype=dtype),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for i in range(len(resnet_sizes))
|
||||||
|
]
|
||||||
|
attention_layer = Residual(
|
||||||
|
GroupNorm(channels=resnet_sizes[0], num_groups=32, eps=1e-6, device=device, dtype=dtype),
|
||||||
|
SelfAttention2d(channels=resnet_sizes[0], device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
resnet_layers[0].insert(1, attention_layer)
|
||||||
|
for _, layer in zip(range(3), resnet_layers[1:]):
|
||||||
|
channels: int = layer[-1].out_channels
|
||||||
|
layer.insert(-1, Upsample(channels=channels, upsample_factor=2, device=device, dtype=dtype))
|
||||||
|
super().__init__(
|
||||||
|
Conv2d(
|
||||||
|
in_channels=self.latent_dim, out_channels=self.latent_dim, kernel_size=1, device=device, dtype=dtype
|
||||||
|
),
|
||||||
|
Conv2d(
|
||||||
|
in_channels=self.latent_dim,
|
||||||
|
out_channels=resnet_sizes[0],
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
Chain(*resnet_layers),
|
||||||
|
Chain(
|
||||||
|
GroupNorm(channels=resnet_sizes[-1], num_groups=32, eps=1e-6, device=device, dtype=dtype),
|
||||||
|
SiLU(),
|
||||||
|
Conv2d(
|
||||||
|
in_channels=resnet_sizes[-1],
|
||||||
|
out_channels=self.output_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LatentDiffusionAutoencoder(Chain):
|
||||||
|
encoder_scale = 0.18125
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
Encoder(device=device, dtype=dtype),
|
||||||
|
Decoder(device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
def encode(self, x: Tensor) -> Tensor:
|
||||||
|
encoder = self[0]
|
||||||
|
x = self.encoder_scale * encoder(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def decode(self, x: Tensor) -> Tensor:
|
||||||
|
decoder = self[1]
|
||||||
|
x = decoder(x / self.encoder_scale)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def encode_image(self, image: Image.Image) -> Tensor:
|
||||||
|
x = image_to_tensor(image, device=self.device, dtype=self.dtype)
|
||||||
|
x = 2 * x - 1
|
||||||
|
return self.encode(x)
|
||||||
|
|
||||||
|
def decode_latents(self, x: Tensor) -> Image.Image:
|
||||||
|
x = self.decode(x)
|
||||||
|
x = (x + 1) / 2
|
||||||
|
return tensor_to_image(x)
|
@ -0,0 +1,175 @@
|
|||||||
|
from torch import Size, Tensor, device as Device, dtype as DType
|
||||||
|
|
||||||
|
from imaginairy.vendored.refiners.fluxion.context import Contexts
|
||||||
|
from imaginairy.vendored.refiners.fluxion.layers import (
|
||||||
|
GLU,
|
||||||
|
Attention,
|
||||||
|
Chain,
|
||||||
|
Conv2d,
|
||||||
|
Flatten,
|
||||||
|
GeLU,
|
||||||
|
GroupNorm,
|
||||||
|
Identity,
|
||||||
|
LayerNorm,
|
||||||
|
Linear,
|
||||||
|
Parallel,
|
||||||
|
Residual,
|
||||||
|
SelfAttention,
|
||||||
|
SetContext,
|
||||||
|
Transpose,
|
||||||
|
Unflatten,
|
||||||
|
UseContext,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CrossAttentionBlock(Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int,
|
||||||
|
context_embedding_dim: int,
|
||||||
|
context_key: str,
|
||||||
|
num_heads: int = 1,
|
||||||
|
use_bias: bool = True,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.context_embedding_dim = context_embedding_dim
|
||||||
|
self.context = "cross_attention_block"
|
||||||
|
self.context_key = context_key
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.use_bias = use_bias
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
Residual(
|
||||||
|
LayerNorm(normalized_shape=embedding_dim, device=device, dtype=dtype),
|
||||||
|
SelfAttention(
|
||||||
|
embedding_dim=embedding_dim, num_heads=num_heads, use_bias=use_bias, device=device, dtype=dtype
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Residual(
|
||||||
|
LayerNorm(normalized_shape=embedding_dim, device=device, dtype=dtype),
|
||||||
|
Parallel(
|
||||||
|
Identity(),
|
||||||
|
UseContext(context=self.context, key=context_key),
|
||||||
|
UseContext(context=self.context, key=context_key),
|
||||||
|
),
|
||||||
|
Attention(
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
key_embedding_dim=context_embedding_dim,
|
||||||
|
value_embedding_dim=context_embedding_dim,
|
||||||
|
use_bias=use_bias,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Residual(
|
||||||
|
LayerNorm(normalized_shape=embedding_dim, device=device, dtype=dtype),
|
||||||
|
Linear(in_features=embedding_dim, out_features=2 * 4 * embedding_dim, device=device, dtype=dtype),
|
||||||
|
GLU(GeLU()),
|
||||||
|
Linear(in_features=4 * embedding_dim, out_features=embedding_dim, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class StatefulFlatten(Chain):
|
||||||
|
def __init__(self, context: str, key: str, start_dim: int = 0, end_dim: int = -1) -> None:
|
||||||
|
self.start_dim = start_dim
|
||||||
|
self.end_dim = end_dim
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
SetContext(context=context, key=key, callback=self.push),
|
||||||
|
Flatten(start_dim=start_dim, end_dim=end_dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
def push(self, sizes: list[Size], x: Tensor) -> None:
|
||||||
|
sizes.append(
|
||||||
|
x.shape[slice(self.start_dim, self.end_dim + 1 if self.end_dim >= 0 else x.ndim + self.end_dim + 1)]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CrossAttentionBlock2d(Residual):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
context_embedding_dim: int,
|
||||||
|
context_key: str,
|
||||||
|
num_attention_heads: int = 1,
|
||||||
|
num_attention_layers: int = 1,
|
||||||
|
num_groups: int = 32,
|
||||||
|
use_bias: bool = True,
|
||||||
|
use_linear_projection: bool = False,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
assert channels % num_attention_heads == 0, "in_channels must be divisible by num_attention_heads"
|
||||||
|
self.channels = channels
|
||||||
|
self.in_channels = channels
|
||||||
|
self.out_channels = channels
|
||||||
|
self.context_embedding_dim = context_embedding_dim
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.num_attention_layers = num_attention_layers
|
||||||
|
self.num_groups = num_groups
|
||||||
|
self.use_bias = use_bias
|
||||||
|
self.context_key = context_key
|
||||||
|
self.use_linear_projection = use_linear_projection
|
||||||
|
self.projection_type = "Linear" if use_linear_projection else "Conv2d"
|
||||||
|
|
||||||
|
in_block = (
|
||||||
|
Chain(
|
||||||
|
GroupNorm(channels=channels, num_groups=num_groups, eps=1e-6, device=device, dtype=dtype),
|
||||||
|
StatefulFlatten(context="flatten", key="sizes", start_dim=2),
|
||||||
|
Transpose(1, 2),
|
||||||
|
Linear(in_features=channels, out_features=channels, device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
if use_linear_projection
|
||||||
|
else Chain(
|
||||||
|
GroupNorm(channels=channels, num_groups=num_groups, eps=1e-6, device=device, dtype=dtype),
|
||||||
|
Conv2d(in_channels=channels, out_channels=channels, kernel_size=1, device=device, dtype=dtype),
|
||||||
|
StatefulFlatten(context="flatten", key="sizes", start_dim=2),
|
||||||
|
Transpose(1, 2),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
out_block = (
|
||||||
|
Chain(
|
||||||
|
Linear(in_features=channels, out_features=channels, device=device, dtype=dtype),
|
||||||
|
Transpose(1, 2),
|
||||||
|
Parallel(
|
||||||
|
Identity(),
|
||||||
|
UseContext(context="flatten", key="sizes").compose(lambda x: x.pop()),
|
||||||
|
),
|
||||||
|
Unflatten(dim=2),
|
||||||
|
)
|
||||||
|
if use_linear_projection
|
||||||
|
else Chain(
|
||||||
|
Transpose(1, 2),
|
||||||
|
Parallel(
|
||||||
|
Identity(),
|
||||||
|
UseContext(context="flatten", key="sizes").compose(lambda x: x.pop()),
|
||||||
|
),
|
||||||
|
Unflatten(dim=2),
|
||||||
|
Conv2d(in_channels=channels, out_channels=channels, kernel_size=1, device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
in_block,
|
||||||
|
Chain(
|
||||||
|
CrossAttentionBlock(
|
||||||
|
embedding_dim=channels,
|
||||||
|
context_embedding_dim=context_embedding_dim,
|
||||||
|
context_key=context_key,
|
||||||
|
num_heads=num_attention_heads,
|
||||||
|
use_bias=use_bias,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
for _ in range(num_attention_layers)
|
||||||
|
),
|
||||||
|
out_block,
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_context(self) -> Contexts:
|
||||||
|
return {"flatten": {"sizes": []}}
|
@ -0,0 +1,94 @@
|
|||||||
|
import math
|
||||||
|
from typing import Any, Callable, Generic, TypeVar
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.fft import fftn, fftshift, ifftn, ifftshift # type: ignore
|
||||||
|
|
||||||
|
import imaginairy.vendored.refiners.fluxion.layers as fl
|
||||||
|
from imaginairy.vendored.refiners.fluxion.adapters.adapter import Adapter
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import ResidualConcatenator, SD1UNet
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
|
||||||
|
|
||||||
|
T = TypeVar("T", bound="SD1UNet | SDXLUNet")
|
||||||
|
TSDFreeUAdapter = TypeVar("TSDFreeUAdapter", bound="SDFreeUAdapter[Any]") # Self (see PEP 673)
|
||||||
|
|
||||||
|
|
||||||
|
def fourier_filter(x: Tensor, scale: float = 1, threshold: int = 1) -> Tensor:
|
||||||
|
"""Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497).
|
||||||
|
|
||||||
|
This version of the method comes from here:
|
||||||
|
https://github.com/ChenyangSi/FreeU/blob/main/demo/free_lunch_utils.py#L23
|
||||||
|
"""
|
||||||
|
batch, channels, height, width = x.shape
|
||||||
|
dtype = x.dtype
|
||||||
|
device = x.device
|
||||||
|
|
||||||
|
if not (math.log2(height).is_integer() and math.log2(width).is_integer()):
|
||||||
|
x = x.to(dtype=torch.float32)
|
||||||
|
|
||||||
|
x_freq = fftn(x, dim=(-2, -1)) # type: ignore
|
||||||
|
x_freq = fftshift(x_freq, dim=(-2, -1)) # type: ignore
|
||||||
|
mask = torch.ones((batch, channels, height, width), device=device) # type: ignore
|
||||||
|
|
||||||
|
center_row, center_col = height // 2, width // 2 # type: ignore
|
||||||
|
mask[..., center_row - threshold : center_row + threshold, center_col - threshold : center_col + threshold] = scale
|
||||||
|
x_freq = x_freq * mask # type: ignore
|
||||||
|
|
||||||
|
x_freq = ifftshift(x_freq, dim=(-2, -1)) # type: ignore
|
||||||
|
x_filtered = ifftn(x_freq, dim=(-2, -1)).real # type: ignore
|
||||||
|
|
||||||
|
return x_filtered.to(dtype=dtype) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class FreeUBackboneFeatures(fl.Module):
|
||||||
|
def __init__(self, backbone_scale: float) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.backbone_scale = backbone_scale
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
num_half_channels = x.shape[1] // 2
|
||||||
|
x[:, :num_half_channels] = x[:, :num_half_channels] * self.backbone_scale
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class FreeUSkipFeatures(fl.Chain):
|
||||||
|
def __init__(self, n: int, skip_scale: float) -> None:
|
||||||
|
apply_filter: Callable[[Tensor], Tensor] = lambda x: fourier_filter(x, scale=skip_scale)
|
||||||
|
super().__init__(
|
||||||
|
fl.UseContext(context="unet", key="residuals").compose(lambda residuals: residuals[n]),
|
||||||
|
fl.Lambda(apply_filter),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FreeUResidualConcatenator(fl.Concatenate):
|
||||||
|
def __init__(self, n: int, backbone_scale: float, skip_scale: float) -> None:
|
||||||
|
super().__init__(
|
||||||
|
FreeUBackboneFeatures(backbone_scale),
|
||||||
|
FreeUSkipFeatures(n, skip_scale),
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SDFreeUAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||||
|
def __init__(self, target: T, backbone_scales: list[float], skip_scales: list[float]) -> None:
|
||||||
|
assert len(backbone_scales) == len(skip_scales)
|
||||||
|
assert len(backbone_scales) <= len(target.UpBlocks)
|
||||||
|
self.backbone_scales = backbone_scales
|
||||||
|
self.skip_scales = skip_scales
|
||||||
|
with self.setup_adapter(target):
|
||||||
|
super().__init__(target)
|
||||||
|
|
||||||
|
def inject(self: TSDFreeUAdapter, parent: fl.Chain | None = None) -> TSDFreeUAdapter:
|
||||||
|
for n, (backbone_scale, skip_scale) in enumerate(zip(self.backbone_scales, self.skip_scales)):
|
||||||
|
block = self.target.UpBlocks[n]
|
||||||
|
concat = block.ensure_find(ResidualConcatenator)
|
||||||
|
block.replace(concat, FreeUResidualConcatenator(-n - 2, backbone_scale, skip_scale))
|
||||||
|
return super().inject(parent)
|
||||||
|
|
||||||
|
def eject(self) -> None:
|
||||||
|
for n in range(len(self.backbone_scales)):
|
||||||
|
block = self.target.UpBlocks[n]
|
||||||
|
concat = block.ensure_find(FreeUResidualConcatenator)
|
||||||
|
block.replace(concat, ResidualConcatenator(-n - 2))
|
||||||
|
super().eject()
|
@ -0,0 +1,465 @@
|
|||||||
|
import math
|
||||||
|
from enum import IntEnum
|
||||||
|
from functools import partial
|
||||||
|
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar
|
||||||
|
|
||||||
|
from jaxtyping import Float
|
||||||
|
from PIL import Image
|
||||||
|
from torch import Tensor, cat, device as Device, dtype as DType, softmax, zeros_like
|
||||||
|
|
||||||
|
import imaginairy.vendored.refiners.fluxion.layers as fl
|
||||||
|
from imaginairy.vendored.refiners.fluxion.adapters.adapter import Adapter
|
||||||
|
from imaginairy.vendored.refiners.fluxion.adapters.lora import Lora
|
||||||
|
from imaginairy.vendored.refiners.fluxion.context import Contexts
|
||||||
|
from imaginairy.vendored.refiners.fluxion.layers.attentions import ScaledDotProductAttention
|
||||||
|
from imaginairy.vendored.refiners.fluxion.utils import image_to_tensor, normalize
|
||||||
|
from imaginairy.vendored.refiners.foundationals.clip.image_encoder import CLIPImageEncoderH
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
|
||||||
|
|
||||||
|
T = TypeVar("T", bound="SD1UNet | SDXLUNet")
|
||||||
|
TIPAdapter = TypeVar("TIPAdapter", bound="IPAdapter[Any]") # Self (see PEP 673)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageProjection(fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
clip_image_embedding_dim: int = 1024,
|
||||||
|
clip_text_embedding_dim: int = 768,
|
||||||
|
num_tokens: int = 4,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.clip_image_embedding_dim = clip_image_embedding_dim
|
||||||
|
self.clip_text_embedding_dim = clip_text_embedding_dim
|
||||||
|
self.num_tokens = num_tokens
|
||||||
|
super().__init__(
|
||||||
|
fl.Linear(
|
||||||
|
in_features=clip_image_embedding_dim,
|
||||||
|
out_features=clip_text_embedding_dim * num_tokens,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
fl.Reshape(num_tokens, clip_text_embedding_dim),
|
||||||
|
fl.LayerNorm(normalized_shape=clip_text_embedding_dim, device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int,
|
||||||
|
feedforward_dim: int,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.feedforward_dim = feedforward_dim
|
||||||
|
super().__init__(
|
||||||
|
fl.Linear(
|
||||||
|
in_features=self.embedding_dim,
|
||||||
|
out_features=self.feedforward_dim,
|
||||||
|
bias=False,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
fl.GeLU(),
|
||||||
|
fl.Linear(
|
||||||
|
in_features=self.feedforward_dim,
|
||||||
|
out_features=self.embedding_dim,
|
||||||
|
bias=False,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Adapted from https://github.com/tencent-ailab/IP-Adapter/blob/6212981/ip_adapter/resampler.py
|
||||||
|
# See also:
|
||||||
|
# - https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
|
||||||
|
# - https://github.com/lucidrains/flamingo-pytorch
|
||||||
|
class PerceiverScaledDotProductAttention(fl.Module):
|
||||||
|
def __init__(self, head_dim: int, num_heads: int) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
# See https://github.com/tencent-ailab/IP-Adapter/blob/6212981/ip_adapter/resampler.py#L69
|
||||||
|
# -> "More stable with f16 than dividing afterwards"
|
||||||
|
self.scale = 1 / math.sqrt(math.sqrt(head_dim))
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
key_value: Float[Tensor, "batch sequence_length 2*head_dim*num_heads"],
|
||||||
|
query: Float[Tensor, "batch num_tokens head_dim*num_heads"],
|
||||||
|
) -> Float[Tensor, "batch num_tokens head_dim*num_heads"]:
|
||||||
|
bs, length, _ = query.shape
|
||||||
|
key, value = key_value.chunk(2, dim=-1)
|
||||||
|
|
||||||
|
q = self.reshape_tensor(query)
|
||||||
|
k = self.reshape_tensor(key)
|
||||||
|
v = self.reshape_tensor(value)
|
||||||
|
|
||||||
|
attention = (q * self.scale) @ (k * self.scale).transpose(-2, -1)
|
||||||
|
attention = softmax(input=attention.float(), dim=-1).type(attention.dtype)
|
||||||
|
attention = attention @ v
|
||||||
|
|
||||||
|
return attention.permute(0, 2, 1, 3).reshape(bs, length, -1)
|
||||||
|
|
||||||
|
def reshape_tensor(
|
||||||
|
self, x: Float[Tensor, "batch length head_dim*num_heads"]
|
||||||
|
) -> Float[Tensor, "batch num_heads length head_dim"]:
|
||||||
|
bs, length, _ = x.shape
|
||||||
|
x = x.view(bs, length, self.num_heads, -1)
|
||||||
|
x = x.transpose(1, 2)
|
||||||
|
x = x.reshape(bs, self.num_heads, length, -1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class PerceiverAttention(fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int,
|
||||||
|
head_dim: int = 64,
|
||||||
|
num_heads: int = 8,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.inner_dim = head_dim * num_heads
|
||||||
|
super().__init__(
|
||||||
|
fl.Distribute(
|
||||||
|
fl.LayerNorm(normalized_shape=self.embedding_dim, device=device, dtype=dtype),
|
||||||
|
fl.LayerNorm(normalized_shape=self.embedding_dim, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.Parallel(
|
||||||
|
fl.Chain(
|
||||||
|
fl.Lambda(func=self.to_kv),
|
||||||
|
fl.Linear(
|
||||||
|
in_features=self.embedding_dim,
|
||||||
|
out_features=2 * self.inner_dim,
|
||||||
|
bias=False,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
), # Wkv
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
fl.GetArg(index=1),
|
||||||
|
fl.Linear(
|
||||||
|
in_features=self.embedding_dim,
|
||||||
|
out_features=self.inner_dim,
|
||||||
|
bias=False,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
), # Wq
|
||||||
|
),
|
||||||
|
),
|
||||||
|
PerceiverScaledDotProductAttention(head_dim=head_dim, num_heads=num_heads),
|
||||||
|
fl.Linear(
|
||||||
|
in_features=self.inner_dim, out_features=self.embedding_dim, bias=False, device=device, dtype=dtype
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_kv(self, x: Tensor, latents: Tensor) -> Tensor:
|
||||||
|
return cat((x, latents), dim=-2)
|
||||||
|
|
||||||
|
|
||||||
|
class LatentsToken(fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self, num_tokens: int, latents_dim: int, device: Device | str | None = None, dtype: DType | None = None
|
||||||
|
) -> None:
|
||||||
|
self.num_tokens = num_tokens
|
||||||
|
self.latents_dim = latents_dim
|
||||||
|
super().__init__(fl.Parameter(num_tokens, latents_dim, device=device, dtype=dtype))
|
||||||
|
|
||||||
|
|
||||||
|
class Transformer(fl.Chain):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerLayer(fl.Chain):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PerceiverResampler(fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
latents_dim: int = 1024,
|
||||||
|
num_attention_layers: int = 8,
|
||||||
|
num_attention_heads: int = 16,
|
||||||
|
head_dim: int = 64,
|
||||||
|
num_tokens: int = 8,
|
||||||
|
input_dim: int = 768,
|
||||||
|
output_dim: int = 1024,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.latents_dim = latents_dim
|
||||||
|
self.num_attention_layers = num_attention_layers
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.num_tokens = num_tokens
|
||||||
|
self.input_dim = input_dim
|
||||||
|
self.output_dim = output_dim
|
||||||
|
self.feedforward_dim = 4 * self.latents_dim
|
||||||
|
super().__init__(
|
||||||
|
fl.Linear(in_features=input_dim, out_features=latents_dim, device=device, dtype=dtype),
|
||||||
|
fl.SetContext(context="perceiver_resampler", key="x"),
|
||||||
|
LatentsToken(num_tokens, latents_dim, device=device, dtype=dtype),
|
||||||
|
Transformer(
|
||||||
|
TransformerLayer(
|
||||||
|
fl.Residual(
|
||||||
|
fl.Parallel(fl.UseContext(context="perceiver_resampler", key="x"), fl.Identity()),
|
||||||
|
PerceiverAttention(
|
||||||
|
embedding_dim=latents_dim,
|
||||||
|
head_dim=head_dim,
|
||||||
|
num_heads=num_attention_heads,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
fl.Residual(
|
||||||
|
fl.LayerNorm(normalized_shape=latents_dim, device=device, dtype=dtype),
|
||||||
|
FeedForward(
|
||||||
|
embedding_dim=latents_dim, feedforward_dim=self.feedforward_dim, device=device, dtype=dtype
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for _ in range(num_attention_layers)
|
||||||
|
),
|
||||||
|
fl.Linear(in_features=latents_dim, out_features=output_dim, device=device, dtype=dtype),
|
||||||
|
fl.LayerNorm(normalized_shape=output_dim, device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_context(self) -> Contexts:
|
||||||
|
return {"perceiver_resampler": {"x": None}}
|
||||||
|
|
||||||
|
|
||||||
|
class _CrossAttnIndex(IntEnum):
|
||||||
|
TXT_CROSS_ATTN = 0 # text cross-attention
|
||||||
|
IMG_CROSS_ATTN = 1 # image cross-attention
|
||||||
|
|
||||||
|
|
||||||
|
class InjectionPoint(fl.Chain):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
target: fl.Attention,
|
||||||
|
text_sequence_length: int = 77,
|
||||||
|
image_sequence_length: int = 4,
|
||||||
|
scale: float = 1.0,
|
||||||
|
) -> None:
|
||||||
|
self.text_sequence_length = text_sequence_length
|
||||||
|
self.image_sequence_length = image_sequence_length
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
with self.setup_adapter(target):
|
||||||
|
super().__init__(
|
||||||
|
fl.Distribute(
|
||||||
|
# Note: the same query is used for image cross-attention as for text cross-attention
|
||||||
|
InjectionPoint(), # Wq
|
||||||
|
fl.Parallel(
|
||||||
|
fl.Chain(
|
||||||
|
fl.Slicing(dim=1, end=text_sequence_length),
|
||||||
|
InjectionPoint(), # Wk
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
fl.Slicing(dim=1, start=text_sequence_length),
|
||||||
|
fl.Linear(
|
||||||
|
in_features=self.target.key_embedding_dim,
|
||||||
|
out_features=self.target.inner_dim,
|
||||||
|
bias=self.target.use_bias,
|
||||||
|
device=target.device,
|
||||||
|
dtype=target.dtype,
|
||||||
|
), # Wk'
|
||||||
|
),
|
||||||
|
),
|
||||||
|
fl.Parallel(
|
||||||
|
fl.Chain(
|
||||||
|
fl.Slicing(dim=1, end=text_sequence_length),
|
||||||
|
InjectionPoint(), # Wv
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
fl.Slicing(dim=1, start=text_sequence_length),
|
||||||
|
fl.Linear(
|
||||||
|
in_features=self.target.key_embedding_dim,
|
||||||
|
out_features=self.target.inner_dim,
|
||||||
|
bias=self.target.use_bias,
|
||||||
|
device=target.device,
|
||||||
|
dtype=target.dtype,
|
||||||
|
), # Wv'
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
fl.Sum(
|
||||||
|
fl.Chain(
|
||||||
|
fl.Lambda(func=partial(self.select_qkv, index=_CrossAttnIndex.TXT_CROSS_ATTN)),
|
||||||
|
ScaledDotProductAttention(num_heads=target.num_heads, is_causal=target.is_causal),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
fl.Lambda(func=partial(self.select_qkv, index=_CrossAttnIndex.IMG_CROSS_ATTN)),
|
||||||
|
ScaledDotProductAttention(num_heads=target.num_heads, is_causal=target.is_causal),
|
||||||
|
fl.Lambda(func=self.scale_outputs),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
InjectionPoint(), # proj
|
||||||
|
)
|
||||||
|
|
||||||
|
def select_qkv(
|
||||||
|
self, query: Tensor, keys: tuple[Tensor, Tensor], values: tuple[Tensor, Tensor], index: _CrossAttnIndex
|
||||||
|
) -> tuple[Tensor, Tensor, Tensor]:
|
||||||
|
return (query, keys[index.value], values[index.value])
|
||||||
|
|
||||||
|
def scale_outputs(self, x: Tensor) -> Tensor:
|
||||||
|
return x * self.scale
|
||||||
|
|
||||||
|
def _predicate(self, k: type[fl.Module]) -> Callable[[fl.Module, fl.Chain], bool]:
|
||||||
|
def f(m: fl.Module, _: fl.Chain) -> bool:
|
||||||
|
if isinstance(m, Lora): # do not adapt LoRAs
|
||||||
|
raise StopIteration
|
||||||
|
return isinstance(m, k)
|
||||||
|
|
||||||
|
return f
|
||||||
|
|
||||||
|
def _target_linears(self) -> list[fl.Linear]:
|
||||||
|
return [m for m, _ in self.target.walk(self._predicate(fl.Linear)) if isinstance(m, fl.Linear)]
|
||||||
|
|
||||||
|
def inject(self: "CrossAttentionAdapter", parent: fl.Chain | None = None) -> "CrossAttentionAdapter":
|
||||||
|
linears = self._target_linears()
|
||||||
|
assert len(linears) == 4 # Wq, Wk, Wv and Proj
|
||||||
|
|
||||||
|
injection_points = list(self.layers(InjectionPoint))
|
||||||
|
assert len(injection_points) == 4
|
||||||
|
|
||||||
|
for linear, ip in zip(linears, injection_points):
|
||||||
|
ip.append(linear)
|
||||||
|
assert len(ip) == 1
|
||||||
|
|
||||||
|
return super().inject(parent)
|
||||||
|
|
||||||
|
def eject(self) -> None:
|
||||||
|
injection_points = list(self.layers(InjectionPoint))
|
||||||
|
assert len(injection_points) == 4
|
||||||
|
|
||||||
|
for ip in injection_points:
|
||||||
|
ip.pop()
|
||||||
|
assert len(ip) == 0
|
||||||
|
|
||||||
|
super().eject()
|
||||||
|
|
||||||
|
|
||||||
|
class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||||
|
# Prevent PyTorch module registration
|
||||||
|
_clip_image_encoder: list[CLIPImageEncoderH]
|
||||||
|
_grid_image_encoder: list[CLIPImageEncoderH]
|
||||||
|
_image_proj: list[fl.Module]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
target: T,
|
||||||
|
clip_image_encoder: CLIPImageEncoderH,
|
||||||
|
image_proj: fl.Module,
|
||||||
|
scale: float = 1.0,
|
||||||
|
fine_grained: bool = False,
|
||||||
|
weights: dict[str, Tensor] | None = None,
|
||||||
|
) -> None:
|
||||||
|
with self.setup_adapter(target):
|
||||||
|
super().__init__(target)
|
||||||
|
|
||||||
|
self.fine_grained = fine_grained
|
||||||
|
self._clip_image_encoder = [clip_image_encoder]
|
||||||
|
if fine_grained:
|
||||||
|
self._grid_image_encoder = [self.convert_to_grid_features(clip_image_encoder)]
|
||||||
|
self._image_proj = [image_proj]
|
||||||
|
|
||||||
|
self.sub_adapters = [
|
||||||
|
CrossAttentionAdapter(target=cross_attn, scale=scale, image_sequence_length=self.image_proj.num_tokens)
|
||||||
|
for cross_attn in filter(lambda attn: type(attn) != fl.SelfAttention, target.layers(fl.Attention))
|
||||||
|
]
|
||||||
|
|
||||||
|
if weights is not None:
|
||||||
|
image_proj_state_dict: dict[str, Tensor] = {
|
||||||
|
k.removeprefix("image_proj."): v for k, v in weights.items() if k.startswith("image_proj.")
|
||||||
|
}
|
||||||
|
self.image_proj.load_state_dict(image_proj_state_dict)
|
||||||
|
|
||||||
|
for i, cross_attn in enumerate(self.sub_adapters):
|
||||||
|
cross_attn_state_dict: dict[str, Tensor] = {}
|
||||||
|
for k, v in weights.items():
|
||||||
|
prefix = f"ip_adapter.{i:03d}."
|
||||||
|
if not k.startswith(prefix):
|
||||||
|
continue
|
||||||
|
cross_attn_state_dict[k.removeprefix(prefix)] = v
|
||||||
|
|
||||||
|
cross_attn.load_state_dict(state_dict=cross_attn_state_dict)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def clip_image_encoder(self) -> CLIPImageEncoderH:
|
||||||
|
return self._clip_image_encoder[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def grid_image_encoder(self) -> CLIPImageEncoderH:
|
||||||
|
assert hasattr(self, "_grid_image_encoder")
|
||||||
|
return self._grid_image_encoder[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def image_proj(self) -> fl.Module:
|
||||||
|
return self._image_proj[0]
|
||||||
|
|
||||||
|
def inject(self: "TIPAdapter", parent: fl.Chain | None = None) -> "TIPAdapter":
|
||||||
|
for adapter in self.sub_adapters:
|
||||||
|
adapter.inject()
|
||||||
|
return super().inject(parent)
|
||||||
|
|
||||||
|
def eject(self) -> None:
|
||||||
|
for adapter in self.sub_adapters:
|
||||||
|
adapter.eject()
|
||||||
|
super().eject()
|
||||||
|
|
||||||
|
def set_scale(self, scale: float) -> None:
|
||||||
|
for cross_attn in self.sub_adapters:
|
||||||
|
cross_attn.scale = scale
|
||||||
|
|
||||||
|
# These should be concatenated to the CLIP text embedding before setting the UNet context
|
||||||
|
def compute_clip_image_embedding(self, image_prompt: Tensor) -> Tensor:
|
||||||
|
image_encoder = self.clip_image_encoder if not self.fine_grained else self.grid_image_encoder
|
||||||
|
clip_embedding = image_encoder(image_prompt)
|
||||||
|
conditional_embedding = self.image_proj(clip_embedding)
|
||||||
|
if not self.fine_grained:
|
||||||
|
negative_embedding = self.image_proj(zeros_like(clip_embedding))
|
||||||
|
else:
|
||||||
|
# See https://github.com/tencent-ailab/IP-Adapter/blob/d580c50/tutorial_train_plus.py#L351-L352
|
||||||
|
clip_embedding = image_encoder(zeros_like(image_prompt))
|
||||||
|
negative_embedding = self.image_proj(clip_embedding)
|
||||||
|
return cat((negative_embedding, conditional_embedding))
|
||||||
|
|
||||||
|
def preprocess_image(
|
||||||
|
self,
|
||||||
|
image: Image.Image,
|
||||||
|
size: tuple[int, int] = (224, 224),
|
||||||
|
mean: list[float] | None = None,
|
||||||
|
std: list[float] | None = None,
|
||||||
|
) -> Tensor:
|
||||||
|
# Default mean and std are parameters from https://github.com/openai/CLIP
|
||||||
|
return normalize(
|
||||||
|
image_to_tensor(image.resize(size), device=self.target.device, dtype=self.target.dtype),
|
||||||
|
mean=[0.48145466, 0.4578275, 0.40821073] if mean is None else mean,
|
||||||
|
std=[0.26862954, 0.26130258, 0.27577711] if std is None else std,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def convert_to_grid_features(clip_image_encoder: CLIPImageEncoderH) -> CLIPImageEncoderH:
|
||||||
|
encoder_clone = clip_image_encoder.structural_copy()
|
||||||
|
assert isinstance(encoder_clone[-1], fl.Linear) # final proj
|
||||||
|
assert isinstance(encoder_clone[-2], fl.LayerNorm) # final normalization
|
||||||
|
assert isinstance(encoder_clone[-3], fl.Lambda) # pooling (classif token)
|
||||||
|
for _ in range(3):
|
||||||
|
encoder_clone.pop()
|
||||||
|
transfomer_layers = encoder_clone[-1]
|
||||||
|
assert isinstance(transfomer_layers, fl.Chain) and len(transfomer_layers) == 32
|
||||||
|
transfomer_layers.pop()
|
||||||
|
return encoder_clone
|
@ -0,0 +1,146 @@
|
|||||||
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable, Iterator
|
||||||
|
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
import imaginairy.vendored.refiners.fluxion.layers as fl
|
||||||
|
from imaginairy.vendored.refiners.fluxion.adapters.adapter import Adapter
|
||||||
|
from imaginairy.vendored.refiners.fluxion.adapters.lora import Lora, LoraAdapter
|
||||||
|
from imaginairy.vendored.refiners.fluxion.utils import load_from_safetensors, load_metadata_from_safetensors
|
||||||
|
from imaginairy.vendored.refiners.foundationals.clip.text_encoder import FeedForward, TransformerLayer
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion import (
|
||||||
|
CLIPTextEncoderL,
|
||||||
|
LatentDiffusionAutoencoder,
|
||||||
|
SD1UNet,
|
||||||
|
StableDiffusion_1,
|
||||||
|
)
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import Controlnet
|
||||||
|
|
||||||
|
MODELS = ["unet", "text_encoder", "lda"]
|
||||||
|
|
||||||
|
|
||||||
|
class LoraTarget(str, Enum):
|
||||||
|
Self = "self"
|
||||||
|
Attention = "Attention"
|
||||||
|
SelfAttention = "SelfAttention"
|
||||||
|
CrossAttention = "CrossAttentionBlock2d"
|
||||||
|
FeedForward = "FeedForward"
|
||||||
|
TransformerLayer = "TransformerLayer"
|
||||||
|
|
||||||
|
def get_class(self) -> type[fl.Chain]:
|
||||||
|
match self:
|
||||||
|
case LoraTarget.Self:
|
||||||
|
return fl.Chain
|
||||||
|
case LoraTarget.Attention:
|
||||||
|
return fl.Attention
|
||||||
|
case LoraTarget.SelfAttention:
|
||||||
|
return fl.SelfAttention
|
||||||
|
case LoraTarget.CrossAttention:
|
||||||
|
return CrossAttentionBlock2d
|
||||||
|
case LoraTarget.FeedForward:
|
||||||
|
return FeedForward
|
||||||
|
case LoraTarget.TransformerLayer:
|
||||||
|
return TransformerLayer
|
||||||
|
|
||||||
|
|
||||||
|
def _predicate(k: type[fl.Module]) -> Callable[[fl.Module, fl.Chain], bool]:
|
||||||
|
def f(m: fl.Module, _: fl.Chain) -> bool:
|
||||||
|
if isinstance(m, Lora): # do not adapt other LoRAs
|
||||||
|
raise StopIteration
|
||||||
|
if isinstance(m, Controlnet): # do not adapt Controlnet linears
|
||||||
|
raise StopIteration
|
||||||
|
return isinstance(m, k)
|
||||||
|
|
||||||
|
return f
|
||||||
|
|
||||||
|
|
||||||
|
def _iter_linears(module: fl.Chain) -> Iterator[tuple[fl.Linear, fl.Chain]]:
|
||||||
|
for m, p in module.walk(_predicate(fl.Linear)):
|
||||||
|
assert isinstance(m, fl.Linear)
|
||||||
|
yield (m, p)
|
||||||
|
|
||||||
|
|
||||||
|
def lora_targets(
|
||||||
|
module: fl.Chain,
|
||||||
|
target: LoraTarget | list[LoraTarget],
|
||||||
|
) -> Iterator[tuple[fl.Linear, fl.Chain]]:
|
||||||
|
if isinstance(target, list):
|
||||||
|
for t in target:
|
||||||
|
yield from lora_targets(module, t)
|
||||||
|
return
|
||||||
|
|
||||||
|
if target == LoraTarget.Self:
|
||||||
|
yield from _iter_linears(module)
|
||||||
|
return
|
||||||
|
|
||||||
|
for layer, _ in module.walk(_predicate(target.get_class())):
|
||||||
|
assert isinstance(layer, fl.Chain)
|
||||||
|
yield from _iter_linears(layer)
|
||||||
|
|
||||||
|
|
||||||
|
class SD1LoraAdapter(fl.Chain, Adapter[StableDiffusion_1]):
|
||||||
|
metadata: dict[str, str] | None
|
||||||
|
tensors: dict[str, Tensor]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
target: StableDiffusion_1,
|
||||||
|
sub_targets: dict[str, list[LoraTarget]],
|
||||||
|
scale: float = 1.0,
|
||||||
|
weights: dict[str, Tensor] | None = None,
|
||||||
|
):
|
||||||
|
with self.setup_adapter(target):
|
||||||
|
super().__init__(target)
|
||||||
|
|
||||||
|
self.sub_adapters: list[LoraAdapter[SD1UNet | CLIPTextEncoderL | LatentDiffusionAutoencoder]] = []
|
||||||
|
|
||||||
|
for model_name in MODELS:
|
||||||
|
if not (model_targets := sub_targets.get(model_name, [])):
|
||||||
|
continue
|
||||||
|
model = getattr(target, "clip_text_encoder" if model_name == "text_encoder" else model_name)
|
||||||
|
|
||||||
|
lora_weights = [weights[k] for k in sorted(weights) if k.startswith(model_name)] if weights else None
|
||||||
|
self.sub_adapters.append(
|
||||||
|
LoraAdapter[type(model)](
|
||||||
|
model,
|
||||||
|
sub_targets=lora_targets(model, model_targets),
|
||||||
|
scale=scale,
|
||||||
|
weights=lora_weights,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_safetensors(
|
||||||
|
cls,
|
||||||
|
target: StableDiffusion_1,
|
||||||
|
checkpoint_path: Path | str,
|
||||||
|
scale: float = 1.0,
|
||||||
|
):
|
||||||
|
metadata = load_metadata_from_safetensors(checkpoint_path)
|
||||||
|
assert metadata is not None, "Invalid safetensors checkpoint: missing metadata"
|
||||||
|
tensors = load_from_safetensors(checkpoint_path, device=target.device)
|
||||||
|
|
||||||
|
sub_targets: dict[str, list[LoraTarget]] = {}
|
||||||
|
for model_name in MODELS:
|
||||||
|
if not (v := metadata.get(f"{model_name}_targets", "")):
|
||||||
|
continue
|
||||||
|
sub_targets[model_name] = [LoraTarget(x) for x in v.split(",")]
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
target,
|
||||||
|
sub_targets,
|
||||||
|
scale=scale,
|
||||||
|
weights=tensors,
|
||||||
|
)
|
||||||
|
|
||||||
|
def inject(self: "SD1LoraAdapter", parent: fl.Chain | None = None) -> "SD1LoraAdapter":
|
||||||
|
for adapter in self.sub_adapters:
|
||||||
|
adapter.inject()
|
||||||
|
return super().inject(parent)
|
||||||
|
|
||||||
|
def eject(self) -> None:
|
||||||
|
for adapter in self.sub_adapters:
|
||||||
|
adapter.eject()
|
||||||
|
super().eject()
|
@ -0,0 +1,115 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import TypeVar
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
|
|
||||||
|
import imaginairy.vendored.refiners.fluxion.layers as fl
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
||||||
|
|
||||||
|
T = TypeVar("T", bound="fl.Module")
|
||||||
|
|
||||||
|
|
||||||
|
TLatentDiffusionModel = TypeVar("TLatentDiffusionModel", bound="LatentDiffusionModel")
|
||||||
|
|
||||||
|
|
||||||
|
class LatentDiffusionModel(fl.Module, ABC):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
unet: fl.Module,
|
||||||
|
lda: LatentDiffusionAutoencoder,
|
||||||
|
clip_text_encoder: fl.Module,
|
||||||
|
scheduler: Scheduler,
|
||||||
|
device: Device | str = "cpu",
|
||||||
|
dtype: DType = torch.float32,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.device: Device = device if isinstance(device, Device) else Device(device=device)
|
||||||
|
self.dtype = dtype
|
||||||
|
self.unet = unet.to(device=self.device, dtype=self.dtype)
|
||||||
|
self.lda = lda.to(device=self.device, dtype=self.dtype)
|
||||||
|
self.clip_text_encoder = clip_text_encoder.to(device=self.device, dtype=self.dtype)
|
||||||
|
self.scheduler = scheduler.to(device=self.device, dtype=self.dtype)
|
||||||
|
|
||||||
|
def set_num_inference_steps(self, num_inference_steps: int) -> None:
|
||||||
|
initial_diffusion_rate = self.scheduler.initial_diffusion_rate
|
||||||
|
final_diffusion_rate = self.scheduler.final_diffusion_rate
|
||||||
|
device, dtype = self.scheduler.device, self.scheduler.dtype
|
||||||
|
self.scheduler = self.scheduler.__class__(
|
||||||
|
num_inference_steps,
|
||||||
|
initial_diffusion_rate=initial_diffusion_rate,
|
||||||
|
final_diffusion_rate=final_diffusion_rate,
|
||||||
|
).to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def init_latents(
|
||||||
|
self,
|
||||||
|
size: tuple[int, int],
|
||||||
|
init_image: Image.Image | None = None,
|
||||||
|
first_step: int = 0,
|
||||||
|
noise: Tensor | None = None,
|
||||||
|
) -> Tensor:
|
||||||
|
height, width = size
|
||||||
|
if noise is None:
|
||||||
|
noise = torch.randn(1, 4, height // 8, width // 8, device=self.device)
|
||||||
|
assert list(noise.shape[2:]) == [
|
||||||
|
height // 8,
|
||||||
|
width // 8,
|
||||||
|
], f"noise shape is not compatible: {noise.shape}, with size: {size}"
|
||||||
|
if init_image is None:
|
||||||
|
return noise
|
||||||
|
encoded_image = self.lda.encode_image(image=init_image.resize(size=(width, height)))
|
||||||
|
return self.scheduler.add_noise(x=encoded_image, noise=noise, step=self.steps[first_step])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def steps(self) -> list[int]:
|
||||||
|
return self.scheduler.steps
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def set_self_attention_guidance(self, enable: bool, scale: float = 1.0) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def has_self_attention_guidance(self) -> bool:
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def compute_self_attention_guidance(
|
||||||
|
self, x: Tensor, noise: Tensor, step: int, *, clip_text_embedding: Tensor, **kwargs: Tensor
|
||||||
|
) -> Tensor:
|
||||||
|
...
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5, **kwargs: Tensor
|
||||||
|
) -> Tensor:
|
||||||
|
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
|
||||||
|
self.set_unet_context(timestep=timestep, clip_text_embedding=clip_text_embedding, **kwargs)
|
||||||
|
|
||||||
|
latents = torch.cat(tensors=(x, x)) # for classifier-free guidance
|
||||||
|
unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2)
|
||||||
|
|
||||||
|
# classifier-free guidance
|
||||||
|
noise = unconditional_prediction + condition_scale * (conditional_prediction - unconditional_prediction)
|
||||||
|
x = x.narrow(dim=1, start=0, length=4) # support > 4 channels for inpainting
|
||||||
|
|
||||||
|
if self.has_self_attention_guidance():
|
||||||
|
noise += self.compute_self_attention_guidance(
|
||||||
|
x=x, noise=unconditional_prediction, step=step, clip_text_embedding=clip_text_embedding, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.scheduler(x, noise=noise, step=step)
|
||||||
|
|
||||||
|
def structural_copy(self: TLatentDiffusionModel) -> TLatentDiffusionModel:
|
||||||
|
return self.__class__(
|
||||||
|
unet=self.unet.structural_copy(),
|
||||||
|
lda=self.lda.structural_copy(),
|
||||||
|
clip_text_encoder=self.clip_text_encoder.structural_copy(),
|
||||||
|
scheduler=self.scheduler,
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.dtype,
|
||||||
|
)
|
@ -0,0 +1,98 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
|
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
|
||||||
|
|
||||||
|
MAX_STEPS = 1000
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DiffusionTarget:
|
||||||
|
size: tuple[int, int]
|
||||||
|
offset: tuple[int, int]
|
||||||
|
clip_text_embedding: Tensor
|
||||||
|
init_latents: Tensor | None = None
|
||||||
|
mask_latent: Tensor | None = None
|
||||||
|
weight: int = 1
|
||||||
|
condition_scale: float = 7.5
|
||||||
|
start_step: int = 0
|
||||||
|
end_step: int = MAX_STEPS
|
||||||
|
|
||||||
|
def crop(self, tensor: Tensor, /) -> Tensor:
|
||||||
|
height, width = self.size
|
||||||
|
top_offset, left_offset = self.offset
|
||||||
|
return tensor[:, :, top_offset : top_offset + height, left_offset : left_offset + width]
|
||||||
|
|
||||||
|
def paste(self, tensor: Tensor, /, crop: Tensor) -> Tensor:
|
||||||
|
height, width = self.size
|
||||||
|
top_offset, left_offset = self.offset
|
||||||
|
tensor[:, :, top_offset : top_offset + height, left_offset : left_offset + width] = crop
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=LatentDiffusionModel)
|
||||||
|
D = TypeVar("D", bound=DiffusionTarget)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MultiDiffusion(Generic[T, D], ABC):
|
||||||
|
ldm: T
|
||||||
|
|
||||||
|
def __call__(self, x: Tensor, /, noise: Tensor, step: int, targets: list[D]) -> Tensor:
|
||||||
|
num_updates = torch.zeros_like(input=x)
|
||||||
|
cumulative_values = torch.zeros_like(input=x)
|
||||||
|
|
||||||
|
for target in targets:
|
||||||
|
match step:
|
||||||
|
case step if step == target.start_step and target.init_latents is not None:
|
||||||
|
noise_view = target.crop(noise)
|
||||||
|
view = self.ldm.scheduler.add_noise(
|
||||||
|
x=target.init_latents,
|
||||||
|
noise=noise_view,
|
||||||
|
step=step,
|
||||||
|
)
|
||||||
|
case step if target.start_step <= step <= target.end_step:
|
||||||
|
view = target.crop(x)
|
||||||
|
case _:
|
||||||
|
continue
|
||||||
|
view = self.diffuse_target(x=view, step=step, target=target)
|
||||||
|
weight = target.weight * target.mask_latent if target.mask_latent is not None else target.weight
|
||||||
|
num_updates = target.paste(num_updates, crop=target.crop(num_updates) + weight)
|
||||||
|
cumulative_values = target.paste(cumulative_values, crop=target.crop(cumulative_values) + weight * view)
|
||||||
|
|
||||||
|
return torch.where(condition=num_updates > 0, input=cumulative_values / num_updates, other=x)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def diffuse_target(self, x: Tensor, step: int, target: D) -> Tensor:
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def steps(self) -> list[int]:
|
||||||
|
return self.ldm.steps
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> Device:
|
||||||
|
return self.ldm.device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self) -> DType:
|
||||||
|
return self.ldm.dtype
|
||||||
|
|
||||||
|
def decode_latents(self, x: Tensor) -> Image.Image:
|
||||||
|
return self.ldm.lda.decode_latents(x=x)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_offset_grid(size: tuple[int, int], stride: int = 8) -> list[tuple[int, int]]:
|
||||||
|
height, width = size
|
||||||
|
|
||||||
|
return [
|
||||||
|
(y, x)
|
||||||
|
for y in range(0, height, stride)
|
||||||
|
for x in range(0, width, stride)
|
||||||
|
if y + 64 <= height and x + 64 <= width
|
||||||
|
]
|
@ -0,0 +1,107 @@
|
|||||||
|
# Adapted from https://github.com/carolineec/informative-drawings, MIT License
|
||||||
|
|
||||||
|
from torch import device as Device, dtype as DType
|
||||||
|
|
||||||
|
import imaginairy.vendored.refiners.fluxion.layers as fl
|
||||||
|
|
||||||
|
|
||||||
|
class InformativeDrawings(fl.Chain):
|
||||||
|
"""Model typically used as the preprocessor for the Lineart ControlNet.
|
||||||
|
|
||||||
|
Implements the paper "Learning to generate line drawings that convey
|
||||||
|
geometry and semantics" published in 2022 by Caroline Chan, Frédo Durand
|
||||||
|
and Phillip Isola - https://arxiv.org/abs/2203.12691
|
||||||
|
|
||||||
|
For use as a preprocessor it is recommended to use the weights for "Style 2".
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int = 3, # RGB
|
||||||
|
out_channels: int = 1, # Grayscale
|
||||||
|
n_residual_blocks: int = 3,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
fl.Chain( # Initial convolution
|
||||||
|
fl.ReflectionPad2d(3),
|
||||||
|
fl.Conv2d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=64,
|
||||||
|
kernel_size=7,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
fl.InstanceNorm2d(64, device=device, dtype=dtype),
|
||||||
|
fl.ReLU(),
|
||||||
|
),
|
||||||
|
*( # Downsampling
|
||||||
|
fl.Chain(
|
||||||
|
fl.Conv2d(
|
||||||
|
in_channels=64 * (2**i),
|
||||||
|
out_channels=128 * (2**i),
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
fl.InstanceNorm2d(128 * (2**i), device=device, dtype=dtype),
|
||||||
|
fl.ReLU(),
|
||||||
|
)
|
||||||
|
for i in range(2)
|
||||||
|
),
|
||||||
|
*( # Residual blocks
|
||||||
|
fl.Residual(
|
||||||
|
fl.ReflectionPad2d(1),
|
||||||
|
fl.Conv2d(
|
||||||
|
in_channels=256,
|
||||||
|
out_channels=256,
|
||||||
|
kernel_size=3,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
fl.InstanceNorm2d(256, device=device, dtype=dtype),
|
||||||
|
fl.ReLU(),
|
||||||
|
fl.ReflectionPad2d(1),
|
||||||
|
fl.Conv2d(
|
||||||
|
in_channels=256,
|
||||||
|
out_channels=256,
|
||||||
|
kernel_size=3,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
fl.InstanceNorm2d(256, device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
for _ in range(n_residual_blocks)
|
||||||
|
),
|
||||||
|
*( # Upsampling
|
||||||
|
fl.Chain(
|
||||||
|
fl.ConvTranspose2d(
|
||||||
|
in_channels=128 * (2**i),
|
||||||
|
out_channels=64 * (2**i),
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
output_padding=1,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
fl.InstanceNorm2d(64 * (2**i), device=device, dtype=dtype),
|
||||||
|
fl.ReLU(),
|
||||||
|
)
|
||||||
|
for i in reversed(range(2))
|
||||||
|
),
|
||||||
|
fl.Chain( # Output layer
|
||||||
|
fl.ReflectionPad2d(3),
|
||||||
|
fl.Conv2d(
|
||||||
|
in_channels=64,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=7,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
fl.Sigmoid(),
|
||||||
|
),
|
||||||
|
)
|
@ -0,0 +1,68 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
|
from jaxtyping import Float, Int
|
||||||
|
from torch import Tensor, arange, cat, cos, device as Device, dtype as DType, exp, float32, sin
|
||||||
|
|
||||||
|
import imaginairy.vendored.refiners.fluxion.layers as fl
|
||||||
|
from imaginairy.vendored.refiners.fluxion.adapters.adapter import Adapter
|
||||||
|
|
||||||
|
|
||||||
|
def compute_sinusoidal_embedding(
|
||||||
|
x: Int[Tensor, "*batch 1"],
|
||||||
|
embedding_dim: int,
|
||||||
|
) -> Float[Tensor, "*batch 1 embedding_dim"]:
|
||||||
|
half_dim = embedding_dim // 2
|
||||||
|
# Note: it is important that this computation is done in float32.
|
||||||
|
# The result can be cast to lower precision later if necessary.
|
||||||
|
exponent = -math.log(10000) * arange(start=0, end=half_dim, dtype=float32, device=x.device)
|
||||||
|
exponent /= half_dim
|
||||||
|
embedding = x.unsqueeze(1).float() * exp(exponent).unsqueeze(0)
|
||||||
|
embedding = cat([cos(embedding), sin(embedding)], dim=-1)
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
|
||||||
|
class RangeEncoder(fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sinuosidal_embedding_dim: int,
|
||||||
|
embedding_dim: int,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.sinuosidal_embedding_dim = sinuosidal_embedding_dim
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
super().__init__(
|
||||||
|
fl.Lambda(self.compute_sinuosoidal_embedding),
|
||||||
|
fl.Converter(set_device=False, set_dtype=True),
|
||||||
|
fl.Linear(in_features=sinuosidal_embedding_dim, out_features=embedding_dim, device=device, dtype=dtype),
|
||||||
|
fl.SiLU(),
|
||||||
|
fl.Linear(in_features=embedding_dim, out_features=embedding_dim, device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
def compute_sinuosoidal_embedding(self, x: Int[Tensor, "*batch 1"]) -> Float[Tensor, "*batch 1 embedding_dim"]:
|
||||||
|
return compute_sinusoidal_embedding(x, embedding_dim=self.sinuosidal_embedding_dim)
|
||||||
|
|
||||||
|
|
||||||
|
class RangeAdapter2d(fl.Sum, Adapter[fl.Conv2d]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
target: fl.Conv2d,
|
||||||
|
channels: int,
|
||||||
|
embedding_dim: int,
|
||||||
|
context_key: str,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.channels = channels
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.context_key = context_key
|
||||||
|
with self.setup_adapter(target):
|
||||||
|
super().__init__(
|
||||||
|
target,
|
||||||
|
fl.Chain(
|
||||||
|
fl.UseContext("range_adapter", context_key),
|
||||||
|
fl.SiLU(),
|
||||||
|
fl.Linear(in_features=embedding_dim, out_features=channels, device=device, dtype=dtype),
|
||||||
|
fl.View(-1, channels, 1, 1),
|
||||||
|
),
|
||||||
|
)
|
@ -0,0 +1,143 @@
|
|||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from imaginairy.vendored.refiners.fluxion.adapters.adapter import Adapter
|
||||||
|
from imaginairy.vendored.refiners.fluxion.layers import (
|
||||||
|
Chain,
|
||||||
|
Concatenate,
|
||||||
|
Identity,
|
||||||
|
Lambda,
|
||||||
|
Parallel,
|
||||||
|
Passthrough,
|
||||||
|
SelfAttention,
|
||||||
|
SetContext,
|
||||||
|
UseContext,
|
||||||
|
)
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
||||||
|
|
||||||
|
|
||||||
|
class SaveLayerNormAdapter(Chain, Adapter[SelfAttention]):
|
||||||
|
def __init__(self, target: SelfAttention, context: str) -> None:
|
||||||
|
self.context = context
|
||||||
|
with self.setup_adapter(target):
|
||||||
|
super().__init__(SetContext(self.context, "norm"), target)
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAttentionInjectionAdapter(Chain, Adapter[SelfAttention]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
target: SelfAttention,
|
||||||
|
context: str,
|
||||||
|
style_cfg: float = 0.5,
|
||||||
|
) -> None:
|
||||||
|
self.context = context
|
||||||
|
self.style_cfg = style_cfg
|
||||||
|
|
||||||
|
sa_guided = target.structural_copy()
|
||||||
|
assert isinstance(sa_guided[0], Parallel)
|
||||||
|
sa_guided.replace(
|
||||||
|
sa_guided[0],
|
||||||
|
Parallel(
|
||||||
|
Identity(),
|
||||||
|
Concatenate(Identity(), UseContext(self.context, "norm"), dim=1),
|
||||||
|
Concatenate(Identity(), UseContext(self.context, "norm"), dim=1),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
with self.setup_adapter(target):
|
||||||
|
slice_tensor: Callable[[Tensor], Tensor] = lambda x: x[:1]
|
||||||
|
super().__init__(
|
||||||
|
Parallel(sa_guided, Chain(Lambda(slice_tensor), target)),
|
||||||
|
Lambda(self.compute_averaged_unconditioned_x),
|
||||||
|
)
|
||||||
|
|
||||||
|
def compute_averaged_unconditioned_x(self, x: Tensor, unguided_unconditioned_x: Tensor) -> Tensor:
|
||||||
|
x[0] = self.style_cfg * x[0] + (1.0 - self.style_cfg) * unguided_unconditioned_x
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAttentionInjectionPassthrough(Passthrough):
|
||||||
|
def __init__(self, target: SD1UNet) -> None:
|
||||||
|
guide_unet = target.structural_copy()
|
||||||
|
for i, attention_block in enumerate(guide_unet.layers(CrossAttentionBlock)):
|
||||||
|
sa = attention_block.ensure_find(SelfAttention)
|
||||||
|
assert sa.parent is not None
|
||||||
|
SaveLayerNormAdapter(sa, context=f"self_attention_context_{i}").inject()
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
Lambda(self._copy_diffusion_context),
|
||||||
|
UseContext("reference_only_control", "guide"),
|
||||||
|
guide_unet,
|
||||||
|
Lambda(self._restore_diffusion_context),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _copy_diffusion_context(self, x: Tensor) -> Tensor:
|
||||||
|
# This function allows to not disrupt the accumulation of residuals in the unet (if controlnet are used)
|
||||||
|
self.set_context(
|
||||||
|
"self_attention_residuals_buffer",
|
||||||
|
{"buffer": self.use_context("unet")["residuals"]},
|
||||||
|
)
|
||||||
|
self.set_context(
|
||||||
|
"unet",
|
||||||
|
{"residuals": [0.0] * 13},
|
||||||
|
)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _restore_diffusion_context(self, x: Tensor) -> Tensor:
|
||||||
|
self.set_context(
|
||||||
|
"unet",
|
||||||
|
{
|
||||||
|
"residuals": self.use_context("self_attention_residuals_buffer")["buffer"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ReferenceOnlyControlAdapter(Chain, Adapter[SD1UNet]):
|
||||||
|
# TODO: Does not support batching yet. Assumes concatenated inputs for classifier-free guidance
|
||||||
|
|
||||||
|
def __init__(self, target: SD1UNet, style_cfg: float = 0.5) -> None:
|
||||||
|
# the style_cfg is the weight of the guide in unconditionned diffusion.
|
||||||
|
# This value is recommended to be 0.5 on the sdwebui repo.
|
||||||
|
|
||||||
|
self.sub_adapters: list[SelfAttentionInjectionAdapter] = []
|
||||||
|
self._passthrough: list[SelfAttentionInjectionPassthrough] = [
|
||||||
|
SelfAttentionInjectionPassthrough(target)
|
||||||
|
] # not registered by PyTorch
|
||||||
|
|
||||||
|
with self.setup_adapter(target):
|
||||||
|
super().__init__(target)
|
||||||
|
|
||||||
|
for i, attention_block in enumerate(target.layers(CrossAttentionBlock)):
|
||||||
|
self.set_context(f"self_attention_context_{i}", {"norm": None})
|
||||||
|
|
||||||
|
sa = attention_block.ensure_find(SelfAttention)
|
||||||
|
assert sa.parent is not None
|
||||||
|
|
||||||
|
self.sub_adapters.append(
|
||||||
|
SelfAttentionInjectionAdapter(sa, context=f"self_attention_context_{i}", style_cfg=style_cfg)
|
||||||
|
)
|
||||||
|
|
||||||
|
def inject(self: "ReferenceOnlyControlAdapter", parent: Chain | None = None) -> "ReferenceOnlyControlAdapter":
|
||||||
|
passthrough = self._passthrough[0]
|
||||||
|
assert passthrough not in self.target, f"{passthrough} is already injected"
|
||||||
|
for adapter in self.sub_adapters:
|
||||||
|
adapter.inject()
|
||||||
|
self.target.insert(0, passthrough)
|
||||||
|
return super().inject(parent)
|
||||||
|
|
||||||
|
def eject(self) -> None:
|
||||||
|
passthrough = self._passthrough[0]
|
||||||
|
assert self.target[0] == passthrough, f"{passthrough} is not the first element of target UNet"
|
||||||
|
for adapter in self.sub_adapters:
|
||||||
|
adapter.eject()
|
||||||
|
self.target.pop(0)
|
||||||
|
super().eject()
|
||||||
|
|
||||||
|
def set_controlnet_condition(self, condition: Tensor) -> None:
|
||||||
|
self.set_context("reference_only_control", {"guide": condition})
|
||||||
|
|
||||||
|
def structural_copy(self: "ReferenceOnlyControlAdapter") -> "ReferenceOnlyControlAdapter":
|
||||||
|
raise RuntimeError("ReferenceOnlyControlAdapter cannot be copied, eject it first.")
|
@ -0,0 +1,110 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from functools import cached_property
|
||||||
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=LatentDiffusionModel)
|
||||||
|
|
||||||
|
|
||||||
|
def add_noise_interval(
|
||||||
|
scheduler: Scheduler,
|
||||||
|
/,
|
||||||
|
x: torch.Tensor,
|
||||||
|
noise: torch.Tensor,
|
||||||
|
initial_timestep: torch.Tensor,
|
||||||
|
target_timestep: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
initial_cumulative_scale_factors = scheduler.cumulative_scale_factors[initial_timestep]
|
||||||
|
target_cumulative_scale_factors = scheduler.cumulative_scale_factors[target_timestep]
|
||||||
|
|
||||||
|
factor = target_cumulative_scale_factors / initial_cumulative_scale_factors
|
||||||
|
noised_x = factor * x + torch.sqrt(1 - factor**2) * noise
|
||||||
|
return noised_x
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Restart(Generic[T]):
|
||||||
|
"""
|
||||||
|
Implements the restart sampling strategy from the paper "Restart Sampling for Improving Generative Processes"
|
||||||
|
(https://arxiv.org/pdf/2306.14878.pdf)
|
||||||
|
|
||||||
|
Works only with the DDIM scheduler for now.
|
||||||
|
"""
|
||||||
|
|
||||||
|
ldm: T
|
||||||
|
num_steps: int = 10
|
||||||
|
num_iterations: int = 2
|
||||||
|
start_time: float = 0.1
|
||||||
|
end_time: float = 2
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
assert isinstance(self.ldm.scheduler, DDIM), "Restart sampling only works with DDIM scheduler"
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
/,
|
||||||
|
clip_text_embedding: torch.Tensor,
|
||||||
|
condition_scale: float = 7.5,
|
||||||
|
**kwargs: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
original_scheduler = self.ldm.scheduler
|
||||||
|
new_scheduler = DDIM(self.ldm.scheduler.num_inference_steps, device=self.device, dtype=self.dtype)
|
||||||
|
new_scheduler.timesteps = self.timesteps
|
||||||
|
self.ldm.scheduler = new_scheduler
|
||||||
|
|
||||||
|
for _ in range(self.num_iterations):
|
||||||
|
noise = torch.randn_like(input=x, device=self.device, dtype=self.dtype)
|
||||||
|
x = add_noise_interval(
|
||||||
|
new_scheduler,
|
||||||
|
x=x,
|
||||||
|
noise=noise,
|
||||||
|
initial_timestep=self.timesteps[-1],
|
||||||
|
target_timestep=self.timesteps[0],
|
||||||
|
)
|
||||||
|
|
||||||
|
for step in range(len(self.timesteps) - 1):
|
||||||
|
x = self.ldm(
|
||||||
|
x, step=step, clip_text_embedding=clip_text_embedding, condition_scale=condition_scale, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
self.ldm.scheduler = original_scheduler
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def start_step(self) -> int:
|
||||||
|
sigmas = self.ldm.scheduler.noise_std / self.ldm.scheduler.cumulative_scale_factors
|
||||||
|
return int(torch.argmin(input=torch.abs(input=sigmas[self.ldm.scheduler.timesteps] - self.start_time)))
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def end_timestep(self) -> int:
|
||||||
|
sigmas = self.ldm.scheduler.noise_std / self.ldm.scheduler.cumulative_scale_factors
|
||||||
|
return int(torch.argmin(input=torch.abs(input=sigmas - self.end_time)))
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def timesteps(self) -> torch.Tensor:
|
||||||
|
return (
|
||||||
|
torch.round(
|
||||||
|
torch.linspace(
|
||||||
|
start=int(self.ldm.scheduler.timesteps[self.start_step]),
|
||||||
|
end=self.end_timestep,
|
||||||
|
steps=self.num_steps,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.flip(0)
|
||||||
|
.to(device=self.device, dtype=torch.int64)
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> torch.device:
|
||||||
|
return self.ldm.device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self) -> torch.dtype:
|
||||||
|
return self.ldm.dtype
|
@ -0,0 +1,11 @@
|
|||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.ddpm import DDPM
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Scheduler",
|
||||||
|
"DPMSolver",
|
||||||
|
"DDPM",
|
||||||
|
"DDIM",
|
||||||
|
]
|
@ -0,0 +1,57 @@
|
|||||||
|
from torch import Tensor, arange, device as Device, dtype as Dtype, float32, sqrt, tensor
|
||||||
|
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
|
||||||
|
|
||||||
|
|
||||||
|
class DDIM(Scheduler):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_inference_steps: int,
|
||||||
|
num_train_timesteps: int = 1_000,
|
||||||
|
initial_diffusion_rate: float = 8.5e-4,
|
||||||
|
final_diffusion_rate: float = 1.2e-2,
|
||||||
|
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
|
||||||
|
device: Device | str = "cpu",
|
||||||
|
dtype: Dtype = float32,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
num_inference_steps=num_inference_steps,
|
||||||
|
num_train_timesteps=num_train_timesteps,
|
||||||
|
initial_diffusion_rate=initial_diffusion_rate,
|
||||||
|
final_diffusion_rate=final_diffusion_rate,
|
||||||
|
noise_schedule=noise_schedule,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
self.timesteps = self._generate_timesteps()
|
||||||
|
|
||||||
|
def _generate_timesteps(self) -> Tensor:
|
||||||
|
"""
|
||||||
|
Generates decreasing timesteps with 'leading' spacing and offset of 1
|
||||||
|
similar to diffusers settings for the DDIM scheduler in Stable Diffusion 1.5
|
||||||
|
"""
|
||||||
|
step_ratio = self.num_train_timesteps // self.num_inference_steps
|
||||||
|
timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio + 1
|
||||||
|
return timesteps.flip(0)
|
||||||
|
|
||||||
|
def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
|
||||||
|
timestep, previous_timestep = (
|
||||||
|
self.timesteps[step],
|
||||||
|
(
|
||||||
|
self.timesteps[step + 1]
|
||||||
|
if step < self.num_inference_steps - 1
|
||||||
|
else tensor(data=[0], device=self.device, dtype=self.dtype)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
current_scale_factor, previous_scale_factor = (
|
||||||
|
self.cumulative_scale_factors[timestep],
|
||||||
|
(
|
||||||
|
self.cumulative_scale_factors[previous_timestep]
|
||||||
|
if previous_timestep > 0
|
||||||
|
else self.cumulative_scale_factors[0]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
predicted_x = (x - sqrt(1 - current_scale_factor**2) * noise) / current_scale_factor
|
||||||
|
denoised_x = previous_scale_factor * predicted_x + sqrt(1 - previous_scale_factor**2) * noise
|
||||||
|
|
||||||
|
return denoised_x
|
@ -0,0 +1,34 @@
|
|||||||
|
from torch import Tensor, arange, device as Device
|
||||||
|
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
||||||
|
|
||||||
|
|
||||||
|
class DDPM(Scheduler):
|
||||||
|
"""
|
||||||
|
The Denoising Diffusion Probabilistic Models (DDPM) is a specific type of diffusion model,
|
||||||
|
which uses a specific strategy to generate the timesteps and applies the diffusion process in a specific way.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_inference_steps: int,
|
||||||
|
num_train_timesteps: int = 1_000,
|
||||||
|
initial_diffusion_rate: float = 8.5e-4,
|
||||||
|
final_diffusion_rate: float = 1.2e-2,
|
||||||
|
device: Device | str = "cpu",
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
num_inference_steps=num_inference_steps,
|
||||||
|
num_train_timesteps=num_train_timesteps,
|
||||||
|
initial_diffusion_rate=initial_diffusion_rate,
|
||||||
|
final_diffusion_rate=final_diffusion_rate,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _generate_timesteps(self) -> Tensor:
|
||||||
|
step_ratio = self.num_train_timesteps // self.num_inference_steps
|
||||||
|
timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio
|
||||||
|
return timesteps.flip(0)
|
||||||
|
|
||||||
|
def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
|
||||||
|
raise NotImplementedError
|
@ -0,0 +1,117 @@
|
|||||||
|
from collections import deque
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from torch import Tensor, device as Device, dtype as Dtype, exp, float32, tensor
|
||||||
|
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
|
||||||
|
|
||||||
|
|
||||||
|
class DPMSolver(Scheduler):
|
||||||
|
"""Implements DPM-Solver++ from https://arxiv.org/abs/2211.01095
|
||||||
|
|
||||||
|
We only support noise prediction for now.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_inference_steps: int,
|
||||||
|
num_train_timesteps: int = 1_000,
|
||||||
|
initial_diffusion_rate: float = 8.5e-4,
|
||||||
|
final_diffusion_rate: float = 1.2e-2,
|
||||||
|
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
|
||||||
|
device: Device | str = "cpu",
|
||||||
|
dtype: Dtype = float32,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
num_inference_steps=num_inference_steps,
|
||||||
|
num_train_timesteps=num_train_timesteps,
|
||||||
|
initial_diffusion_rate=initial_diffusion_rate,
|
||||||
|
final_diffusion_rate=final_diffusion_rate,
|
||||||
|
noise_schedule=noise_schedule,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
self.estimated_data = deque([tensor([])] * 2, maxlen=2)
|
||||||
|
self.initial_steps = 0
|
||||||
|
|
||||||
|
def _generate_timesteps(self) -> Tensor:
|
||||||
|
# We need to use numpy here because:
|
||||||
|
# numpy.linspace(0,999,31)[15] is 499.49999999999994
|
||||||
|
# torch.linspace(0,999,31)[15] is 499.5
|
||||||
|
# ...and we want the same result as the original codebase.
|
||||||
|
return tensor(
|
||||||
|
np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps + 1).round().astype(int)[1:],
|
||||||
|
device=self.device,
|
||||||
|
).flip(0)
|
||||||
|
|
||||||
|
def dpm_solver_first_order_update(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
|
||||||
|
timestep, previous_timestep = (
|
||||||
|
self.timesteps[step],
|
||||||
|
self.timesteps[step + 1 if step < len(self.timesteps) - 1 else 0],
|
||||||
|
)
|
||||||
|
previous_ratio, current_ratio = (
|
||||||
|
self.signal_to_noise_ratios[previous_timestep],
|
||||||
|
self.signal_to_noise_ratios[timestep],
|
||||||
|
)
|
||||||
|
previous_scale_factor = self.cumulative_scale_factors[previous_timestep]
|
||||||
|
previous_noise_std, current_noise_std = (
|
||||||
|
self.noise_std[previous_timestep],
|
||||||
|
self.noise_std[timestep],
|
||||||
|
)
|
||||||
|
factor = exp(-(previous_ratio - current_ratio)) - 1.0
|
||||||
|
denoised_x = (previous_noise_std / current_noise_std) * x - (factor * previous_scale_factor) * noise
|
||||||
|
return denoised_x
|
||||||
|
|
||||||
|
def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int) -> Tensor:
|
||||||
|
previous_timestep, current_timestep, next_timestep = (
|
||||||
|
self.timesteps[step + 1] if step < len(self.timesteps) - 1 else tensor([0]),
|
||||||
|
self.timesteps[step],
|
||||||
|
self.timesteps[step - 1],
|
||||||
|
)
|
||||||
|
current_data_estimation, next_data_estimation = self.estimated_data[-1], self.estimated_data[-2]
|
||||||
|
previous_ratio, current_ratio, next_ratio = (
|
||||||
|
self.signal_to_noise_ratios[previous_timestep],
|
||||||
|
self.signal_to_noise_ratios[current_timestep],
|
||||||
|
self.signal_to_noise_ratios[next_timestep],
|
||||||
|
)
|
||||||
|
previous_scale_factor = self.cumulative_scale_factors[previous_timestep]
|
||||||
|
previous_std, current_std = (
|
||||||
|
self.noise_std[previous_timestep],
|
||||||
|
self.noise_std[current_timestep],
|
||||||
|
)
|
||||||
|
estimation_delta = (current_data_estimation - next_data_estimation) / (
|
||||||
|
(current_ratio - next_ratio) / (previous_ratio - current_ratio)
|
||||||
|
)
|
||||||
|
factor = exp(-(previous_ratio - current_ratio)) - 1.0
|
||||||
|
denoised_x = (
|
||||||
|
(previous_std / current_std) * x
|
||||||
|
- (factor * previous_scale_factor) * current_data_estimation
|
||||||
|
- 0.5 * (factor * previous_scale_factor) * estimation_delta
|
||||||
|
)
|
||||||
|
return denoised_x
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
noise: Tensor,
|
||||||
|
step: int,
|
||||||
|
) -> Tensor:
|
||||||
|
"""
|
||||||
|
Represents one step of the backward diffusion process that iteratively denoises the input data `x`.
|
||||||
|
|
||||||
|
This method works by estimating the denoised version of `x` and applying either a first-order or second-order
|
||||||
|
backward Euler update, which is a numerical method commonly used to solve ordinary differential equations
|
||||||
|
(ODEs).
|
||||||
|
"""
|
||||||
|
current_timestep = self.timesteps[step]
|
||||||
|
scale_factor, noise_ratio = self.cumulative_scale_factors[current_timestep], self.noise_std[current_timestep]
|
||||||
|
estimated_denoised_data = (x - noise_ratio * noise) / scale_factor
|
||||||
|
self.estimated_data.append(estimated_denoised_data)
|
||||||
|
denoised_x = (
|
||||||
|
self.dpm_solver_first_order_update(x=x, noise=estimated_denoised_data, step=step)
|
||||||
|
if (self.initial_steps == 0)
|
||||||
|
else self.multistep_dpm_solver_second_order_update(x=x, step=step)
|
||||||
|
)
|
||||||
|
if self.initial_steps < 2:
|
||||||
|
self.initial_steps += 1
|
||||||
|
return denoised_x
|
@ -0,0 +1,128 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from enum import Enum
|
||||||
|
from typing import TypeVar
|
||||||
|
|
||||||
|
from torch import Tensor, device as Device, dtype as DType, float32, linspace, log, sqrt
|
||||||
|
|
||||||
|
T = TypeVar("T", bound="Scheduler")
|
||||||
|
|
||||||
|
|
||||||
|
class NoiseSchedule(str, Enum):
|
||||||
|
UNIFORM = "uniform"
|
||||||
|
QUADRATIC = "quadratic"
|
||||||
|
KARRAS = "karras"
|
||||||
|
|
||||||
|
|
||||||
|
class Scheduler(ABC):
|
||||||
|
"""
|
||||||
|
A base class for creating a diffusion model scheduler.
|
||||||
|
|
||||||
|
The Scheduler creates a sequence of noise and scaling factors used in the diffusion process,
|
||||||
|
which gradually transforms the original data distribution into a Gaussian one.
|
||||||
|
|
||||||
|
This process is described using several parameters such as initial and final diffusion rates,
|
||||||
|
and is encapsulated into a `__call__` method that applies a step of the diffusion process.
|
||||||
|
"""
|
||||||
|
|
||||||
|
timesteps: Tensor
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_inference_steps: int,
|
||||||
|
num_train_timesteps: int = 1_000,
|
||||||
|
initial_diffusion_rate: float = 8.5e-4,
|
||||||
|
final_diffusion_rate: float = 1.2e-2,
|
||||||
|
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
|
||||||
|
device: Device | str = "cpu",
|
||||||
|
dtype: DType = float32,
|
||||||
|
):
|
||||||
|
self.device: Device = Device(device)
|
||||||
|
self.dtype: DType = dtype
|
||||||
|
self.num_inference_steps = num_inference_steps
|
||||||
|
self.num_train_timesteps = num_train_timesteps
|
||||||
|
self.initial_diffusion_rate = initial_diffusion_rate
|
||||||
|
self.final_diffusion_rate = final_diffusion_rate
|
||||||
|
self.noise_schedule = noise_schedule
|
||||||
|
self.scale_factors = self.sample_noise_schedule()
|
||||||
|
self.cumulative_scale_factors = sqrt(self.scale_factors.cumprod(dim=0))
|
||||||
|
self.noise_std = sqrt(1.0 - self.scale_factors.cumprod(dim=0))
|
||||||
|
self.signal_to_noise_ratios = log(self.cumulative_scale_factors) - log(self.noise_std)
|
||||||
|
self.timesteps = self._generate_timesteps()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
|
||||||
|
"""
|
||||||
|
Applies a step of the diffusion process to the input tensor `x` using the provided `noise` and `timestep`.
|
||||||
|
|
||||||
|
This method should be overridden by subclasses to implement the specific diffusion process.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _generate_timesteps(self) -> Tensor:
|
||||||
|
"""
|
||||||
|
Generates a tensor of timesteps.
|
||||||
|
|
||||||
|
This method should be overridden by subclasses to provide the specific timesteps for the diffusion process.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def steps(self) -> list[int]:
|
||||||
|
return list(range(self.num_inference_steps))
|
||||||
|
|
||||||
|
def sample_power_distribution(self, power: float = 2, /) -> Tensor:
|
||||||
|
return (
|
||||||
|
linspace(
|
||||||
|
start=self.initial_diffusion_rate ** (1 / power),
|
||||||
|
end=self.final_diffusion_rate ** (1 / power),
|
||||||
|
steps=self.num_train_timesteps,
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.dtype,
|
||||||
|
)
|
||||||
|
** power
|
||||||
|
)
|
||||||
|
|
||||||
|
def sample_noise_schedule(self) -> Tensor:
|
||||||
|
match self.noise_schedule:
|
||||||
|
case "uniform":
|
||||||
|
return 1 - self.sample_power_distribution(1)
|
||||||
|
case "quadratic":
|
||||||
|
return 1 - self.sample_power_distribution(2)
|
||||||
|
case "karras":
|
||||||
|
return 1 - self.sample_power_distribution(7)
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Unknown noise schedule: {self.noise_schedule}")
|
||||||
|
|
||||||
|
def add_noise(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
noise: Tensor,
|
||||||
|
step: int,
|
||||||
|
) -> Tensor:
|
||||||
|
timestep = self.timesteps[step]
|
||||||
|
cumulative_scale_factors = self.cumulative_scale_factors[timestep]
|
||||||
|
noise_stds = self.noise_std[timestep]
|
||||||
|
noised_x = cumulative_scale_factors * x + noise_stds * noise
|
||||||
|
return noised_x
|
||||||
|
|
||||||
|
def remove_noise(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
|
||||||
|
timestep = self.timesteps[step]
|
||||||
|
cumulative_scale_factors = self.cumulative_scale_factors[timestep]
|
||||||
|
noise_stds = self.noise_std[timestep]
|
||||||
|
# See equation (15) from https://arxiv.org/pdf/2006.11239.pdf. Useful to preview progress or for guidance like
|
||||||
|
# in https://arxiv.org/pdf/2210.00939.pdf (self-attention guidance)
|
||||||
|
denoised_x = (x - noise_stds * noise) / cumulative_scale_factors
|
||||||
|
return denoised_x
|
||||||
|
|
||||||
|
def to(self: T, device: Device | str | None = None, dtype: DType | None = None) -> T: # type: ignore
|
||||||
|
if device is not None:
|
||||||
|
self.device = Device(device)
|
||||||
|
self.timesteps = self.timesteps.to(device)
|
||||||
|
if dtype is not None:
|
||||||
|
self.dtype = dtype
|
||||||
|
self.scale_factors = self.scale_factors.to(device, dtype=dtype)
|
||||||
|
self.cumulative_scale_factors = self.cumulative_scale_factors.to(device, dtype=dtype)
|
||||||
|
self.noise_std = self.noise_std.to(device, dtype=dtype)
|
||||||
|
self.signal_to_noise_ratios = self.signal_to_noise_ratios.to(device, dtype=dtype)
|
||||||
|
return self
|
@ -0,0 +1,101 @@
|
|||||||
|
import math
|
||||||
|
from typing import TYPE_CHECKING, Any, Generic, TypeVar
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from jaxtyping import Float
|
||||||
|
from torch import Size, Tensor
|
||||||
|
|
||||||
|
import imaginairy.vendored.refiners.fluxion.layers as fl
|
||||||
|
from imaginairy.vendored.refiners.fluxion.adapters.adapter import Adapter
|
||||||
|
from imaginairy.vendored.refiners.fluxion.context import Contexts
|
||||||
|
from imaginairy.vendored.refiners.fluxion.utils import gaussian_blur, interpolate
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
|
||||||
|
|
||||||
|
T = TypeVar("T", bound="SD1UNet | SDXLUNet")
|
||||||
|
TSAGAdapter = TypeVar("TSAGAdapter", bound="SAGAdapter[Any]") # Self (see PEP 673)
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAttentionMap(fl.Passthrough):
|
||||||
|
def __init__(self, num_heads: int, context_key: str) -> None:
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.context_key = context_key
|
||||||
|
super().__init__(
|
||||||
|
fl.Lambda(func=self.compute_attention_scores),
|
||||||
|
fl.SetContext(context="self_attention_map", key=context_key),
|
||||||
|
)
|
||||||
|
|
||||||
|
def split_to_multi_head(
|
||||||
|
self, x: Float[Tensor, "batch_size sequence_length embedding_dim"]
|
||||||
|
) -> Float[Tensor, "batch_size num_heads sequence_length (embedding_dim//num_heads)"]:
|
||||||
|
assert (
|
||||||
|
len(x.shape) == 3
|
||||||
|
), f"Expected tensor with shape (batch_size sequence_length embedding_dim), got {x.shape}"
|
||||||
|
assert (
|
||||||
|
x.shape[-1] % self.num_heads == 0
|
||||||
|
), f"Embedding dim (x.shape[-1]={x.shape[-1]}) must be divisible by num heads"
|
||||||
|
return x.reshape(x.shape[0], x.shape[1], self.num_heads, x.shape[-1] // self.num_heads).transpose(1, 2)
|
||||||
|
|
||||||
|
def compute_attention_scores(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor:
|
||||||
|
query, key = self.split_to_multi_head(query), self.split_to_multi_head(key)
|
||||||
|
_, _, _, dim = query.shape
|
||||||
|
attention = query @ key.permute(0, 1, 3, 2)
|
||||||
|
attention = attention / math.sqrt(dim)
|
||||||
|
return torch.softmax(input=attention, dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAttentionShape(fl.Passthrough):
|
||||||
|
def __init__(self, context_key: str) -> None:
|
||||||
|
self.context_key = context_key
|
||||||
|
super().__init__(
|
||||||
|
fl.SetContext(context="self_attention_map", key=context_key, callback=self.register_shape),
|
||||||
|
)
|
||||||
|
|
||||||
|
def register_shape(self, shapes: list[Size], x: Tensor) -> None:
|
||||||
|
assert x.ndim == 4, f"Expected 4D tensor, got {x.ndim}D with shape {x.shape}"
|
||||||
|
shapes.append(x.shape[-2:])
|
||||||
|
|
||||||
|
|
||||||
|
class SAGAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||||
|
def __init__(self, target: T, scale: float = 1.0, kernel_size: int = 9, sigma: float = 1.0) -> None:
|
||||||
|
self.scale = scale
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.sigma = sigma
|
||||||
|
with self.setup_adapter(target):
|
||||||
|
super().__init__(target)
|
||||||
|
|
||||||
|
def inject(self: "TSAGAdapter", parent: fl.Chain | None = None) -> "TSAGAdapter":
|
||||||
|
return super().inject(parent)
|
||||||
|
|
||||||
|
def eject(self) -> None:
|
||||||
|
super().eject()
|
||||||
|
|
||||||
|
def compute_sag_mask(
|
||||||
|
self, latents: Float[Tensor, "batch_size channels height width"], classifier_free_guidance: bool = True
|
||||||
|
) -> Float[Tensor, "batch_size channels height width"]:
|
||||||
|
attn_map = self.use_context("self_attention_map")["middle_block_attn_map"]
|
||||||
|
if classifier_free_guidance:
|
||||||
|
unconditional_attn, _ = attn_map.chunk(2)
|
||||||
|
attn_map = unconditional_attn
|
||||||
|
attn_shape = self.use_context("self_attention_map")["middle_block_attn_shape"].pop()
|
||||||
|
assert len(attn_shape) == 2
|
||||||
|
b, c, h, w = latents.shape
|
||||||
|
attn_h, attn_w = attn_shape
|
||||||
|
attn_mask = attn_map.mean(dim=1, keepdim=False).sum(dim=1, keepdim=False) > 1.0
|
||||||
|
attn_mask = attn_mask.reshape(b, attn_h, attn_w).unsqueeze(1).repeat(1, c, 1, 1).type(attn_map.dtype)
|
||||||
|
return interpolate(attn_mask, Size((h, w)))
|
||||||
|
|
||||||
|
def compute_degraded_latents(
|
||||||
|
self, scheduler: Scheduler, latents: Tensor, noise: Tensor, step: int, classifier_free_guidance: bool = True
|
||||||
|
) -> Tensor:
|
||||||
|
sag_mask = self.compute_sag_mask(latents=latents, classifier_free_guidance=classifier_free_guidance)
|
||||||
|
original_latents = scheduler.remove_noise(x=latents, noise=noise, step=step)
|
||||||
|
degraded_latents = gaussian_blur(original_latents, kernel_size=self.kernel_size, sigma=self.sigma)
|
||||||
|
degraded_latents = degraded_latents * sag_mask + original_latents * (1 - sag_mask)
|
||||||
|
return scheduler.add_noise(degraded_latents, noise=noise, step=step)
|
||||||
|
|
||||||
|
def init_context(self) -> Contexts:
|
||||||
|
return {"self_attention_map": {"middle_block_attn_map": None, "middle_block_attn_shape": []}}
|
@ -0,0 +1,17 @@
|
|||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1ControlnetAdapter
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.image_prompt import SD1IPAdapter
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.model import (
|
||||||
|
StableDiffusion_1,
|
||||||
|
StableDiffusion_1_Inpainting,
|
||||||
|
)
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.t2i_adapter import SD1T2IAdapter
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"StableDiffusion_1",
|
||||||
|
"StableDiffusion_1_Inpainting",
|
||||||
|
"SD1UNet",
|
||||||
|
"SD1ControlnetAdapter",
|
||||||
|
"SD1IPAdapter",
|
||||||
|
"SD1T2IAdapter",
|
||||||
|
]
|
@ -0,0 +1,179 @@
|
|||||||
|
from typing import Iterable, cast
|
||||||
|
|
||||||
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
|
|
||||||
|
from imaginairy.vendored.refiners.fluxion.adapters.adapter import Adapter
|
||||||
|
from imaginairy.vendored.refiners.fluxion.context import Contexts
|
||||||
|
from imaginairy.vendored.refiners.fluxion.layers import Chain, Conv2d, Lambda, Passthrough, Residual, SiLU, Slicing, UseContext
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.range_adapter import RangeAdapter2d
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import (
|
||||||
|
DownBlocks,
|
||||||
|
MiddleBlock,
|
||||||
|
ResidualBlock,
|
||||||
|
SD1UNet,
|
||||||
|
TimestepEncoder,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ConditionEncoder(Chain):
|
||||||
|
"""Encode an image to be used as a condition for Controlnet.
|
||||||
|
|
||||||
|
Input is a `batch 3 width height` tensor, output is a `batch 320 width//8 height//8` tensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
|
self.out_channels = (16, 32, 96, 256)
|
||||||
|
super().__init__(
|
||||||
|
Chain(
|
||||||
|
Conv2d(
|
||||||
|
in_channels=3,
|
||||||
|
out_channels=self.out_channels[0],
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
SiLU(),
|
||||||
|
),
|
||||||
|
*(
|
||||||
|
Chain(
|
||||||
|
Conv2d(
|
||||||
|
in_channels=self.out_channels[i],
|
||||||
|
out_channels=self.out_channels[i],
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
SiLU(),
|
||||||
|
Conv2d(
|
||||||
|
in_channels=self.out_channels[i],
|
||||||
|
out_channels=self.out_channels[i + 1],
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
SiLU(),
|
||||||
|
)
|
||||||
|
for i in range(len(self.out_channels) - 1)
|
||||||
|
),
|
||||||
|
Conv2d(
|
||||||
|
in_channels=self.out_channels[-1],
|
||||||
|
out_channels=320,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Controlnet(Passthrough):
|
||||||
|
def __init__(
|
||||||
|
self, name: str, scale: float = 1.0, device: Device | str | None = None, dtype: DType | None = None
|
||||||
|
) -> None:
|
||||||
|
"""Controlnet is a Half-UNet that collects residuals from the UNet and uses them to condition the UNet.
|
||||||
|
|
||||||
|
Input is a `batch 3 width height` tensor, output is a `batch 1280 width//8 height//8` tensor with residuals
|
||||||
|
stored in the context.
|
||||||
|
|
||||||
|
It has to use the same context as the UNet: `unet` and `sampling`.
|
||||||
|
"""
|
||||||
|
self.name = name
|
||||||
|
self.scale = scale
|
||||||
|
super().__init__(
|
||||||
|
TimestepEncoder(context_key=f"timestep_embedding_{name}", device=device, dtype=dtype),
|
||||||
|
Slicing(dim=1, end=4), # support inpainting
|
||||||
|
DownBlocks(in_channels=4, device=device, dtype=dtype),
|
||||||
|
MiddleBlock(device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
# We run the condition encoder at each step. Caching the result
|
||||||
|
# is not worth it as subsequent runs take virtually no time (FG-374).
|
||||||
|
self.DownBlocks[0].append(
|
||||||
|
Residual(
|
||||||
|
UseContext("controlnet", f"condition_{name}"),
|
||||||
|
ConditionEncoder(device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for residual_block in self.layers(ResidualBlock):
|
||||||
|
chain = residual_block.Chain
|
||||||
|
RangeAdapter2d(
|
||||||
|
target=chain.Conv2d_1,
|
||||||
|
channels=residual_block.out_channels,
|
||||||
|
embedding_dim=1280,
|
||||||
|
context_key=f"timestep_embedding_{name}",
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
).inject(chain)
|
||||||
|
for n, block in enumerate(cast(Iterable[Chain], self.DownBlocks)):
|
||||||
|
assert hasattr(block[0], "out_channels"), (
|
||||||
|
"The first block of every subchain in DownBlocks is expected to respond to `out_channels`,"
|
||||||
|
f" {block[0]} does not."
|
||||||
|
)
|
||||||
|
out_channels: int = block[0].out_channels
|
||||||
|
block.append(
|
||||||
|
Passthrough(
|
||||||
|
Conv2d(
|
||||||
|
in_channels=out_channels, out_channels=out_channels, kernel_size=1, device=device, dtype=dtype
|
||||||
|
),
|
||||||
|
Lambda(self._store_nth_residual(n)),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.MiddleBlock.append(
|
||||||
|
Passthrough(
|
||||||
|
Conv2d(in_channels=1280, out_channels=1280, kernel_size=1, device=device, dtype=dtype),
|
||||||
|
Lambda(self._store_nth_residual(12)),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _store_nth_residual(self, n: int):
|
||||||
|
def _store_residual(x: Tensor):
|
||||||
|
residuals = self.use_context("unet")["residuals"]
|
||||||
|
residuals[n] = residuals[n] + x * self.scale
|
||||||
|
return x
|
||||||
|
|
||||||
|
return _store_residual
|
||||||
|
|
||||||
|
|
||||||
|
class SD1ControlnetAdapter(Chain, Adapter[SD1UNet]):
|
||||||
|
def __init__(
|
||||||
|
self, target: SD1UNet, name: str, scale: float = 1.0, weights: dict[str, Tensor] | None = None
|
||||||
|
) -> None:
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
controlnet = Controlnet(name=name, scale=scale, device=target.device, dtype=target.dtype)
|
||||||
|
if weights is not None:
|
||||||
|
controlnet.load_state_dict(weights)
|
||||||
|
self._controlnet: list[Controlnet] = [controlnet] # not registered by PyTorch
|
||||||
|
|
||||||
|
with self.setup_adapter(target):
|
||||||
|
super().__init__(target)
|
||||||
|
|
||||||
|
def inject(self: "SD1ControlnetAdapter", parent: Chain | None = None) -> "SD1ControlnetAdapter":
|
||||||
|
controlnet = self._controlnet[0]
|
||||||
|
target_controlnets = [x for x in self.target if isinstance(x, Controlnet)]
|
||||||
|
assert controlnet not in target_controlnets, f"{controlnet} is already injected"
|
||||||
|
for cn in target_controlnets:
|
||||||
|
assert cn.name != self.name, f"Controlnet named {self.name} is already injected"
|
||||||
|
self.target.insert(0, controlnet)
|
||||||
|
return super().inject(parent)
|
||||||
|
|
||||||
|
def eject(self) -> None:
|
||||||
|
self.target.remove(self._controlnet[0])
|
||||||
|
super().eject()
|
||||||
|
|
||||||
|
def init_context(self) -> Contexts:
|
||||||
|
return {"controlnet": {f"condition_{self.name}": None}}
|
||||||
|
|
||||||
|
def set_scale(self, scale: float) -> None:
|
||||||
|
self._controlnet[0].scale = scale
|
||||||
|
|
||||||
|
def set_controlnet_condition(self, condition: Tensor) -> None:
|
||||||
|
self.set_context("controlnet", {f"condition_{self.name}": condition})
|
||||||
|
|
||||||
|
def structural_copy(self: "SD1ControlnetAdapter") -> "SD1ControlnetAdapter":
|
||||||
|
raise RuntimeError("Controlnet cannot be copied, eject it first.")
|
@ -0,0 +1,53 @@
|
|||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from imaginairy.vendored.refiners.foundationals.clip.image_encoder import CLIPImageEncoderH
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.image_prompt import ImageProjection, IPAdapter, PerceiverResampler
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
||||||
|
|
||||||
|
|
||||||
|
class SD1IPAdapter(IPAdapter[SD1UNet]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
target: SD1UNet,
|
||||||
|
clip_image_encoder: CLIPImageEncoderH | None = None,
|
||||||
|
image_proj: ImageProjection | PerceiverResampler | None = None,
|
||||||
|
scale: float = 1.0,
|
||||||
|
fine_grained: bool = False,
|
||||||
|
weights: dict[str, Tensor] | None = None,
|
||||||
|
) -> None:
|
||||||
|
clip_image_encoder = clip_image_encoder or CLIPImageEncoderH(device=target.device, dtype=target.dtype)
|
||||||
|
|
||||||
|
if image_proj is None:
|
||||||
|
cross_attn_2d = target.ensure_find(CrossAttentionBlock2d)
|
||||||
|
image_proj = (
|
||||||
|
ImageProjection(
|
||||||
|
clip_image_embedding_dim=clip_image_encoder.output_dim,
|
||||||
|
clip_text_embedding_dim=cross_attn_2d.context_embedding_dim,
|
||||||
|
device=target.device,
|
||||||
|
dtype=target.dtype,
|
||||||
|
)
|
||||||
|
if not fine_grained
|
||||||
|
else PerceiverResampler(
|
||||||
|
latents_dim=cross_attn_2d.context_embedding_dim,
|
||||||
|
num_attention_layers=4,
|
||||||
|
num_attention_heads=12,
|
||||||
|
head_dim=64,
|
||||||
|
num_tokens=16,
|
||||||
|
input_dim=clip_image_encoder.embedding_dim, # = dim before final projection
|
||||||
|
output_dim=cross_attn_2d.context_embedding_dim,
|
||||||
|
device=target.device,
|
||||||
|
dtype=target.dtype,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif fine_grained:
|
||||||
|
assert isinstance(image_proj, PerceiverResampler)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
target=target,
|
||||||
|
clip_image_encoder=clip_image_encoder,
|
||||||
|
image_proj=image_proj,
|
||||||
|
scale=scale,
|
||||||
|
fine_grained=fine_grained,
|
||||||
|
weights=weights,
|
||||||
|
)
|
@ -0,0 +1,173 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
|
|
||||||
|
from imaginairy.vendored.refiners.fluxion.utils import image_to_tensor, interpolate
|
||||||
|
from imaginairy.vendored.refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.self_attention_guidance import SD1SAGAdapter
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
||||||
|
|
||||||
|
|
||||||
|
class SD1Autoencoder(LatentDiffusionAutoencoder):
|
||||||
|
encoder_scale: float = 0.18215
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusion_1(LatentDiffusionModel):
|
||||||
|
unet: SD1UNet
|
||||||
|
clip_text_encoder: CLIPTextEncoderL
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
unet: SD1UNet | None = None,
|
||||||
|
lda: SD1Autoencoder | None = None,
|
||||||
|
clip_text_encoder: CLIPTextEncoderL | None = None,
|
||||||
|
scheduler: Scheduler | None = None,
|
||||||
|
device: Device | str = "cpu",
|
||||||
|
dtype: DType = torch.float32,
|
||||||
|
) -> None:
|
||||||
|
unet = unet or SD1UNet(in_channels=4)
|
||||||
|
lda = lda or SD1Autoencoder()
|
||||||
|
clip_text_encoder = clip_text_encoder or CLIPTextEncoderL()
|
||||||
|
scheduler = scheduler or DPMSolver(num_inference_steps=30)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
unet=unet,
|
||||||
|
lda=lda,
|
||||||
|
clip_text_encoder=clip_text_encoder,
|
||||||
|
scheduler=scheduler,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
def compute_clip_text_embedding(self, text: str, negative_text: str = "") -> Tensor:
|
||||||
|
conditional_embedding = self.clip_text_encoder(text)
|
||||||
|
if text == negative_text:
|
||||||
|
return torch.cat(tensors=(conditional_embedding, conditional_embedding), dim=0)
|
||||||
|
|
||||||
|
negative_embedding = self.clip_text_encoder(negative_text or "")
|
||||||
|
return torch.cat(tensors=(negative_embedding, conditional_embedding), dim=0)
|
||||||
|
|
||||||
|
def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None:
|
||||||
|
self.unet.set_timestep(timestep=timestep)
|
||||||
|
self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
|
||||||
|
|
||||||
|
def set_self_attention_guidance(self, enable: bool, scale: float = 1.0) -> None:
|
||||||
|
if enable:
|
||||||
|
if sag := self._find_sag_adapter():
|
||||||
|
sag.scale = scale
|
||||||
|
else:
|
||||||
|
SD1SAGAdapter(target=self.unet, scale=scale).inject()
|
||||||
|
else:
|
||||||
|
if sag := self._find_sag_adapter():
|
||||||
|
sag.eject()
|
||||||
|
|
||||||
|
def has_self_attention_guidance(self) -> bool:
|
||||||
|
return self._find_sag_adapter() is not None
|
||||||
|
|
||||||
|
def _find_sag_adapter(self) -> SD1SAGAdapter | None:
|
||||||
|
for p in self.unet.get_parents():
|
||||||
|
if isinstance(p, SD1SAGAdapter):
|
||||||
|
return p
|
||||||
|
return None
|
||||||
|
|
||||||
|
def compute_self_attention_guidance(
|
||||||
|
self, x: Tensor, noise: Tensor, step: int, *, clip_text_embedding: Tensor, **kwargs: Tensor
|
||||||
|
) -> Tensor:
|
||||||
|
sag = self._find_sag_adapter()
|
||||||
|
assert sag is not None
|
||||||
|
|
||||||
|
degraded_latents = sag.compute_degraded_latents(
|
||||||
|
scheduler=self.scheduler,
|
||||||
|
latents=x,
|
||||||
|
noise=noise,
|
||||||
|
step=step,
|
||||||
|
classifier_free_guidance=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
negative_embedding, _ = clip_text_embedding.chunk(2)
|
||||||
|
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
|
||||||
|
self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs)
|
||||||
|
degraded_noise = self.unet(degraded_latents)
|
||||||
|
|
||||||
|
return sag.scale * (noise - degraded_noise)
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusion_1_Inpainting(StableDiffusion_1):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
unet: SD1UNet | None = None,
|
||||||
|
lda: SD1Autoencoder | None = None,
|
||||||
|
clip_text_encoder: CLIPTextEncoderL | None = None,
|
||||||
|
scheduler: Scheduler | None = None,
|
||||||
|
device: Device | str = "cpu",
|
||||||
|
dtype: DType = torch.float32,
|
||||||
|
) -> None:
|
||||||
|
self.mask_latents: Tensor | None = None
|
||||||
|
self.target_image_latents: Tensor | None = None
|
||||||
|
super().__init__(
|
||||||
|
unet=unet, lda=lda, clip_text_encoder=clip_text_encoder, scheduler=scheduler, device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5, **_: Tensor
|
||||||
|
) -> Tensor:
|
||||||
|
assert self.mask_latents is not None
|
||||||
|
assert self.target_image_latents is not None
|
||||||
|
x = torch.cat(tensors=(x, self.mask_latents, self.target_image_latents), dim=1)
|
||||||
|
return super().forward(
|
||||||
|
x=x,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=clip_text_embedding,
|
||||||
|
condition_scale=condition_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_inpainting_conditions(
|
||||||
|
self,
|
||||||
|
target_image: Image.Image,
|
||||||
|
mask: Image.Image,
|
||||||
|
latents_size: tuple[int, int] = (64, 64),
|
||||||
|
) -> tuple[Tensor, Tensor]:
|
||||||
|
target_image = target_image.convert(mode="RGB")
|
||||||
|
mask = mask.convert(mode="L")
|
||||||
|
|
||||||
|
mask_tensor = torch.tensor(data=np.array(object=mask).astype(dtype=np.float32) / 255.0).to(device=self.device)
|
||||||
|
mask_tensor = (mask_tensor > 0.5).unsqueeze(dim=0).unsqueeze(dim=0).to(dtype=self.dtype)
|
||||||
|
self.mask_latents = interpolate(x=mask_tensor, factor=torch.Size(latents_size))
|
||||||
|
|
||||||
|
init_image_tensor = image_to_tensor(image=target_image, device=self.device, dtype=self.dtype) * 2 - 1
|
||||||
|
masked_init_image = init_image_tensor * (1 - mask_tensor)
|
||||||
|
self.target_image_latents = self.lda.encode(x=masked_init_image)
|
||||||
|
|
||||||
|
return self.mask_latents, self.target_image_latents
|
||||||
|
|
||||||
|
def compute_self_attention_guidance(
|
||||||
|
self, x: Tensor, noise: Tensor, step: int, *, clip_text_embedding: Tensor, **kwargs: Tensor
|
||||||
|
) -> Tensor:
|
||||||
|
sag = self._find_sag_adapter()
|
||||||
|
assert sag is not None
|
||||||
|
assert self.mask_latents is not None
|
||||||
|
assert self.target_image_latents is not None
|
||||||
|
|
||||||
|
degraded_latents = sag.compute_degraded_latents(
|
||||||
|
scheduler=self.scheduler,
|
||||||
|
latents=x,
|
||||||
|
noise=noise,
|
||||||
|
step=step,
|
||||||
|
classifier_free_guidance=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
negative_embedding, _ = clip_text_embedding.chunk(2)
|
||||||
|
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
|
||||||
|
self.set_unet_context(timestep=timestep, clip_text_embedding=negative_embedding, **kwargs)
|
||||||
|
x = torch.cat(
|
||||||
|
tensors=(degraded_latents, self.mask_latents, self.target_image_latents),
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
degraded_noise = self.unet(x)
|
||||||
|
|
||||||
|
return sag.scale * (noise - degraded_noise)
|
@ -0,0 +1,41 @@
|
|||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget, MultiDiffusion
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.model import (
|
||||||
|
StableDiffusion_1,
|
||||||
|
StableDiffusion_1_Inpainting,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SD1MultiDiffusion(MultiDiffusion[StableDiffusion_1, DiffusionTarget]):
|
||||||
|
def diffuse_target(self, x: Tensor, step: int, target: DiffusionTarget) -> Tensor:
|
||||||
|
return self.ldm(
|
||||||
|
x=x,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=target.clip_text_embedding,
|
||||||
|
scale=target.condition_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InpaintingDiffusionTarget(DiffusionTarget):
|
||||||
|
target_image: Image.Image = field(default_factory=lambda: Image.new(mode="RGB", size=(512, 512), color=255))
|
||||||
|
mask: Image.Image = field(default_factory=lambda: Image.new(mode="L", size=(512, 512), color=255))
|
||||||
|
|
||||||
|
|
||||||
|
class SD1InpaintingMultiDiffusion(MultiDiffusion[StableDiffusion_1_Inpainting, InpaintingDiffusionTarget]):
|
||||||
|
def diffuse_target(self, x: Tensor, step: int, target: InpaintingDiffusionTarget) -> Tensor:
|
||||||
|
self.ldm.set_inpainting_conditions(
|
||||||
|
target_image=target.target_image,
|
||||||
|
mask=target.mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.ldm(
|
||||||
|
x=x,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=target.clip_text_embedding,
|
||||||
|
scale=target.condition_scale,
|
||||||
|
)
|
@ -0,0 +1,41 @@
|
|||||||
|
import imaginairy.vendored.refiners.fluxion.layers as fl
|
||||||
|
from imaginairy.vendored.refiners.fluxion.layers.attentions import ScaledDotProductAttention
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.self_attention_guidance import (
|
||||||
|
SAGAdapter,
|
||||||
|
SelfAttentionMap,
|
||||||
|
SelfAttentionShape,
|
||||||
|
)
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import MiddleBlock, ResidualBlock, SD1UNet
|
||||||
|
|
||||||
|
|
||||||
|
class SD1SAGAdapter(SAGAdapter[SD1UNet]):
|
||||||
|
def __init__(self, target: SD1UNet, scale: float = 1.0, kernel_size: int = 9, sigma: float = 1.0) -> None:
|
||||||
|
super().__init__(
|
||||||
|
target=target,
|
||||||
|
scale=scale,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
sigma=sigma,
|
||||||
|
)
|
||||||
|
|
||||||
|
def inject(self: "SD1SAGAdapter", parent: fl.Chain | None = None) -> "SD1SAGAdapter":
|
||||||
|
middle_block = self.target.ensure_find(MiddleBlock)
|
||||||
|
middle_block.insert_after_type(ResidualBlock, SelfAttentionShape(context_key="middle_block_attn_shape"))
|
||||||
|
|
||||||
|
# An alternative would be to replace the ScaledDotProductAttention with a version which records the attention
|
||||||
|
# scores to avoid computing these scores twice
|
||||||
|
self_attn = middle_block.ensure_find(fl.SelfAttention)
|
||||||
|
self_attn.insert_before_type(
|
||||||
|
ScaledDotProductAttention,
|
||||||
|
SelfAttentionMap(num_heads=self_attn.num_heads, context_key="middle_block_attn_map"),
|
||||||
|
)
|
||||||
|
|
||||||
|
return super().inject(parent)
|
||||||
|
|
||||||
|
def eject(self) -> None:
|
||||||
|
middle_block = self.target.ensure_find(MiddleBlock)
|
||||||
|
middle_block.remove(middle_block.ensure_find(SelfAttentionShape))
|
||||||
|
|
||||||
|
self_attn = middle_block.ensure_find(fl.SelfAttention)
|
||||||
|
self_attn.remove(self_attn.ensure_find(SelfAttentionMap))
|
||||||
|
|
||||||
|
super().eject()
|
@ -0,0 +1,37 @@
|
|||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
import imaginairy.vendored.refiners.fluxion.layers as fl
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import ResidualAccumulator, SD1UNet
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.t2i_adapter import ConditionEncoder, T2IAdapter, T2IFeatures
|
||||||
|
|
||||||
|
|
||||||
|
class SD1T2IAdapter(T2IAdapter[SD1UNet]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
target: SD1UNet,
|
||||||
|
name: str,
|
||||||
|
condition_encoder: ConditionEncoder | None = None,
|
||||||
|
scale: float = 1.0,
|
||||||
|
weights: dict[str, Tensor] | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.residual_indices = (2, 5, 8, 11)
|
||||||
|
self._features = [T2IFeatures(name=name, index=i, scale=scale) for i in range(4)]
|
||||||
|
super().__init__(
|
||||||
|
target=target,
|
||||||
|
name=name,
|
||||||
|
condition_encoder=condition_encoder or ConditionEncoder(device=target.device, dtype=target.dtype),
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
def inject(self: "SD1T2IAdapter", parent: fl.Chain | None = None) -> "SD1T2IAdapter":
|
||||||
|
for n, feat in zip(self.residual_indices, self._features, strict=True):
|
||||||
|
block = self.target.DownBlocks[n]
|
||||||
|
for t2i_layer in block.layers(layer_type=T2IFeatures):
|
||||||
|
assert t2i_layer.name != self.name, f"T2I-Adapter named {self.name} is already injected"
|
||||||
|
block.insert_before_type(ResidualAccumulator, feat)
|
||||||
|
return super().inject(parent)
|
||||||
|
|
||||||
|
def eject(self: "SD1T2IAdapter") -> None:
|
||||||
|
for n, feat in zip(self.residual_indices, self._features, strict=True):
|
||||||
|
self.target.DownBlocks[n].remove(feat)
|
||||||
|
super().eject()
|
@ -0,0 +1,288 @@
|
|||||||
|
from typing import Iterable, cast
|
||||||
|
|
||||||
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
|
|
||||||
|
import imaginairy.vendored.refiners.fluxion.layers as fl
|
||||||
|
from imaginairy.vendored.refiners.fluxion.context import Contexts
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.range_adapter import RangeAdapter2d, RangeEncoder
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepEncoder(fl.Passthrough):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
context_key: str = "timestep_embedding",
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
fl.UseContext("diffusion", "timestep"),
|
||||||
|
RangeEncoder(320, 1280, device=device, dtype=dtype),
|
||||||
|
fl.SetContext("range_adapter", context_key),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualBlock(fl.Sum):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
num_groups: int = 32,
|
||||||
|
eps: float = 1e-5,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
if in_channels % num_groups != 0 or out_channels % num_groups != 0:
|
||||||
|
raise ValueError("Number of input and output channels must be divisible by num_groups.")
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.num_groups = num_groups
|
||||||
|
self.eps = eps
|
||||||
|
shortcut = (
|
||||||
|
fl.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, device=device, dtype=dtype)
|
||||||
|
if in_channels != out_channels
|
||||||
|
else fl.Identity()
|
||||||
|
)
|
||||||
|
super().__init__(
|
||||||
|
fl.Chain(
|
||||||
|
fl.GroupNorm(channels=in_channels, num_groups=num_groups, eps=eps, device=device, dtype=dtype),
|
||||||
|
fl.SiLU(),
|
||||||
|
fl.Conv2d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
fl.GroupNorm(channels=out_channels, num_groups=num_groups, eps=eps, device=device, dtype=dtype),
|
||||||
|
fl.SiLU(),
|
||||||
|
fl.Conv2d(
|
||||||
|
in_channels=out_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
shortcut,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPLCrossAttention(CrossAttentionBlock2d):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
channels=channels,
|
||||||
|
context_embedding_dim=768,
|
||||||
|
context_key="clip_text_embedding",
|
||||||
|
num_attention_heads=8,
|
||||||
|
use_bias=False,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DownBlocks(fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
):
|
||||||
|
self.in_channels = in_channels
|
||||||
|
super().__init__(
|
||||||
|
fl.Chain(
|
||||||
|
fl.Conv2d(
|
||||||
|
in_channels=in_channels, out_channels=320, kernel_size=3, padding=1, device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=320, out_channels=320, device=device, dtype=dtype),
|
||||||
|
CLIPLCrossAttention(channels=320, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=320, out_channels=320, device=device, dtype=dtype),
|
||||||
|
CLIPLCrossAttention(channels=320, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.Chain(fl.Downsample(channels=320, scale_factor=2, padding=1, device=device, dtype=dtype)),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=320, out_channels=640, device=device, dtype=dtype),
|
||||||
|
CLIPLCrossAttention(channels=640, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=640, out_channels=640, device=device, dtype=dtype),
|
||||||
|
CLIPLCrossAttention(channels=640, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.Chain(fl.Downsample(channels=640, scale_factor=2, padding=1, device=device, dtype=dtype)),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=640, out_channels=1280, device=device, dtype=dtype),
|
||||||
|
CLIPLCrossAttention(channels=1280, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=1280, out_channels=1280, device=device, dtype=dtype),
|
||||||
|
CLIPLCrossAttention(channels=1280, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.Chain(fl.Downsample(channels=1280, scale_factor=2, padding=1, device=device, dtype=dtype)),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=1280, out_channels=1280, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=1280, out_channels=1280, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UpBlocks(fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=2560, out_channels=1280, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=2560, out_channels=1280, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=2560, out_channels=1280, device=device, dtype=dtype),
|
||||||
|
fl.Upsample(channels=1280, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=2560, out_channels=1280, device=device, dtype=dtype),
|
||||||
|
CLIPLCrossAttention(channels=1280, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=2560, out_channels=1280, device=device, dtype=dtype),
|
||||||
|
CLIPLCrossAttention(channels=1280, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=1920, out_channels=1280, device=device, dtype=dtype),
|
||||||
|
CLIPLCrossAttention(channels=1280, device=device, dtype=dtype),
|
||||||
|
fl.Upsample(channels=1280, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=1920, out_channels=640, device=device, dtype=dtype),
|
||||||
|
CLIPLCrossAttention(channels=640, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=1280, out_channels=640, device=device, dtype=dtype),
|
||||||
|
CLIPLCrossAttention(channels=640, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=960, out_channels=640, device=device, dtype=dtype),
|
||||||
|
CLIPLCrossAttention(channels=640, device=device, dtype=dtype),
|
||||||
|
fl.Upsample(channels=640, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=960, out_channels=320, device=device, dtype=dtype),
|
||||||
|
CLIPLCrossAttention(channels=320, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=640, out_channels=320, device=device, dtype=dtype),
|
||||||
|
CLIPLCrossAttention(channels=320, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=640, out_channels=320, device=device, dtype=dtype),
|
||||||
|
CLIPLCrossAttention(channels=320, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MiddleBlock(fl.Chain):
|
||||||
|
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
|
super().__init__(
|
||||||
|
ResidualBlock(in_channels=1280, out_channels=1280, device=device, dtype=dtype),
|
||||||
|
CLIPLCrossAttention(channels=1280, device=device, dtype=dtype),
|
||||||
|
ResidualBlock(in_channels=1280, out_channels=1280, device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualAccumulator(fl.Passthrough):
|
||||||
|
def __init__(self, n: int) -> None:
|
||||||
|
self.n = n
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
fl.Residual(
|
||||||
|
fl.UseContext(context="unet", key="residuals").compose(func=lambda residuals: residuals[self.n])
|
||||||
|
),
|
||||||
|
fl.SetContext(context="unet", key="residuals", callback=self.update),
|
||||||
|
)
|
||||||
|
|
||||||
|
def update(self, residuals: list[Tensor | float], x: Tensor) -> None:
|
||||||
|
residuals[self.n] = x
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualConcatenator(fl.Chain):
|
||||||
|
def __init__(self, n: int) -> None:
|
||||||
|
self.n = n
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
fl.Concatenate(
|
||||||
|
fl.Identity(),
|
||||||
|
fl.UseContext(context="unet", key="residuals").compose(lambda residuals: residuals[self.n]),
|
||||||
|
dim=1,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SD1UNet(fl.Chain):
|
||||||
|
def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
|
self.in_channels = in_channels
|
||||||
|
super().__init__(
|
||||||
|
TimestepEncoder(device=device, dtype=dtype),
|
||||||
|
DownBlocks(in_channels=in_channels, device=device, dtype=dtype),
|
||||||
|
fl.Sum(
|
||||||
|
fl.UseContext(context="unet", key="residuals").compose(lambda x: x[-1]),
|
||||||
|
MiddleBlock(device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
UpBlocks(),
|
||||||
|
fl.Chain(
|
||||||
|
fl.GroupNorm(channels=320, num_groups=32, device=device, dtype=dtype),
|
||||||
|
fl.SiLU(),
|
||||||
|
fl.Conv2d(
|
||||||
|
in_channels=320,
|
||||||
|
out_channels=4,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for residual_block in self.layers(ResidualBlock):
|
||||||
|
chain = residual_block.Chain
|
||||||
|
RangeAdapter2d(
|
||||||
|
target=chain.Conv2d_1,
|
||||||
|
channels=residual_block.out_channels,
|
||||||
|
embedding_dim=1280,
|
||||||
|
context_key="timestep_embedding",
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
).inject(chain)
|
||||||
|
for n, block in enumerate(cast(Iterable[fl.Chain], self.DownBlocks)):
|
||||||
|
block.append(ResidualAccumulator(n))
|
||||||
|
for n, block in enumerate(cast(Iterable[fl.Chain], self.UpBlocks)):
|
||||||
|
block.insert(0, ResidualConcatenator(-n - 2))
|
||||||
|
|
||||||
|
def init_context(self) -> Contexts:
|
||||||
|
return {
|
||||||
|
"unet": {"residuals": [0.0] * 13},
|
||||||
|
"diffusion": {"timestep": None},
|
||||||
|
"range_adapter": {"timestep_embedding": None},
|
||||||
|
"sampling": {"shapes": []},
|
||||||
|
}
|
||||||
|
|
||||||
|
def set_clip_text_embedding(self, clip_text_embedding: Tensor) -> None:
|
||||||
|
self.set_context("cross_attention_block", {"clip_text_embedding": clip_text_embedding})
|
||||||
|
|
||||||
|
def set_timestep(self, timestep: Tensor) -> None:
|
||||||
|
self.set_context("diffusion", {"timestep": timestep})
|
@ -0,0 +1,13 @@
|
|||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_xl.image_prompt import SDXLIPAdapter
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_xl.t2i_adapter import SDXLT2IAdapter
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"SDXLUNet",
|
||||||
|
"DoubleTextEncoder",
|
||||||
|
"StableDiffusion_XL",
|
||||||
|
"SDXLIPAdapter",
|
||||||
|
"SDXLT2IAdapter",
|
||||||
|
]
|
@ -0,0 +1,53 @@
|
|||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from imaginairy.vendored.refiners.foundationals.clip.image_encoder import CLIPImageEncoderH
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.image_prompt import ImageProjection, IPAdapter, PerceiverResampler
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
|
||||||
|
|
||||||
|
|
||||||
|
class SDXLIPAdapter(IPAdapter[SDXLUNet]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
target: SDXLUNet,
|
||||||
|
clip_image_encoder: CLIPImageEncoderH | None = None,
|
||||||
|
image_proj: ImageProjection | PerceiverResampler | None = None,
|
||||||
|
scale: float = 1.0,
|
||||||
|
fine_grained: bool = False,
|
||||||
|
weights: dict[str, Tensor] | None = None,
|
||||||
|
) -> None:
|
||||||
|
clip_image_encoder = clip_image_encoder or CLIPImageEncoderH(device=target.device, dtype=target.dtype)
|
||||||
|
|
||||||
|
if image_proj is None:
|
||||||
|
cross_attn_2d = target.ensure_find(CrossAttentionBlock2d)
|
||||||
|
image_proj = (
|
||||||
|
ImageProjection(
|
||||||
|
clip_image_embedding_dim=clip_image_encoder.output_dim,
|
||||||
|
clip_text_embedding_dim=cross_attn_2d.context_embedding_dim,
|
||||||
|
device=target.device,
|
||||||
|
dtype=target.dtype,
|
||||||
|
)
|
||||||
|
if not fine_grained
|
||||||
|
else PerceiverResampler(
|
||||||
|
latents_dim=1280, # not `cross_attn_2d.context_embedding_dim` in this case
|
||||||
|
num_attention_layers=4,
|
||||||
|
num_attention_heads=20,
|
||||||
|
head_dim=64,
|
||||||
|
num_tokens=16,
|
||||||
|
input_dim=clip_image_encoder.embedding_dim, # = dim before final projection
|
||||||
|
output_dim=cross_attn_2d.context_embedding_dim,
|
||||||
|
device=target.device,
|
||||||
|
dtype=target.dtype,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif fine_grained:
|
||||||
|
assert isinstance(image_proj, PerceiverResampler)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
target=target,
|
||||||
|
clip_image_encoder=clip_image_encoder,
|
||||||
|
image_proj=image_proj,
|
||||||
|
scale=scale,
|
||||||
|
fine_grained=fine_grained,
|
||||||
|
weights=weights,
|
||||||
|
)
|
@ -0,0 +1,154 @@
|
|||||||
|
import torch
|
||||||
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
|
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_xl.self_attention_guidance import SDXLSAGAdapter
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
|
||||||
|
|
||||||
|
|
||||||
|
class SDXLAutoencoder(LatentDiffusionAutoencoder):
|
||||||
|
encoder_scale: float = 0.13025
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusion_XL(LatentDiffusionModel):
|
||||||
|
unet: SDXLUNet
|
||||||
|
clip_text_encoder: DoubleTextEncoder
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
unet: SDXLUNet | None = None,
|
||||||
|
lda: SDXLAutoencoder | None = None,
|
||||||
|
clip_text_encoder: DoubleTextEncoder | None = None,
|
||||||
|
scheduler: Scheduler | None = None,
|
||||||
|
device: Device | str = "cpu",
|
||||||
|
dtype: DType = torch.float32,
|
||||||
|
) -> None:
|
||||||
|
unet = unet or SDXLUNet(in_channels=4)
|
||||||
|
lda = lda or SDXLAutoencoder()
|
||||||
|
clip_text_encoder = clip_text_encoder or DoubleTextEncoder()
|
||||||
|
scheduler = scheduler or DDIM(num_inference_steps=30)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
unet=unet,
|
||||||
|
lda=lda,
|
||||||
|
clip_text_encoder=clip_text_encoder,
|
||||||
|
scheduler=scheduler,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
def compute_clip_text_embedding(self, text: str, negative_text: str | None = None) -> tuple[Tensor, Tensor]:
|
||||||
|
conditional_embedding, conditional_pooled_embedding = self.clip_text_encoder(text)
|
||||||
|
if text == negative_text:
|
||||||
|
return torch.cat(tensors=(conditional_embedding, conditional_embedding), dim=0), torch.cat(
|
||||||
|
tensors=(conditional_pooled_embedding, conditional_pooled_embedding), dim=0
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: when negative_text is None, use zero tensor?
|
||||||
|
negative_embedding, negative_pooled_embedding = self.clip_text_encoder(negative_text or "")
|
||||||
|
|
||||||
|
return torch.cat(tensors=(negative_embedding, conditional_embedding), dim=0), torch.cat(
|
||||||
|
tensors=(negative_pooled_embedding, conditional_pooled_embedding), dim=0
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_time_ids(self) -> Tensor:
|
||||||
|
# [original_height, original_width, crop_top, crop_left, target_height, target_width]
|
||||||
|
# See https://arxiv.org/abs/2307.01952 > 2.2 Micro-Conditioning
|
||||||
|
time_ids = torch.tensor(data=[1024, 1024, 0, 0, 1024, 1024], device=self.device)
|
||||||
|
return time_ids.repeat(2, 1)
|
||||||
|
|
||||||
|
def set_unet_context(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
timestep: Tensor,
|
||||||
|
clip_text_embedding: Tensor,
|
||||||
|
pooled_text_embedding: Tensor,
|
||||||
|
time_ids: Tensor,
|
||||||
|
**_: Tensor,
|
||||||
|
) -> None:
|
||||||
|
self.unet.set_timestep(timestep=timestep)
|
||||||
|
self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
|
||||||
|
self.unet.set_pooled_text_embedding(pooled_text_embedding=pooled_text_embedding)
|
||||||
|
self.unet.set_time_ids(time_ids=time_ids)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
step: int,
|
||||||
|
*,
|
||||||
|
clip_text_embedding: Tensor,
|
||||||
|
pooled_text_embedding: Tensor,
|
||||||
|
time_ids: Tensor,
|
||||||
|
condition_scale: float = 5.0,
|
||||||
|
**kwargs: Tensor,
|
||||||
|
) -> Tensor:
|
||||||
|
return super().forward(
|
||||||
|
x=x,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=clip_text_embedding,
|
||||||
|
pooled_text_embedding=pooled_text_embedding,
|
||||||
|
time_ids=time_ids,
|
||||||
|
condition_scale=condition_scale,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_self_attention_guidance(self, enable: bool, scale: float = 1.0) -> None:
|
||||||
|
if enable:
|
||||||
|
if sag := self._find_sag_adapter():
|
||||||
|
sag.scale = scale
|
||||||
|
else:
|
||||||
|
SDXLSAGAdapter(target=self.unet, scale=scale).inject()
|
||||||
|
else:
|
||||||
|
if sag := self._find_sag_adapter():
|
||||||
|
sag.eject()
|
||||||
|
|
||||||
|
def has_self_attention_guidance(self) -> bool:
|
||||||
|
return self._find_sag_adapter() is not None
|
||||||
|
|
||||||
|
def _find_sag_adapter(self) -> SDXLSAGAdapter | None:
|
||||||
|
for p in self.unet.get_parents():
|
||||||
|
if isinstance(p, SDXLSAGAdapter):
|
||||||
|
return p
|
||||||
|
return None
|
||||||
|
|
||||||
|
def compute_self_attention_guidance(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
noise: Tensor,
|
||||||
|
step: int,
|
||||||
|
*,
|
||||||
|
clip_text_embedding: Tensor,
|
||||||
|
pooled_text_embedding: Tensor,
|
||||||
|
time_ids: Tensor,
|
||||||
|
**kwargs: Tensor,
|
||||||
|
) -> Tensor:
|
||||||
|
sag = self._find_sag_adapter()
|
||||||
|
assert sag is not None
|
||||||
|
|
||||||
|
degraded_latents = sag.compute_degraded_latents(
|
||||||
|
scheduler=self.scheduler,
|
||||||
|
latents=x,
|
||||||
|
noise=noise,
|
||||||
|
step=step,
|
||||||
|
classifier_free_guidance=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
negative_embedding, _ = clip_text_embedding.chunk(2)
|
||||||
|
negative_pooled_embedding, _ = pooled_text_embedding.chunk(2)
|
||||||
|
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
|
||||||
|
time_ids, _ = time_ids.chunk(2)
|
||||||
|
self.set_unet_context(
|
||||||
|
timestep=timestep,
|
||||||
|
clip_text_embedding=negative_embedding,
|
||||||
|
pooled_text_embedding=negative_pooled_embedding,
|
||||||
|
time_ids=time_ids,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
degraded_noise = self.unet(degraded_latents)
|
||||||
|
|
||||||
|
return sag.scale * (noise - degraded_noise)
|
@ -0,0 +1,21 @@
|
|||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget, MultiDiffusion
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
|
||||||
|
|
||||||
|
|
||||||
|
class SDXLDiffusionTarget(DiffusionTarget):
|
||||||
|
pooled_text_embedding: Tensor
|
||||||
|
time_ids: Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class SDXLMultiDiffusion(MultiDiffusion[StableDiffusion_XL, SDXLDiffusionTarget]):
|
||||||
|
def diffuse_target(self, x: Tensor, step: int, target: SDXLDiffusionTarget) -> Tensor:
|
||||||
|
return self.ldm(
|
||||||
|
x=x,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=target.clip_text_embedding,
|
||||||
|
pooled_text_embedding=target.pooled_text_embedding,
|
||||||
|
time_ids=target.time_ids,
|
||||||
|
condition_scale=target.condition_scale,
|
||||||
|
)
|
@ -0,0 +1,41 @@
|
|||||||
|
import imaginairy.vendored.refiners.fluxion.layers as fl
|
||||||
|
from imaginairy.vendored.refiners.fluxion.layers.attentions import ScaledDotProductAttention
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.self_attention_guidance import (
|
||||||
|
SAGAdapter,
|
||||||
|
SelfAttentionMap,
|
||||||
|
SelfAttentionShape,
|
||||||
|
)
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import MiddleBlock, ResidualBlock, SDXLUNet
|
||||||
|
|
||||||
|
|
||||||
|
class SDXLSAGAdapter(SAGAdapter[SDXLUNet]):
|
||||||
|
def __init__(self, target: SDXLUNet, scale: float = 1.0, kernel_size: int = 9, sigma: float = 1.0) -> None:
|
||||||
|
super().__init__(
|
||||||
|
target=target,
|
||||||
|
scale=scale,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
sigma=sigma,
|
||||||
|
)
|
||||||
|
|
||||||
|
def inject(self: "SDXLSAGAdapter", parent: fl.Chain | None = None) -> "SDXLSAGAdapter":
|
||||||
|
middle_block = self.target.ensure_find(MiddleBlock)
|
||||||
|
middle_block.insert_after_type(ResidualBlock, SelfAttentionShape(context_key="middle_block_attn_shape"))
|
||||||
|
|
||||||
|
# An alternative would be to replace the ScaledDotProductAttention with a version which records the attention
|
||||||
|
# scores to avoid computing these scores twice
|
||||||
|
self_attn = middle_block.ensure_find(fl.SelfAttention)
|
||||||
|
self_attn.insert_before_type(
|
||||||
|
ScaledDotProductAttention,
|
||||||
|
SelfAttentionMap(num_heads=self_attn.num_heads, context_key="middle_block_attn_map"),
|
||||||
|
)
|
||||||
|
|
||||||
|
return super().inject(parent)
|
||||||
|
|
||||||
|
def eject(self) -> None:
|
||||||
|
middle_block = self.target.ensure_find(MiddleBlock)
|
||||||
|
middle_block.remove(middle_block.ensure_find(SelfAttentionShape))
|
||||||
|
|
||||||
|
self_attn = middle_block.ensure_find(fl.SelfAttention)
|
||||||
|
self_attn.remove(self_attn.ensure_find(SelfAttentionMap))
|
||||||
|
|
||||||
|
super().eject()
|
@ -0,0 +1,48 @@
|
|||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
import imaginairy.vendored.refiners.fluxion.layers as fl
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import ResidualAccumulator
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.t2i_adapter import ConditionEncoderXL, T2IAdapter, T2IFeatures
|
||||||
|
|
||||||
|
|
||||||
|
class SDXLT2IAdapter(T2IAdapter[SDXLUNet]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
target: SDXLUNet,
|
||||||
|
name: str,
|
||||||
|
condition_encoder: ConditionEncoderXL | None = None,
|
||||||
|
scale: float = 1.0,
|
||||||
|
weights: dict[str, Tensor] | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.residual_indices = (3, 5, 8) # the UNet's middle block is handled separately (see `inject` and `eject`)
|
||||||
|
self._features = [T2IFeatures(name=name, index=i, scale=scale) for i in range(4)]
|
||||||
|
super().__init__(
|
||||||
|
target=target,
|
||||||
|
name=name,
|
||||||
|
condition_encoder=condition_encoder or ConditionEncoderXL(device=target.device, dtype=target.dtype),
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
def inject(self: "SDXLT2IAdapter", parent: fl.Chain | None = None) -> "SDXLT2IAdapter":
|
||||||
|
def sanity_check_t2i(block: fl.Chain) -> None:
|
||||||
|
for t2i_layer in block.layers(layer_type=T2IFeatures):
|
||||||
|
assert t2i_layer.name != self.name, f"T2I-Adapter named {self.name} is already injected"
|
||||||
|
|
||||||
|
# Note: `strict=False` because `residual_indices` is shorter than `_features` due to MiddleBlock (see below)
|
||||||
|
for n, feat in zip(self.residual_indices, self._features, strict=False):
|
||||||
|
block = self.target.DownBlocks[n]
|
||||||
|
sanity_check_t2i(block)
|
||||||
|
block.insert_before_type(ResidualAccumulator, feat)
|
||||||
|
|
||||||
|
# Special case: the MiddleBlock has no ResidualAccumulator (this is done via a subsequent layer) so just append
|
||||||
|
sanity_check_t2i(self.target.MiddleBlock)
|
||||||
|
self.target.MiddleBlock.append(self._features[-1])
|
||||||
|
return super().inject(parent)
|
||||||
|
|
||||||
|
def eject(self: "SDXLT2IAdapter") -> None:
|
||||||
|
# See `inject` re: `strict=False`
|
||||||
|
for n, feat in zip(self.residual_indices, self._features, strict=False):
|
||||||
|
self.target.DownBlocks[n].remove(feat)
|
||||||
|
self.target.MiddleBlock.remove(self._features[-1])
|
||||||
|
super().eject()
|
@ -0,0 +1,85 @@
|
|||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from jaxtyping import Float
|
||||||
|
from torch import Tensor, cat, device as Device, dtype as DType
|
||||||
|
|
||||||
|
import imaginairy.vendored.refiners.fluxion.layers as fl
|
||||||
|
from imaginairy.vendored.refiners.fluxion.adapters.adapter import Adapter
|
||||||
|
from imaginairy.vendored.refiners.fluxion.context import Contexts
|
||||||
|
from imaginairy.vendored.refiners.foundationals.clip.text_encoder import CLIPTextEncoderG, CLIPTextEncoderL
|
||||||
|
from imaginairy.vendored.refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class TextEncoderWithPooling(fl.Chain, Adapter[CLIPTextEncoderG]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
target: CLIPTextEncoderG,
|
||||||
|
projection: fl.Linear | None = None,
|
||||||
|
) -> None:
|
||||||
|
with self.setup_adapter(target=target):
|
||||||
|
tokenizer = target.ensure_find(CLIPTokenizer)
|
||||||
|
super().__init__(
|
||||||
|
tokenizer,
|
||||||
|
fl.SetContext(
|
||||||
|
context="text_encoder_pooling", key="end_of_text_index", callback=self.set_end_of_text_index
|
||||||
|
),
|
||||||
|
target[1:-2],
|
||||||
|
fl.Parallel(
|
||||||
|
fl.Identity(),
|
||||||
|
fl.Chain(
|
||||||
|
target[-2:],
|
||||||
|
projection
|
||||||
|
or fl.Linear(
|
||||||
|
in_features=1280, out_features=1280, bias=False, device=target.device, dtype=target.dtype
|
||||||
|
),
|
||||||
|
fl.Lambda(func=self.pool),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_context(self) -> Contexts:
|
||||||
|
return {"text_encoder_pooling": {"end_of_text_index": []}}
|
||||||
|
|
||||||
|
def __call__(self, text: str) -> tuple[Float[Tensor, "1 77 1280"], Float[Tensor, "1 1280"]]:
|
||||||
|
return super().__call__(text)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tokenizer(self) -> CLIPTokenizer:
|
||||||
|
return self.ensure_find(CLIPTokenizer)
|
||||||
|
|
||||||
|
def set_end_of_text_index(self, end_of_text_index: list[int], tokens: Tensor) -> None:
|
||||||
|
position = (tokens == self.tokenizer.end_of_text_token_id).nonzero(as_tuple=True)[1].item()
|
||||||
|
end_of_text_index.append(cast(int, position))
|
||||||
|
|
||||||
|
def pool(self, x: Float[Tensor, "1 77 1280"]) -> Float[Tensor, "1 1280"]:
|
||||||
|
end_of_text_index = self.use_context(context_name="text_encoder_pooling").get("end_of_text_index", [])
|
||||||
|
assert len(end_of_text_index) == 1, "End of text index not found."
|
||||||
|
return x[:, end_of_text_index[0], :]
|
||||||
|
|
||||||
|
|
||||||
|
class DoubleTextEncoder(fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
text_encoder_l: CLIPTextEncoderL | None = None,
|
||||||
|
text_encoder_g: CLIPTextEncoderG | None = None,
|
||||||
|
projection: fl.Linear | None = None,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
text_encoder_l = text_encoder_l or CLIPTextEncoderL(device=device, dtype=dtype)
|
||||||
|
text_encoder_g = text_encoder_g or CLIPTextEncoderG(device=device, dtype=dtype)
|
||||||
|
super().__init__(
|
||||||
|
fl.Parallel(text_encoder_l[:-2], text_encoder_g),
|
||||||
|
fl.Lambda(func=self.concatenate_embeddings),
|
||||||
|
)
|
||||||
|
TextEncoderWithPooling(target=text_encoder_g, projection=projection).inject(parent=self.Parallel)
|
||||||
|
|
||||||
|
def __call__(self, text: str) -> tuple[Float[Tensor, "1 77 2048"], Float[Tensor, "1 1280"]]:
|
||||||
|
return super().__call__(text)
|
||||||
|
|
||||||
|
def concatenate_embeddings(
|
||||||
|
self, text_embedding_l: Tensor, text_embedding_with_pooling: tuple[Tensor, Tensor]
|
||||||
|
) -> tuple[Tensor, Tensor]:
|
||||||
|
text_embedding_g, pooled_text_embedding = text_embedding_with_pooling
|
||||||
|
text_embedding = cat(tensors=[text_embedding_l, text_embedding_g], dim=-1)
|
||||||
|
return text_embedding, pooled_text_embedding
|
@ -0,0 +1,285 @@
|
|||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
|
|
||||||
|
import imaginairy.vendored.refiners.fluxion.layers as fl
|
||||||
|
from imaginairy.vendored.refiners.fluxion.context import Contexts
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.range_adapter import (
|
||||||
|
RangeAdapter2d,
|
||||||
|
RangeEncoder,
|
||||||
|
compute_sinusoidal_embedding,
|
||||||
|
)
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import (
|
||||||
|
ResidualAccumulator,
|
||||||
|
ResidualBlock,
|
||||||
|
ResidualConcatenator,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TextTimeEmbedding(fl.Chain):
|
||||||
|
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
|
self.timestep_embedding_dim = 1280
|
||||||
|
self.time_ids_embedding_dim = 256
|
||||||
|
self.text_time_embedding_dim = 2816
|
||||||
|
super().__init__(
|
||||||
|
fl.Concatenate(
|
||||||
|
fl.UseContext(context="diffusion", key="pooled_text_embedding"),
|
||||||
|
fl.Chain(
|
||||||
|
fl.UseContext(context="diffusion", key="time_ids"),
|
||||||
|
fl.Unsqueeze(dim=-1),
|
||||||
|
fl.Lambda(func=self.compute_sinuosoidal_embedding),
|
||||||
|
fl.Reshape(-1),
|
||||||
|
),
|
||||||
|
dim=1,
|
||||||
|
),
|
||||||
|
fl.Converter(set_device=False, set_dtype=True),
|
||||||
|
fl.Linear(
|
||||||
|
in_features=self.text_time_embedding_dim,
|
||||||
|
out_features=self.timestep_embedding_dim,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
fl.SiLU(),
|
||||||
|
fl.Linear(
|
||||||
|
in_features=self.timestep_embedding_dim,
|
||||||
|
out_features=self.timestep_embedding_dim,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def compute_sinuosoidal_embedding(self, x: Tensor) -> Tensor:
|
||||||
|
return compute_sinusoidal_embedding(x=x, embedding_dim=self.time_ids_embedding_dim)
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepEncoder(fl.Passthrough):
|
||||||
|
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
|
self.timestep_embedding_dim = 1280
|
||||||
|
super().__init__(
|
||||||
|
fl.Sum(
|
||||||
|
fl.Chain(
|
||||||
|
fl.UseContext(context="diffusion", key="timestep"),
|
||||||
|
RangeEncoder(
|
||||||
|
sinuosidal_embedding_dim=320,
|
||||||
|
embedding_dim=self.timestep_embedding_dim,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
TextTimeEmbedding(device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.SetContext(context="range_adapter", key="timestep_embedding"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SDXLCrossAttention(CrossAttentionBlock2d):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
num_attention_layers: int = 1,
|
||||||
|
num_attention_heads: int = 10,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
channels=channels,
|
||||||
|
context_embedding_dim=2048,
|
||||||
|
context_key="clip_text_embedding",
|
||||||
|
num_attention_layers=num_attention_layers,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
use_bias=False,
|
||||||
|
use_linear_projection=True,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DownBlocks(fl.Chain):
|
||||||
|
def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
|
self.in_channels = in_channels
|
||||||
|
|
||||||
|
in_block = fl.Chain(
|
||||||
|
fl.Conv2d(in_channels=in_channels, out_channels=320, kernel_size=3, padding=1, device=device, dtype=dtype)
|
||||||
|
)
|
||||||
|
first_blocks = [
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=320, out_channels=320, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=320, out_channels=320, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
fl.Downsample(channels=320, scale_factor=2, padding=1, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
second_blocks = [
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=320, out_channels=640, device=device, dtype=dtype),
|
||||||
|
SDXLCrossAttention(
|
||||||
|
channels=640, num_attention_layers=2, num_attention_heads=10, device=device, dtype=dtype
|
||||||
|
),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=640, out_channels=640, device=device, dtype=dtype),
|
||||||
|
SDXLCrossAttention(
|
||||||
|
channels=640, num_attention_layers=2, num_attention_heads=10, device=device, dtype=dtype
|
||||||
|
),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
fl.Downsample(channels=640, scale_factor=2, padding=1, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
third_blocks = [
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=640, out_channels=1280, device=device, dtype=dtype),
|
||||||
|
SDXLCrossAttention(
|
||||||
|
channels=1280, num_attention_layers=10, num_attention_heads=20, device=device, dtype=dtype
|
||||||
|
),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=1280, out_channels=1280, device=device, dtype=dtype),
|
||||||
|
SDXLCrossAttention(
|
||||||
|
channels=1280, num_attention_layers=10, num_attention_heads=20, device=device, dtype=dtype
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
in_block,
|
||||||
|
*first_blocks,
|
||||||
|
*second_blocks,
|
||||||
|
*third_blocks,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UpBlocks(fl.Chain):
|
||||||
|
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
|
first_blocks = [
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=2560, out_channels=1280, device=device, dtype=dtype),
|
||||||
|
SDXLCrossAttention(
|
||||||
|
channels=1280, num_attention_layers=10, num_attention_heads=20, device=device, dtype=dtype
|
||||||
|
),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=2560, out_channels=1280, device=device, dtype=dtype),
|
||||||
|
SDXLCrossAttention(
|
||||||
|
channels=1280, num_attention_layers=10, num_attention_heads=20, device=device, dtype=dtype
|
||||||
|
),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=1920, out_channels=1280, device=device, dtype=dtype),
|
||||||
|
SDXLCrossAttention(
|
||||||
|
channels=1280, num_attention_layers=10, num_attention_heads=20, device=device, dtype=dtype
|
||||||
|
),
|
||||||
|
fl.Upsample(channels=1280, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
second_blocks = [
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=1920, out_channels=640, device=device, dtype=dtype),
|
||||||
|
SDXLCrossAttention(
|
||||||
|
channels=640, num_attention_layers=2, num_attention_heads=10, device=device, dtype=dtype
|
||||||
|
),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=1280, out_channels=640, device=device, dtype=dtype),
|
||||||
|
SDXLCrossAttention(
|
||||||
|
channels=640, num_attention_layers=2, num_attention_heads=10, device=device, dtype=dtype
|
||||||
|
),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=960, out_channels=640, device=device, dtype=dtype),
|
||||||
|
SDXLCrossAttention(
|
||||||
|
channels=640, num_attention_layers=2, num_attention_heads=10, device=device, dtype=dtype
|
||||||
|
),
|
||||||
|
fl.Upsample(channels=640, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
third_blocks = [
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=960, out_channels=320, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=640, out_channels=320, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(in_channels=640, out_channels=320, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
*first_blocks,
|
||||||
|
*second_blocks,
|
||||||
|
*third_blocks,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MiddleBlock(fl.Chain):
|
||||||
|
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
|
super().__init__(
|
||||||
|
ResidualBlock(in_channels=1280, out_channels=1280, device=device, dtype=dtype),
|
||||||
|
SDXLCrossAttention(
|
||||||
|
channels=1280, num_attention_layers=10, num_attention_heads=20, device=device, dtype=dtype
|
||||||
|
),
|
||||||
|
ResidualBlock(in_channels=1280, out_channels=1280, device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OutputBlock(fl.Chain):
|
||||||
|
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
|
super().__init__(
|
||||||
|
fl.GroupNorm(channels=320, num_groups=32, device=device, dtype=dtype),
|
||||||
|
fl.SiLU(),
|
||||||
|
fl.Conv2d(in_channels=320, out_channels=4, kernel_size=3, stride=1, padding=1, device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SDXLUNet(fl.Chain):
|
||||||
|
def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
|
self.in_channels = in_channels
|
||||||
|
super().__init__(
|
||||||
|
TimestepEncoder(device=device, dtype=dtype),
|
||||||
|
DownBlocks(in_channels=in_channels, device=device, dtype=dtype),
|
||||||
|
MiddleBlock(device=device, dtype=dtype),
|
||||||
|
fl.Residual(fl.UseContext(context="unet", key="residuals").compose(lambda x: x[-1])),
|
||||||
|
UpBlocks(device=device, dtype=dtype),
|
||||||
|
OutputBlock(device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
for residual_block in self.layers(ResidualBlock):
|
||||||
|
chain = residual_block.Chain
|
||||||
|
RangeAdapter2d(
|
||||||
|
target=chain.Conv2d_1,
|
||||||
|
channels=residual_block.out_channels,
|
||||||
|
embedding_dim=1280,
|
||||||
|
context_key="timestep_embedding",
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
).inject(chain)
|
||||||
|
for n, block in enumerate(iterable=cast(list[fl.Chain], self.DownBlocks)):
|
||||||
|
block.append(module=ResidualAccumulator(n=n))
|
||||||
|
for n, block in enumerate(iterable=cast(list[fl.Chain], self.UpBlocks)):
|
||||||
|
block.insert(index=0, module=ResidualConcatenator(n=-n - 2))
|
||||||
|
|
||||||
|
def init_context(self) -> Contexts:
|
||||||
|
return {
|
||||||
|
"unet": {"residuals": [0.0] * 10},
|
||||||
|
"diffusion": {"timestep": None, "time_ids": None, "pooled_text_embedding": None},
|
||||||
|
"range_adapter": {"timestep_embedding": None},
|
||||||
|
"sampling": {"shapes": []},
|
||||||
|
}
|
||||||
|
|
||||||
|
def set_clip_text_embedding(self, clip_text_embedding: Tensor) -> None:
|
||||||
|
self.set_context(context="cross_attention_block", value={"clip_text_embedding": clip_text_embedding})
|
||||||
|
|
||||||
|
def set_timestep(self, timestep: Tensor) -> None:
|
||||||
|
self.set_context(context="diffusion", value={"timestep": timestep})
|
||||||
|
|
||||||
|
def set_time_ids(self, time_ids: Tensor) -> None:
|
||||||
|
self.set_context(context="diffusion", value={"time_ids": time_ids})
|
||||||
|
|
||||||
|
def set_pooled_text_embedding(self, pooled_text_embedding: Tensor) -> None:
|
||||||
|
self.set_context(context="diffusion", value={"pooled_text_embedding": pooled_text_embedding})
|
@ -0,0 +1,215 @@
|
|||||||
|
from typing import TYPE_CHECKING, Any, Generic, TypeVar
|
||||||
|
|
||||||
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
|
from torch.nn import AvgPool2d as _AvgPool2d
|
||||||
|
|
||||||
|
import imaginairy.vendored.refiners.fluxion.layers as fl
|
||||||
|
from imaginairy.vendored.refiners.fluxion.adapters.adapter import Adapter
|
||||||
|
from imaginairy.vendored.refiners.fluxion.context import Contexts
|
||||||
|
from imaginairy.vendored.refiners.fluxion.layers.module import Module
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
||||||
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
|
||||||
|
|
||||||
|
T = TypeVar("T", bound="SD1UNet | SDXLUNet")
|
||||||
|
TT2IAdapter = TypeVar("TT2IAdapter", bound="T2IAdapter[Any]") # Self (see PEP 673)
|
||||||
|
|
||||||
|
|
||||||
|
class Downsample2d(_AvgPool2d, Module):
|
||||||
|
def __init__(self, scale_factor: int) -> None:
|
||||||
|
_AvgPool2d.__init__(self, kernel_size=scale_factor, stride=scale_factor)
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualBlock(fl.Residual):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
fl.Conv2d(
|
||||||
|
in_channels=channels, out_channels=channels, kernel_size=3, padding=1, device=device, dtype=dtype
|
||||||
|
),
|
||||||
|
fl.ReLU(),
|
||||||
|
fl.Conv2d(in_channels=channels, out_channels=channels, kernel_size=1, device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualBlocks(fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
num_residual_blocks: int = 2,
|
||||||
|
downsample: bool = False,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
preproc = Downsample2d(scale_factor=2) if downsample else fl.Identity()
|
||||||
|
shortcut = (
|
||||||
|
fl.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, device=device, dtype=dtype)
|
||||||
|
if in_channels != out_channels
|
||||||
|
else fl.Identity()
|
||||||
|
)
|
||||||
|
super().__init__(
|
||||||
|
preproc,
|
||||||
|
shortcut,
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(channels=out_channels, device=device, dtype=dtype) for _ in range(num_residual_blocks)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class StatefulResidualBlocks(fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
num_residual_blocks: int = 2,
|
||||||
|
downsample: bool = False,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
ResidualBlocks(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
num_residual_blocks=num_residual_blocks,
|
||||||
|
downsample=downsample,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
fl.SetContext(context="t2iadapter", key="features", callback=self.push),
|
||||||
|
)
|
||||||
|
|
||||||
|
def push(self, features: list[Tensor], x: Tensor) -> None:
|
||||||
|
features.append(x)
|
||||||
|
|
||||||
|
|
||||||
|
class ConditionEncoder(fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int = 3,
|
||||||
|
channels: tuple[int, int, int, int] = (320, 640, 1280, 1280),
|
||||||
|
num_residual_blocks: int = 2,
|
||||||
|
downscale_factor: int = 8,
|
||||||
|
scale: float = 1.0,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.scale = scale
|
||||||
|
super().__init__(
|
||||||
|
fl.PixelUnshuffle(downscale_factor=downscale_factor),
|
||||||
|
fl.Conv2d(
|
||||||
|
in_channels=in_channels * downscale_factor**2,
|
||||||
|
out_channels=channels[0],
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
StatefulResidualBlocks(channels[0], channels[0], num_residual_blocks, device=device, dtype=dtype),
|
||||||
|
*(
|
||||||
|
StatefulResidualBlocks(
|
||||||
|
channels[i - 1], channels[i], num_residual_blocks, downsample=True, device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
for i in range(1, len(channels))
|
||||||
|
),
|
||||||
|
fl.UseContext(context="t2iadapter", key="features"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_context(self) -> Contexts:
|
||||||
|
return {"t2iadapter": {"features": []}}
|
||||||
|
|
||||||
|
|
||||||
|
class ConditionEncoderXL(ConditionEncoder, fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int = 3,
|
||||||
|
channels: tuple[int, int, int, int] = (320, 640, 1280, 1280),
|
||||||
|
num_residual_blocks: int = 2,
|
||||||
|
downscale_factor: int = 16,
|
||||||
|
scale: float = 1.0,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.scale = scale
|
||||||
|
fl.Chain.__init__(
|
||||||
|
self,
|
||||||
|
fl.PixelUnshuffle(downscale_factor=downscale_factor),
|
||||||
|
fl.Conv2d(
|
||||||
|
in_channels=in_channels * downscale_factor**2,
|
||||||
|
out_channels=channels[0],
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
StatefulResidualBlocks(channels[0], channels[0], num_residual_blocks, device=device, dtype=dtype),
|
||||||
|
StatefulResidualBlocks(channels[0], channels[1], num_residual_blocks, device=device, dtype=dtype),
|
||||||
|
StatefulResidualBlocks(
|
||||||
|
channels[1], channels[2], num_residual_blocks, downsample=True, device=device, dtype=dtype
|
||||||
|
),
|
||||||
|
StatefulResidualBlocks(channels[2], channels[3], num_residual_blocks, device=device, dtype=dtype),
|
||||||
|
fl.UseContext(context="t2iadapter", key="features"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class T2IFeatures(fl.Residual):
|
||||||
|
def __init__(self, name: str, index: int, scale: float = 1.0) -> None:
|
||||||
|
self.name = name
|
||||||
|
self.index = index
|
||||||
|
self.scale = scale
|
||||||
|
super().__init__(
|
||||||
|
fl.UseContext(context="t2iadapter", key=f"condition_features_{self.name}").compose(
|
||||||
|
func=lambda features: self.scale * features[self.index]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class T2IAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||||
|
_condition_encoder: list[ConditionEncoder] # prevent PyTorch module registration
|
||||||
|
_features: list[T2IFeatures] = []
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
target: T,
|
||||||
|
name: str,
|
||||||
|
condition_encoder: ConditionEncoder,
|
||||||
|
weights: dict[str, Tensor] | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.name = name
|
||||||
|
if weights is not None:
|
||||||
|
condition_encoder.load_state_dict(weights)
|
||||||
|
self._condition_encoder = [condition_encoder]
|
||||||
|
|
||||||
|
with self.setup_adapter(target):
|
||||||
|
super().__init__(target)
|
||||||
|
|
||||||
|
def inject(self: TT2IAdapter, parent: fl.Chain | None = None) -> TT2IAdapter:
|
||||||
|
return super().inject(parent)
|
||||||
|
|
||||||
|
def eject(self) -> None:
|
||||||
|
super().eject()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def condition_encoder(self) -> ConditionEncoder:
|
||||||
|
return self._condition_encoder[0]
|
||||||
|
|
||||||
|
def compute_condition_features(self, condition: Tensor) -> tuple[Tensor, ...]:
|
||||||
|
return self.condition_encoder(condition)
|
||||||
|
|
||||||
|
def set_condition_features(self, features: tuple[Tensor, ...]) -> None:
|
||||||
|
self.set_context("t2iadapter", {f"condition_features_{self.name}": features})
|
||||||
|
|
||||||
|
def set_scale(self, scale: float) -> None:
|
||||||
|
for f in self._features:
|
||||||
|
f.scale = scale
|
||||||
|
|
||||||
|
def init_context(self) -> Contexts:
|
||||||
|
return {"t2iadapter": {f"condition_features_{self.name}": None}}
|
||||||
|
|
||||||
|
def structural_copy(self: "TT2IAdapter") -> "TT2IAdapter":
|
||||||
|
raise RuntimeError("T2I-Adapter cannot be copied, eject it first.")
|
@ -0,0 +1,368 @@
|
|||||||
|
import torch
|
||||||
|
from torch import Tensor, device as Device, dtype as DType, nn
|
||||||
|
|
||||||
|
import imaginairy.vendored.refiners.fluxion.layers as fl
|
||||||
|
from imaginairy.vendored.refiners.fluxion.context import Contexts
|
||||||
|
from imaginairy.vendored.refiners.fluxion.utils import pad
|
||||||
|
|
||||||
|
|
||||||
|
class PatchEncoder(fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
patch_size: int = 16,
|
||||||
|
use_bias: bool = True,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.use_bias = use_bias
|
||||||
|
super().__init__(
|
||||||
|
fl.Conv2d(
|
||||||
|
in_channels=self.in_channels,
|
||||||
|
out_channels=self.out_channels,
|
||||||
|
kernel_size=(self.patch_size, self.patch_size),
|
||||||
|
stride=(self.patch_size, self.patch_size),
|
||||||
|
use_bias=self.use_bias,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
fl.Permute(0, 2, 3, 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PositionalEncoder(fl.Residual):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int,
|
||||||
|
image_embedding_size: tuple[int, int],
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.image_embedding_size = image_embedding_size
|
||||||
|
super().__init__(
|
||||||
|
fl.Parameter(
|
||||||
|
image_embedding_size[0],
|
||||||
|
image_embedding_size[1],
|
||||||
|
embedding_dim,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RelativePositionAttention(fl.WeightedModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
spatial_size: tuple[int, int],
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = embedding_dim // num_heads
|
||||||
|
self.spatial_size = spatial_size
|
||||||
|
self.horizontal_embedding = nn.Parameter(
|
||||||
|
data=torch.zeros(2 * spatial_size[0] - 1, self.head_dim, device=device, dtype=dtype)
|
||||||
|
)
|
||||||
|
self.vertical_embedding = nn.Parameter(
|
||||||
|
data=torch.zeros(2 * spatial_size[1] - 1, self.head_dim, device=device, dtype=dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> Device:
|
||||||
|
return self.horizontal_embedding.device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self) -> DType:
|
||||||
|
return self.horizontal_embedding.dtype
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
batch, height, width, _ = x.shape
|
||||||
|
x = (
|
||||||
|
x.reshape(batch, width * height, 3, self.num_heads, -1)
|
||||||
|
.permute(2, 0, 3, 1, 4)
|
||||||
|
.reshape(3, batch * self.num_heads, width * height, -1)
|
||||||
|
)
|
||||||
|
query, key, value = x.unbind(dim=0)
|
||||||
|
horizontal_relative_embedding, vertical_relative_embedding = self.compute_relative_embedding(x=query)
|
||||||
|
attention = (query * self.head_dim**-0.5) @ key.transpose(dim0=-2, dim1=-1)
|
||||||
|
# Order of operations is important here
|
||||||
|
attention = (
|
||||||
|
(attention.reshape(-1, height, width, height, width) + vertical_relative_embedding)
|
||||||
|
+ horizontal_relative_embedding
|
||||||
|
).reshape(attention.shape)
|
||||||
|
attention = attention.softmax(dim=-1)
|
||||||
|
attention = attention @ value
|
||||||
|
attention = (
|
||||||
|
attention.reshape(batch, self.num_heads, height, width, -1)
|
||||||
|
.permute(0, 2, 3, 1, 4)
|
||||||
|
.reshape(batch, height, width, -1)
|
||||||
|
)
|
||||||
|
return attention
|
||||||
|
|
||||||
|
def compute_relative_coords(self, size: int) -> Tensor:
|
||||||
|
x, y = torch.meshgrid(torch.arange(end=size), torch.arange(end=size), indexing="ij")
|
||||||
|
return x - y + size - 1
|
||||||
|
|
||||||
|
def compute_relative_embedding(self, x: Tensor) -> tuple[Tensor, Tensor]:
|
||||||
|
width, height = self.spatial_size
|
||||||
|
horizontal_coords = self.compute_relative_coords(size=width)
|
||||||
|
vertical_coords = self.compute_relative_coords(size=height)
|
||||||
|
horizontal_positional_embedding = self.horizontal_embedding[horizontal_coords]
|
||||||
|
vertical_positional_embedding = self.vertical_embedding[vertical_coords]
|
||||||
|
x = x.reshape(x.shape[0], width, height, -1)
|
||||||
|
horizontal_relative_embedding = torch.einsum("bhwc,wkc->bhwk", x, horizontal_positional_embedding).unsqueeze(
|
||||||
|
dim=-2
|
||||||
|
)
|
||||||
|
vertical_relative_embedding = torch.einsum("bhwc,hkc->bhwk", x, vertical_positional_embedding).unsqueeze(dim=-1)
|
||||||
|
|
||||||
|
return horizontal_relative_embedding, vertical_relative_embedding
|
||||||
|
|
||||||
|
|
||||||
|
class FusedSelfAttention(fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int = 768,
|
||||||
|
spatial_size: tuple[int, int] = (64, 64),
|
||||||
|
num_heads: int = 1,
|
||||||
|
use_bias: bool = True,
|
||||||
|
is_causal: bool = False,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
assert (
|
||||||
|
embedding_dim % num_heads == 0
|
||||||
|
), f"Embedding dim (embedding_dim={embedding_dim}) must be divisible by num heads (num_heads={num_heads})"
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.use_bias = use_bias
|
||||||
|
self.is_causal = is_causal
|
||||||
|
super().__init__(
|
||||||
|
fl.Linear(
|
||||||
|
in_features=self.embedding_dim,
|
||||||
|
out_features=3 * self.embedding_dim,
|
||||||
|
bias=self.use_bias,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
RelativePositionAttention(
|
||||||
|
embedding_dim=self.embedding_dim,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
spatial_size=spatial_size,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
fl.Linear(
|
||||||
|
in_features=self.embedding_dim,
|
||||||
|
out_features=self.embedding_dim,
|
||||||
|
bias=True,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int,
|
||||||
|
feedforward_dim: int,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.feedforward_dim = feedforward_dim
|
||||||
|
super().__init__(
|
||||||
|
fl.Linear(
|
||||||
|
in_features=self.embedding_dim,
|
||||||
|
out_features=self.feedforward_dim,
|
||||||
|
bias=True,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
fl.GeLU(),
|
||||||
|
fl.Linear(
|
||||||
|
in_features=self.feedforward_dim,
|
||||||
|
out_features=self.embedding_dim,
|
||||||
|
bias=True,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WindowPartition(fl.ContextModule):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
batch, height, width, channels = x.shape
|
||||||
|
context = self.use_context(context_name="window_partition")
|
||||||
|
context.update({"original_height": height, "original_width": width})
|
||||||
|
window_size = context["window_size"]
|
||||||
|
padding_height = (window_size - height % window_size) % window_size
|
||||||
|
padding_width = (window_size - width % window_size) % window_size
|
||||||
|
if padding_height > 0 or padding_width > 0:
|
||||||
|
x = pad(x=x, pad=(0, 0, 0, padding_width, 0, padding_height))
|
||||||
|
padded_height, padded_width = height + padding_height, width + padding_width
|
||||||
|
context.update({"padded_height": padded_height, "padded_width": padded_width})
|
||||||
|
x = x.view(batch, padded_height // window_size, window_size, padded_width // window_size, window_size, channels)
|
||||||
|
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, channels)
|
||||||
|
return windows
|
||||||
|
|
||||||
|
|
||||||
|
class WindowMerge(fl.ContextModule):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
context = self.use_context(context_name="window_partition")
|
||||||
|
window_size = context["window_size"]
|
||||||
|
padded_height, padded_width = context["padded_height"], context["padded_width"]
|
||||||
|
original_height, original_width = context["original_height"], context["original_width"]
|
||||||
|
batch_size = x.shape[0] // (padded_height * padded_width // window_size // window_size)
|
||||||
|
x = x.view(batch_size, padded_height // window_size, padded_width // window_size, window_size, window_size, -1)
|
||||||
|
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, padded_height, padded_width, -1)
|
||||||
|
if padded_height > original_height or padded_width > original_width:
|
||||||
|
x = x[:, :original_height, :original_width, :].contiguous()
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerLayer(fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
feedforward_dim: int,
|
||||||
|
image_embedding_size: tuple[int, int],
|
||||||
|
window_size: int | None = None,
|
||||||
|
layer_norm_eps: float = 1e-6,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.feedforward_dim = feedforward_dim
|
||||||
|
self.window_size = window_size
|
||||||
|
self.layer_norm_eps = layer_norm_eps
|
||||||
|
self.image_embedding_size = image_embedding_size
|
||||||
|
attention_spatial_size = (window_size, window_size) if window_size is not None else image_embedding_size
|
||||||
|
reshape_or_merge = (
|
||||||
|
WindowMerge()
|
||||||
|
if self.window_size is not None
|
||||||
|
else fl.Reshape(self.image_embedding_size[0], self.image_embedding_size[1], embedding_dim)
|
||||||
|
)
|
||||||
|
super().__init__(
|
||||||
|
fl.Residual(
|
||||||
|
fl.LayerNorm(normalized_shape=embedding_dim, eps=self.layer_norm_eps, device=device, dtype=dtype),
|
||||||
|
WindowPartition() if self.window_size is not None else fl.Identity(),
|
||||||
|
FusedSelfAttention(
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
spatial_size=attention_spatial_size,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
reshape_or_merge,
|
||||||
|
),
|
||||||
|
fl.Residual(
|
||||||
|
fl.LayerNorm(normalized_shape=embedding_dim, eps=self.layer_norm_eps, device=device, dtype=dtype),
|
||||||
|
FeedForward(embedding_dim=embedding_dim, feedforward_dim=feedforward_dim, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_context(self) -> Contexts:
|
||||||
|
return {"window_partition": {"window_size": self.window_size}}
|
||||||
|
|
||||||
|
|
||||||
|
class Neck(fl.Chain):
|
||||||
|
def __init__(self, in_channels: int = 768, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
|
self.in_channels = in_channels
|
||||||
|
super().__init__(
|
||||||
|
fl.Permute(0, 3, 1, 2),
|
||||||
|
fl.Conv2d(
|
||||||
|
in_channels=self.in_channels,
|
||||||
|
out_channels=256,
|
||||||
|
kernel_size=1,
|
||||||
|
use_bias=False,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
fl.LayerNorm2d(channels=256, device=device, dtype=dtype),
|
||||||
|
fl.Conv2d(
|
||||||
|
in_channels=256,
|
||||||
|
out_channels=256,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
use_bias=False,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
fl.LayerNorm2d(channels=256, device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Transformer(fl.Chain):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SAMViT(fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int,
|
||||||
|
num_layers: int,
|
||||||
|
num_heads: int,
|
||||||
|
global_attention_indices: tuple[int, ...] | None = None,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.image_size = (1024, 1024)
|
||||||
|
self.patch_size = 16
|
||||||
|
self.window_size = 14
|
||||||
|
self.image_embedding_size = (self.image_size[0] // self.patch_size, self.image_size[1] // self.patch_size)
|
||||||
|
self.feed_forward_dim = 4 * self.embedding_dim
|
||||||
|
self.global_attention_indices = global_attention_indices or tuple()
|
||||||
|
super().__init__(
|
||||||
|
PatchEncoder(
|
||||||
|
in_channels=3, out_channels=embedding_dim, patch_size=self.patch_size, device=device, dtype=dtype
|
||||||
|
),
|
||||||
|
PositionalEncoder(
|
||||||
|
embedding_dim=embedding_dim, image_embedding_size=self.image_embedding_size, device=device, dtype=dtype
|
||||||
|
),
|
||||||
|
Transformer(
|
||||||
|
TransformerLayer(
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
feedforward_dim=self.feed_forward_dim,
|
||||||
|
window_size=self.window_size if i not in self.global_attention_indices else None,
|
||||||
|
image_embedding_size=self.image_embedding_size,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
for i in range(num_layers)
|
||||||
|
),
|
||||||
|
Neck(in_channels=embedding_dim, device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SAMViTH(SAMViT):
|
||||||
|
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
|
super().__init__(
|
||||||
|
embedding_dim=1280,
|
||||||
|
num_layers=32,
|
||||||
|
num_heads=16,
|
||||||
|
global_attention_indices=(7, 15, 23, 31),
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
@ -0,0 +1,264 @@
|
|||||||
|
import torch
|
||||||
|
from torch import Tensor, device as Device, dtype as DType, nn
|
||||||
|
|
||||||
|
import imaginairy.vendored.refiners.fluxion.layers as fl
|
||||||
|
from imaginairy.vendored.refiners.fluxion.context import Contexts
|
||||||
|
from imaginairy.vendored.refiners.foundationals.segment_anything.transformer import (
|
||||||
|
SparseCrossDenseAttention,
|
||||||
|
TwoWayTranformerLayer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingsAggregator(fl.ContextModule):
|
||||||
|
def __init__(self, num_output_mask: int = 3) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.num_mask_tokens = num_output_mask
|
||||||
|
|
||||||
|
def forward(self, iou_mask_tokens: Tensor) -> Tensor:
|
||||||
|
mask_decoder = self.ensure_parent
|
||||||
|
mask_decoder_context = mask_decoder.use_context(context_name="mask_decoder")
|
||||||
|
image_embedding = mask_decoder_context["image_embedding"]
|
||||||
|
point_embedding = mask_decoder_context["point_embedding"]
|
||||||
|
mask_embedding = mask_decoder_context["mask_embedding"]
|
||||||
|
dense_positional_embedding = mask_decoder_context["dense_positional_embedding"]
|
||||||
|
|
||||||
|
sparse_embedding = torch.cat(tensors=(iou_mask_tokens, point_embedding), dim=1)
|
||||||
|
dense_embedding = (image_embedding + mask_embedding).flatten(start_dim=2).transpose(1, 2)
|
||||||
|
if dense_positional_embedding.shape != dense_embedding.shape:
|
||||||
|
dense_positional_embedding = dense_positional_embedding.flatten(start_dim=2).transpose(1, 2)
|
||||||
|
|
||||||
|
mask_decoder_context.update(
|
||||||
|
{
|
||||||
|
"dense_embedding": dense_embedding,
|
||||||
|
"dense_positional_embedding": dense_positional_embedding,
|
||||||
|
"sparse_embedding": sparse_embedding,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
mask_decoder.set_context(context="mask_decoder", value=mask_decoder_context)
|
||||||
|
|
||||||
|
return sparse_embedding
|
||||||
|
|
||||||
|
|
||||||
|
class Transformer(fl.Chain):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Hypernetworks(fl.Concatenate):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int = 256,
|
||||||
|
num_layers: int = 3,
|
||||||
|
num_mask_tokens: int = 3,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.num_mask_tokens = num_mask_tokens
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
*[
|
||||||
|
fl.Chain(
|
||||||
|
fl.Slicing(dim=1, start=i + 1, end=i + 2),
|
||||||
|
fl.MultiLinear(
|
||||||
|
input_dim=embedding_dim,
|
||||||
|
output_dim=embedding_dim // 8,
|
||||||
|
inner_dim=embedding_dim,
|
||||||
|
num_layers=num_layers,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for i in range(num_mask_tokens + 1)
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DenseEmbeddingUpscaling(fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int = 256,
|
||||||
|
dense_embedding_side_dim: int = 64,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.dense_embedding_side_dim = dense_embedding_side_dim
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
fl.UseContext(context="mask_decoder", key="dense_embedding"),
|
||||||
|
fl.Transpose(dim0=1, dim1=2),
|
||||||
|
fl.Reshape(embedding_dim, dense_embedding_side_dim, dense_embedding_side_dim),
|
||||||
|
fl.ConvTranspose2d(
|
||||||
|
in_channels=embedding_dim,
|
||||||
|
out_channels=embedding_dim // 4,
|
||||||
|
kernel_size=2,
|
||||||
|
stride=2,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
fl.LayerNorm2d(channels=embedding_dim // 4, device=device, dtype=dtype),
|
||||||
|
fl.GeLU(),
|
||||||
|
fl.ConvTranspose2d(
|
||||||
|
in_channels=embedding_dim // 4,
|
||||||
|
out_channels=embedding_dim // 8,
|
||||||
|
kernel_size=2,
|
||||||
|
stride=2,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
fl.GeLU(),
|
||||||
|
fl.Flatten(start_dim=2),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class IOUMaskEncoder(fl.WeightedModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int = 256,
|
||||||
|
num_mask_tokens: int = 4,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.num_mask_tokens = num_mask_tokens
|
||||||
|
# aka prompt tokens + output token (for IoU scores prediction)
|
||||||
|
self.weight = nn.Parameter(data=torch.randn(num_mask_tokens + 1, embedding_dim, device=device, dtype=dtype))
|
||||||
|
|
||||||
|
def forward(self) -> Tensor:
|
||||||
|
return self.weight.unsqueeze(dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
class MaskPrediction(fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int,
|
||||||
|
num_mask_tokens: int,
|
||||||
|
num_layers: int = 3,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.num_mask_tokens = num_mask_tokens
|
||||||
|
self.num_layers = num_layers
|
||||||
|
super().__init__(
|
||||||
|
fl.Matmul(
|
||||||
|
input=Hypernetworks(
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
num_layers=num_layers,
|
||||||
|
num_mask_tokens=num_mask_tokens,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
other=DenseEmbeddingUpscaling(embedding_dim=embedding_dim, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.Slicing(dim=1, start=1),
|
||||||
|
fl.Reshape(num_mask_tokens, embedding_dim, embedding_dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class IOUPrediction(fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int,
|
||||||
|
num_layers: int,
|
||||||
|
num_mask_tokens: int,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.num_layers = num_layers
|
||||||
|
super().__init__(
|
||||||
|
fl.Slicing(dim=1, start=0, end=1),
|
||||||
|
fl.Squeeze(dim=0),
|
||||||
|
fl.MultiLinear(
|
||||||
|
input_dim=embedding_dim,
|
||||||
|
output_dim=num_mask_tokens + 1,
|
||||||
|
inner_dim=embedding_dim,
|
||||||
|
num_layers=num_layers,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
fl.Slicing(dim=-1, start=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MaskDecoder(fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int = 256,
|
||||||
|
feed_forward_dim: int = 2048,
|
||||||
|
num_layers: int = 2,
|
||||||
|
num_output_mask: int = 3,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.num_mask_tokens = num_output_mask
|
||||||
|
self.feed_forward_dim = feed_forward_dim
|
||||||
|
self.num_layers = num_layers
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
IOUMaskEncoder(
|
||||||
|
embedding_dim=embedding_dim, num_mask_tokens=num_output_mask + 1, device=device, dtype=dtype
|
||||||
|
),
|
||||||
|
EmbeddingsAggregator(num_output_mask=num_output_mask),
|
||||||
|
Transformer(
|
||||||
|
*(
|
||||||
|
TwoWayTranformerLayer(
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
num_heads=8,
|
||||||
|
feed_forward_dim=feed_forward_dim,
|
||||||
|
use_residual_self_attention=i > 0,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
for i in range(num_layers)
|
||||||
|
),
|
||||||
|
SparseCrossDenseAttention(embedding_dim=embedding_dim, device=device, dtype=dtype),
|
||||||
|
fl.LayerNorm(normalized_shape=embedding_dim, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
fl.Parallel(
|
||||||
|
MaskPrediction(
|
||||||
|
embedding_dim=embedding_dim, num_mask_tokens=num_output_mask, device=device, dtype=dtype
|
||||||
|
),
|
||||||
|
IOUPrediction(
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
num_layers=3,
|
||||||
|
num_mask_tokens=num_output_mask,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_context(self) -> Contexts:
|
||||||
|
return {
|
||||||
|
"mask_decoder": {
|
||||||
|
"image_embedding": None,
|
||||||
|
"point_embedding": None,
|
||||||
|
"mask_embedding": None,
|
||||||
|
"dense_positional_embedding": None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def set_image_embedding(self, image_embedding: Tensor) -> None:
|
||||||
|
mask_decoder_context = self.use_context(context_name="mask_decoder")
|
||||||
|
mask_decoder_context["image_embedding"] = image_embedding
|
||||||
|
|
||||||
|
def set_point_embedding(self, point_embedding: Tensor) -> None:
|
||||||
|
mask_decoder_context = self.use_context(context_name="mask_decoder")
|
||||||
|
mask_decoder_context["point_embedding"] = point_embedding
|
||||||
|
|
||||||
|
def set_mask_embedding(self, mask_embedding: Tensor) -> None:
|
||||||
|
mask_decoder_context = self.use_context(context_name="mask_decoder")
|
||||||
|
mask_decoder_context["mask_embedding"] = mask_embedding
|
||||||
|
|
||||||
|
def set_dense_positional_embedding(self, dense_positional_embedding: Tensor) -> None:
|
||||||
|
mask_decoder_context = self.use_context(context_name="mask_decoder")
|
||||||
|
mask_decoder_context["dense_positional_embedding"] = dense_positional_embedding
|
@ -0,0 +1,170 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Sequence
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
|
|
||||||
|
import imaginairy.vendored.refiners.fluxion.layers as fl
|
||||||
|
from imaginairy.vendored.refiners.fluxion.utils import image_to_tensor, interpolate, no_grad, normalize, pad
|
||||||
|
from imaginairy.vendored.refiners.foundationals.segment_anything.image_encoder import SAMViT, SAMViTH
|
||||||
|
from imaginairy.vendored.refiners.foundationals.segment_anything.mask_decoder import MaskDecoder
|
||||||
|
from imaginairy.vendored.refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ImageEmbedding:
|
||||||
|
features: Tensor
|
||||||
|
original_image_size: tuple[int, int] # (height, width)
|
||||||
|
|
||||||
|
|
||||||
|
class SegmentAnything(fl.Module):
|
||||||
|
mask_threshold: float = 0.0
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
image_encoder: SAMViT,
|
||||||
|
point_encoder: PointEncoder,
|
||||||
|
mask_encoder: MaskEncoder,
|
||||||
|
mask_decoder: MaskDecoder,
|
||||||
|
device: Device | str = "cpu",
|
||||||
|
dtype: DType = torch.float32,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.device: Device = device if isinstance(device, Device) else Device(device=device)
|
||||||
|
self.dtype = dtype
|
||||||
|
self.image_encoder = image_encoder.to(device=self.device, dtype=self.dtype)
|
||||||
|
self.point_encoder = point_encoder.to(device=self.device, dtype=self.dtype)
|
||||||
|
self.mask_encoder = mask_encoder.to(device=self.device, dtype=self.dtype)
|
||||||
|
self.mask_decoder = mask_decoder.to(device=self.device, dtype=self.dtype)
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def compute_image_embedding(self, image: Image.Image) -> ImageEmbedding:
|
||||||
|
original_size = (image.height, image.width)
|
||||||
|
target_size = self.compute_target_size(original_size)
|
||||||
|
return ImageEmbedding(
|
||||||
|
features=self.image_encoder(self.preprocess_image(image=image, target_size=target_size)),
|
||||||
|
original_image_size=original_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def predict(
|
||||||
|
self,
|
||||||
|
input: Image.Image | ImageEmbedding,
|
||||||
|
foreground_points: Sequence[tuple[float, float]] | None = None,
|
||||||
|
background_points: Sequence[tuple[float, float]] | None = None,
|
||||||
|
box_points: Sequence[Sequence[tuple[float, float]]] | None = None,
|
||||||
|
masks: Sequence[Image.Image] | None = None,
|
||||||
|
binarize: bool = True,
|
||||||
|
) -> tuple[Tensor, Tensor, Tensor]:
|
||||||
|
if isinstance(input, ImageEmbedding):
|
||||||
|
original_size = input.original_image_size
|
||||||
|
target_size = self.compute_target_size(original_size)
|
||||||
|
image_embedding = input.features
|
||||||
|
else:
|
||||||
|
original_size = (input.height, input.width)
|
||||||
|
target_size = self.compute_target_size(original_size)
|
||||||
|
image_embedding = self.image_encoder(self.preprocess_image(image=input, target_size=target_size))
|
||||||
|
|
||||||
|
coordinates, type_mask = self.point_encoder.points_to_tensor(
|
||||||
|
foreground_points=foreground_points,
|
||||||
|
background_points=background_points,
|
||||||
|
box_points=box_points,
|
||||||
|
)
|
||||||
|
self.point_encoder.set_type_mask(type_mask=type_mask)
|
||||||
|
|
||||||
|
if masks is not None:
|
||||||
|
mask_tensor = torch.stack(
|
||||||
|
tensors=[image_to_tensor(image=mask, device=self.device, dtype=self.dtype) for mask in masks]
|
||||||
|
)
|
||||||
|
mask_embedding = self.mask_encoder(mask_tensor)
|
||||||
|
else:
|
||||||
|
mask_embedding = self.mask_encoder.get_no_mask_dense_embedding(
|
||||||
|
image_embedding_size=self.image_encoder.image_embedding_size
|
||||||
|
)
|
||||||
|
point_embedding = self.point_encoder(
|
||||||
|
self.normalize(coordinates, target_size=target_size, original_size=original_size)
|
||||||
|
)
|
||||||
|
dense_positional_embedding = self.point_encoder.get_dense_positional_embedding(
|
||||||
|
image_embedding_size=self.image_encoder.image_embedding_size
|
||||||
|
)
|
||||||
|
|
||||||
|
self.mask_decoder.set_image_embedding(image_embedding=image_embedding)
|
||||||
|
self.mask_decoder.set_mask_embedding(mask_embedding=mask_embedding)
|
||||||
|
self.mask_decoder.set_point_embedding(point_embedding=point_embedding)
|
||||||
|
self.mask_decoder.set_dense_positional_embedding(dense_positional_embedding=dense_positional_embedding)
|
||||||
|
|
||||||
|
low_res_masks, iou_predictions = self.mask_decoder()
|
||||||
|
|
||||||
|
high_res_masks = self.postprocess_masks(
|
||||||
|
masks=low_res_masks, target_size=target_size, original_size=original_size
|
||||||
|
)
|
||||||
|
|
||||||
|
if binarize:
|
||||||
|
high_res_masks = high_res_masks > self.mask_threshold
|
||||||
|
|
||||||
|
return high_res_masks, iou_predictions, low_res_masks
|
||||||
|
|
||||||
|
@property
|
||||||
|
def image_size(self) -> int:
|
||||||
|
w, h = self.image_encoder.image_size
|
||||||
|
assert w == h
|
||||||
|
return w
|
||||||
|
|
||||||
|
def compute_target_size(self, size: tuple[int, int]) -> tuple[int, int]:
|
||||||
|
oldh, oldw = size
|
||||||
|
scale = self.image_size * 1.0 / max(oldh, oldw)
|
||||||
|
newh, neww = oldh * scale, oldw * scale
|
||||||
|
neww = int(neww + 0.5)
|
||||||
|
newh = int(newh + 0.5)
|
||||||
|
return (newh, neww)
|
||||||
|
|
||||||
|
def preprocess_image(self, image: Image.Image, target_size: tuple[int, int]) -> Tensor:
|
||||||
|
h, w = target_size
|
||||||
|
padh = self.image_size - h
|
||||||
|
padw = self.image_size - w
|
||||||
|
image_tensor = torch.tensor(
|
||||||
|
np.array(image.resize((w, h), resample=Image.Resampling.BILINEAR)).astype(np.float32).transpose(2, 0, 1),
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.dtype,
|
||||||
|
).unsqueeze(0)
|
||||||
|
return pad(
|
||||||
|
normalize(image_tensor, mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]), (0, padw, 0, padh)
|
||||||
|
)
|
||||||
|
|
||||||
|
def normalize(self, coordinates: Tensor, target_size: tuple[int, int], original_size: tuple[int, int]) -> Tensor:
|
||||||
|
coordinates[:, :, 0] = ((coordinates[:, :, 0] * (target_size[1] / original_size[1])) + 0.5) / self.image_size
|
||||||
|
coordinates[:, :, 1] = ((coordinates[:, :, 1] * (target_size[0] / original_size[0])) + 0.5) / self.image_size
|
||||||
|
return coordinates
|
||||||
|
|
||||||
|
def postprocess_masks(self, masks: Tensor, target_size: tuple[int, int], original_size: tuple[int, int]) -> Tensor:
|
||||||
|
masks = interpolate(masks, factor=torch.Size((self.image_size, self.image_size)), mode="bilinear")
|
||||||
|
masks = masks[..., : target_size[0], : target_size[1]] # remove padding added at `preprocess_image` time
|
||||||
|
masks = interpolate(masks, factor=torch.Size(original_size), mode="bilinear")
|
||||||
|
return masks
|
||||||
|
|
||||||
|
|
||||||
|
class SegmentAnythingH(SegmentAnything):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
image_encoder: SAMViTH | None = None,
|
||||||
|
point_encoder: PointEncoder | None = None,
|
||||||
|
mask_encoder: MaskEncoder | None = None,
|
||||||
|
mask_decoder: MaskDecoder | None = None,
|
||||||
|
device: Device | str = "cpu",
|
||||||
|
dtype: DType = torch.float32,
|
||||||
|
) -> None:
|
||||||
|
image_encoder = image_encoder or SAMViTH()
|
||||||
|
point_encoder = point_encoder or PointEncoder()
|
||||||
|
mask_encoder = mask_encoder or MaskEncoder()
|
||||||
|
mask_decoder = mask_decoder or MaskDecoder()
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
image_encoder=image_encoder,
|
||||||
|
point_encoder=point_encoder,
|
||||||
|
mask_encoder=mask_encoder,
|
||||||
|
mask_decoder=mask_decoder,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
@ -0,0 +1,192 @@
|
|||||||
|
from collections.abc import Sequence
|
||||||
|
from enum import Enum, auto
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from jaxtyping import Float, Int
|
||||||
|
from torch import Tensor, device as Device, dtype as DType, nn
|
||||||
|
|
||||||
|
import imaginairy.vendored.refiners.fluxion.layers as fl
|
||||||
|
from imaginairy.vendored.refiners.fluxion.context import Contexts
|
||||||
|
|
||||||
|
|
||||||
|
class CoordinateEncoder(fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_positional_features: int = 64,
|
||||||
|
scale: float = 1,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.num_positional_features = num_positional_features
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
fl.Multiply(scale=2, bias=-1),
|
||||||
|
fl.Linear(in_features=2, out_features=num_positional_features, bias=False, device=device, dtype=dtype),
|
||||||
|
fl.Multiply(scale=2 * torch.pi * self.scale),
|
||||||
|
fl.Concatenate(fl.Sin(), fl.Cos(), dim=-1),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PointType(Enum):
|
||||||
|
BACKGROUND = auto()
|
||||||
|
FOREGROUND = auto()
|
||||||
|
BOX_TOP_LEFT = auto()
|
||||||
|
BOX_BOTTOM_RIGHT = auto()
|
||||||
|
NOT_A_POINT = auto()
|
||||||
|
|
||||||
|
|
||||||
|
class PointTypeEmbedding(fl.WeightedModule, fl.ContextModule):
|
||||||
|
def __init__(self, embedding_dim: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.weight = nn.Parameter(data=torch.randn(len(PointType), self.embedding_dim, device=device, dtype=dtype))
|
||||||
|
|
||||||
|
def forward(self, type_mask: Int[Tensor, "1 num_points"]) -> Float[Tensor, "1 num_points embedding_dim"]:
|
||||||
|
assert isinstance(type_mask, Tensor), "type_mask must be a Tensor."
|
||||||
|
|
||||||
|
embeddings = torch.zeros(*type_mask.shape, self.embedding_dim).to(device=type_mask.device)
|
||||||
|
for type_id in PointType:
|
||||||
|
mask = type_mask == type_id.value
|
||||||
|
embeddings[mask] = self.weight[type_id.value - 1]
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
class PointEncoder(fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self, embedding_dim: int = 256, scale: float = 1, device: Device | str | None = None, dtype: DType | None = None
|
||||||
|
) -> None:
|
||||||
|
assert embedding_dim % 2 == 0, "embedding_dim must be divisible by 2."
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
CoordinateEncoder(num_positional_features=embedding_dim // 2, scale=scale, device=device, dtype=dtype),
|
||||||
|
fl.Lambda(func=self.pad),
|
||||||
|
fl.Residual(
|
||||||
|
fl.UseContext(context="point_encoder", key="type_mask"),
|
||||||
|
PointTypeEmbedding(embedding_dim=embedding_dim, device=device, dtype=dtype),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def pad(self, x: Tensor) -> Tensor:
|
||||||
|
type_mask: Tensor = self.use_context("point_encoder")["type_mask"]
|
||||||
|
if torch.any((type_mask == PointType.BOX_TOP_LEFT.value) | (type_mask == PointType.BOX_BOTTOM_RIGHT.value)):
|
||||||
|
# Some boxes have been passed: no need to pad in this case
|
||||||
|
return x
|
||||||
|
type_mask = torch.cat(
|
||||||
|
[type_mask, torch.full((type_mask.shape[0], 1), PointType.NOT_A_POINT.value, device=type_mask.device)],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
self.set_context(context="point_encoder", value={"type_mask": type_mask})
|
||||||
|
return torch.cat([x, torch.zeros((x.shape[0], 1, x.shape[-1]), device=x.device)], dim=1)
|
||||||
|
|
||||||
|
def init_context(self) -> Contexts:
|
||||||
|
return {
|
||||||
|
"point_encoder": {
|
||||||
|
"type_mask": None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def set_type_mask(self, type_mask: Int[Tensor, "1 num_points"]) -> None:
|
||||||
|
self.set_context(context="point_encoder", value={"type_mask": type_mask})
|
||||||
|
|
||||||
|
def get_dense_positional_embedding(
|
||||||
|
self, image_embedding_size: tuple[int, int]
|
||||||
|
) -> Float[Tensor, "num_positional_features height width"]:
|
||||||
|
coordinate_encoder = self.ensure_find(layer_type=CoordinateEncoder)
|
||||||
|
height, width = image_embedding_size
|
||||||
|
grid = torch.ones((height, width), device=self.device, dtype=torch.float32)
|
||||||
|
y_embedding = grid.cumsum(dim=0) - 0.5
|
||||||
|
x_embedding = grid.cumsum(dim=1) - 0.5
|
||||||
|
y_embedding = y_embedding / height
|
||||||
|
x_embedding = x_embedding / width
|
||||||
|
positional_embedding = (
|
||||||
|
coordinate_encoder(torch.stack(tensors=[x_embedding, y_embedding], dim=-1))
|
||||||
|
.permute(2, 0, 1)
|
||||||
|
.unsqueeze(dim=0)
|
||||||
|
)
|
||||||
|
return positional_embedding
|
||||||
|
|
||||||
|
def points_to_tensor(
|
||||||
|
self,
|
||||||
|
foreground_points: Sequence[tuple[float, float]] | None = None,
|
||||||
|
background_points: Sequence[tuple[float, float]] | None = None,
|
||||||
|
not_a_points: Sequence[tuple[float, float]] | None = None,
|
||||||
|
box_points: Sequence[Sequence[tuple[float, float]]] | None = None,
|
||||||
|
) -> tuple[Float[Tensor, "1 num_points 2"], Int[Tensor, "1 num_points"]]:
|
||||||
|
foreground_points = foreground_points or []
|
||||||
|
background_points = background_points or []
|
||||||
|
not_a_points = not_a_points or []
|
||||||
|
box_points = box_points or []
|
||||||
|
top_left_points = [box[0] for box in box_points]
|
||||||
|
bottom_right_points = [box[1] for box in box_points]
|
||||||
|
coordinates: list[Tensor] = []
|
||||||
|
type_ids: list[Tensor] = []
|
||||||
|
|
||||||
|
# Must be in sync with PointType enum
|
||||||
|
for type_id, coords_seq in zip(
|
||||||
|
PointType, [background_points, foreground_points, top_left_points, bottom_right_points, not_a_points]
|
||||||
|
):
|
||||||
|
if len(coords_seq) > 0:
|
||||||
|
coords_tensor = torch.tensor(data=list(coords_seq), dtype=torch.float, device=self.device)
|
||||||
|
coordinates.append(coords_tensor)
|
||||||
|
point_ids = torch.tensor(data=[type_id.value] * len(coords_seq), dtype=torch.int, device=self.device)
|
||||||
|
type_ids.append(point_ids)
|
||||||
|
|
||||||
|
all_coordinates = torch.cat(tensors=coordinates, dim=0).unsqueeze(dim=0)
|
||||||
|
type_mask = torch.cat(tensors=type_ids, dim=0).unsqueeze(dim=0)
|
||||||
|
|
||||||
|
return all_coordinates, type_mask
|
||||||
|
|
||||||
|
|
||||||
|
class MaskEncoder(fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int = 256,
|
||||||
|
intermediate_channels: int = 16,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.intermediate_channels = intermediate_channels
|
||||||
|
super().__init__(
|
||||||
|
fl.Conv2d(
|
||||||
|
in_channels=1,
|
||||||
|
out_channels=self.intermediate_channels // 4,
|
||||||
|
kernel_size=2,
|
||||||
|
stride=2,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
fl.LayerNorm2d(channels=self.intermediate_channels // 4, device=device, dtype=dtype),
|
||||||
|
fl.GeLU(),
|
||||||
|
fl.Conv2d(
|
||||||
|
in_channels=self.intermediate_channels // 4,
|
||||||
|
out_channels=self.intermediate_channels,
|
||||||
|
kernel_size=2,
|
||||||
|
stride=2,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
fl.LayerNorm2d(channels=self.intermediate_channels, device=device, dtype=dtype),
|
||||||
|
fl.GeLU(),
|
||||||
|
fl.Conv2d(
|
||||||
|
in_channels=self.intermediate_channels,
|
||||||
|
out_channels=self.embedding_dim,
|
||||||
|
kernel_size=1,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.register_parameter(
|
||||||
|
"no_mask_embedding", nn.Parameter(torch.randn(1, embedding_dim, device=device, dtype=dtype))
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_no_mask_dense_embedding(
|
||||||
|
self, image_embedding_size: tuple[int, int], batch_size: int = 1
|
||||||
|
) -> Float[Tensor, "batch embedding_dim image_embedding_height image_embedding_width"]:
|
||||||
|
return self.no_mask_embedding.reshape(1, -1, 1, 1).expand(
|
||||||
|
batch_size, -1, image_embedding_size[0], image_embedding_size[1]
|
||||||
|
)
|
@ -0,0 +1,158 @@
|
|||||||
|
from torch import device as Device, dtype as DType
|
||||||
|
|
||||||
|
import imaginairy.vendored.refiners.fluxion.layers as fl
|
||||||
|
|
||||||
|
|
||||||
|
class CrossAttention(fl.Attention):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int,
|
||||||
|
cross_embedding_dim: int | None = None,
|
||||||
|
num_heads: int = 1,
|
||||||
|
inner_dim: int | None = None,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
key_embedding_dim=cross_embedding_dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
inner_dim=inner_dim,
|
||||||
|
is_optimized=False,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
self.cross_embedding_dim = cross_embedding_dim or embedding_dim
|
||||||
|
self.insert(index=0, module=fl.Parallel(fl.GetArg(index=0), fl.GetArg(index=1), fl.GetArg(index=1)))
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(fl.Residual):
|
||||||
|
def __init__(
|
||||||
|
self, embedding_dim: int, feed_forward_dim: int, device: Device | str | None = None, dtype: DType | None = None
|
||||||
|
) -> None:
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.feed_forward_dim = feed_forward_dim
|
||||||
|
super().__init__(
|
||||||
|
fl.Linear(in_features=embedding_dim, out_features=feed_forward_dim, device=device, dtype=dtype),
|
||||||
|
fl.ReLU(),
|
||||||
|
fl.Linear(in_features=feed_forward_dim, out_features=embedding_dim, device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SparseSelfAttention(fl.Residual):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int,
|
||||||
|
inner_dim: int | None = None,
|
||||||
|
num_heads: int = 1,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
add_sparse_embedding = fl.Residual(fl.UseContext(context="mask_decoder", key="sparse_embedding"))
|
||||||
|
super().__init__(
|
||||||
|
fl.Parallel(add_sparse_embedding, add_sparse_embedding, fl.Identity()),
|
||||||
|
fl.Attention(
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
inner_dim=inner_dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
is_optimized=False,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SparseCrossDenseAttention(fl.Residual):
|
||||||
|
def __init__(
|
||||||
|
self, embedding_dim: int, num_heads: int = 8, device: Device | str | None = None, dtype: DType | None = None
|
||||||
|
) -> None:
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
super().__init__(
|
||||||
|
fl.Parallel(
|
||||||
|
fl.Residual(
|
||||||
|
fl.UseContext(context="mask_decoder", key="sparse_embedding"),
|
||||||
|
),
|
||||||
|
fl.Sum(
|
||||||
|
fl.UseContext(context="mask_decoder", key="dense_embedding"),
|
||||||
|
fl.UseContext(context="mask_decoder", key="dense_positional_embedding"),
|
||||||
|
),
|
||||||
|
fl.UseContext(context="mask_decoder", key="dense_embedding"),
|
||||||
|
),
|
||||||
|
fl.Attention(
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
inner_dim=embedding_dim // 2,
|
||||||
|
num_heads=num_heads,
|
||||||
|
is_optimized=False,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DenseCrossSparseAttention(fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self, embedding_dim: int, num_heads: int = 8, device: Device | str | None = None, dtype: DType | None = None
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
fl.Parallel(
|
||||||
|
fl.Sum(
|
||||||
|
fl.UseContext(context="mask_decoder", key="dense_embedding"),
|
||||||
|
fl.UseContext(context="mask_decoder", key="dense_positional_embedding"),
|
||||||
|
),
|
||||||
|
fl.Residual(
|
||||||
|
fl.UseContext(context="mask_decoder", key="sparse_embedding"),
|
||||||
|
),
|
||||||
|
fl.Identity(),
|
||||||
|
),
|
||||||
|
fl.Attention(
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
inner_dim=embedding_dim // 2,
|
||||||
|
num_heads=num_heads,
|
||||||
|
is_optimized=False,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TwoWayTranformerLayer(fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int,
|
||||||
|
num_heads: int = 8,
|
||||||
|
feed_forward_dim: int = 2048,
|
||||||
|
use_residual_self_attention: bool = True,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.feed_forward_dim = feed_forward_dim
|
||||||
|
|
||||||
|
self_attention = (
|
||||||
|
SparseSelfAttention(embedding_dim=embedding_dim, num_heads=num_heads, device=device, dtype=dtype)
|
||||||
|
if use_residual_self_attention
|
||||||
|
else fl.SelfAttention(
|
||||||
|
embedding_dim=embedding_dim, num_heads=num_heads, is_optimized=False, device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
self_attention,
|
||||||
|
fl.LayerNorm(normalized_shape=embedding_dim, device=device, dtype=dtype),
|
||||||
|
SparseCrossDenseAttention(embedding_dim=embedding_dim, num_heads=num_heads, device=device, dtype=dtype),
|
||||||
|
fl.LayerNorm(normalized_shape=embedding_dim, device=device, dtype=dtype),
|
||||||
|
FeedForward(embedding_dim=embedding_dim, feed_forward_dim=feed_forward_dim, device=device, dtype=dtype),
|
||||||
|
fl.LayerNorm(normalized_shape=embedding_dim, device=device, dtype=dtype),
|
||||||
|
fl.Passthrough(
|
||||||
|
fl.Sum(
|
||||||
|
fl.UseContext(context="mask_decoder", key="dense_embedding"),
|
||||||
|
DenseCrossSparseAttention(
|
||||||
|
embedding_dim=embedding_dim, num_heads=num_heads, device=device, dtype=dtype
|
||||||
|
),
|
||||||
|
),
|
||||||
|
fl.LayerNorm(normalized_shape=embedding_dim, device=device, dtype=dtype),
|
||||||
|
fl.SetContext(context="mask_decoder", key="dense_embedding"),
|
||||||
|
),
|
||||||
|
)
|
0
imaginairy/vendored/refiners/py.typed
Normal file
0
imaginairy/vendored/refiners/py.typed
Normal file
1
imaginairy/vendored/refiners/readme.txt
Normal file
1
imaginairy/vendored/refiners/readme.txt
Normal file
@ -0,0 +1 @@
|
|||||||
|
vendored from git@github.com:finegrain-ai/refiners.git @ 20c229903f53d05dc1c44659ec97603660ef964c
|
@ -94,7 +94,7 @@ importlib-metadata==7.0.1
|
|||||||
iniconfig==2.0.0
|
iniconfig==2.0.0
|
||||||
# via pytest
|
# via pytest
|
||||||
jaxtyping==0.2.25
|
jaxtyping==0.2.25
|
||||||
# via refiners
|
# via imaginAIry (setup.py)
|
||||||
jinja2==3.1.2
|
jinja2==3.1.2
|
||||||
# via torch
|
# via torch
|
||||||
kiwisolver==1.4.5
|
kiwisolver==1.4.5
|
||||||
@ -133,7 +133,6 @@ numpy==1.24.4
|
|||||||
# matplotlib
|
# matplotlib
|
||||||
# numba
|
# numba
|
||||||
# opencv-python
|
# opencv-python
|
||||||
# refiners
|
|
||||||
# scipy
|
# scipy
|
||||||
# torchvision
|
# torchvision
|
||||||
# transformers
|
# transformers
|
||||||
@ -160,7 +159,6 @@ pillow==10.2.0
|
|||||||
# imageio
|
# imageio
|
||||||
# imaginAIry (setup.py)
|
# imaginAIry (setup.py)
|
||||||
# matplotlib
|
# matplotlib
|
||||||
# refiners
|
|
||||||
# torchvision
|
# torchvision
|
||||||
pluggy==1.3.0
|
pluggy==1.3.0
|
||||||
# via pytest
|
# via pytest
|
||||||
@ -199,8 +197,6 @@ pyyaml==6.0.1
|
|||||||
# responses
|
# responses
|
||||||
# timm
|
# timm
|
||||||
# transformers
|
# transformers
|
||||||
refiners==0.2.0
|
|
||||||
# via imaginAIry (setup.py)
|
|
||||||
regex==2023.12.25
|
regex==2023.12.25
|
||||||
# via
|
# via
|
||||||
# diffusers
|
# diffusers
|
||||||
@ -216,13 +212,12 @@ requests==2.31.0
|
|||||||
# transformers
|
# transformers
|
||||||
responses==0.24.1
|
responses==0.24.1
|
||||||
# via -r requirements-dev.in
|
# via -r requirements-dev.in
|
||||||
ruff==0.1.9
|
ruff==0.1.11
|
||||||
# via -r requirements-dev.in
|
# via -r requirements-dev.in
|
||||||
safetensors==0.3.3
|
safetensors==0.4.1
|
||||||
# via
|
# via
|
||||||
# diffusers
|
# diffusers
|
||||||
# imaginAIry (setup.py)
|
# imaginAIry (setup.py)
|
||||||
# refiners
|
|
||||||
# timm
|
# timm
|
||||||
# transformers
|
# transformers
|
||||||
scipy==1.10.1
|
scipy==1.10.1
|
||||||
@ -262,7 +257,6 @@ torch==2.1.2
|
|||||||
# imaginAIry (setup.py)
|
# imaginAIry (setup.py)
|
||||||
# kornia
|
# kornia
|
||||||
# open-clip-torch
|
# open-clip-torch
|
||||||
# refiners
|
|
||||||
# timm
|
# timm
|
||||||
# torchdiffeq
|
# torchdiffeq
|
||||||
# torchvision
|
# torchvision
|
||||||
|
5
setup.py
5
setup.py
@ -95,9 +95,10 @@ setup(
|
|||||||
# need to migration to 2.0
|
# need to migration to 2.0
|
||||||
"pydantic>=2.3.0",
|
"pydantic>=2.3.0",
|
||||||
"requests>=2.28.1",
|
"requests>=2.28.1",
|
||||||
"refiners>=0.2.0",
|
# "refiners>=0.2.0",
|
||||||
|
"jaxtyping>=0.2.23", # refiners dependency
|
||||||
"einops>=0.3.0",
|
"einops>=0.3.0",
|
||||||
"safetensors>=0.2.1",
|
"safetensors>=0.4.0",
|
||||||
# scipy is a sub dependency but v1.11 doesn't support python 3.8. https://docs.scipy.org/doc/scipy/dev/toolchain.html#numpy
|
# scipy is a sub dependency but v1.11 doesn't support python 3.8. https://docs.scipy.org/doc/scipy/dev/toolchain.html#numpy
|
||||||
"scipy<1.11",
|
"scipy<1.11",
|
||||||
"timm>=0.4.12,!=0.9.0,!=0.9.1", # for vendored blip
|
"timm>=0.4.12,!=0.9.0,!=0.9.1", # for vendored blip
|
||||||
|
Loading…
Reference in New Issue
Block a user