|
|
|
@ -4,26 +4,28 @@ from typing import Any, Optional
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
from einops import rearrange
|
|
|
|
|
from torch import nn
|
|
|
|
|
|
|
|
|
|
from imaginairy.modules.attention import MemoryEfficientCrossAttention
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
import xformers
|
|
|
|
|
import xformers.ops
|
|
|
|
|
import xformers # noqa
|
|
|
|
|
import xformers.ops # noqa
|
|
|
|
|
|
|
|
|
|
XFORMERS_IS_AVAILBLE = True
|
|
|
|
|
except:
|
|
|
|
|
XFORMERS_IS_AVAILBLE = False
|
|
|
|
|
XFORMERS_IS_AVAILABLE = True
|
|
|
|
|
except ImportError:
|
|
|
|
|
XFORMERS_IS_AVAILABLE = False
|
|
|
|
|
# print("No module 'xformers'. Proceeding without it.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_timestep_embedding(timesteps, embedding_dim):
|
|
|
|
|
"""
|
|
|
|
|
Build sinusoidal embeddings.
|
|
|
|
|
|
|
|
|
|
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
|
|
|
|
From Fairseq.
|
|
|
|
|
Build sinusoidal embeddings.
|
|
|
|
|
|
|
|
|
|
This matches the implementation in tensor2tensor, but differs slightly
|
|
|
|
|
from the description in Section 3.5 of "Attention Is All You Need".
|
|
|
|
|
"""
|
|
|
|
@ -271,22 +273,22 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
|
|
|
|
"linear",
|
|
|
|
|
"none",
|
|
|
|
|
], f"attn_type {attn_type} unknown"
|
|
|
|
|
if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
|
|
|
|
|
if XFORMERS_IS_AVAILABLE and attn_type == "vanilla":
|
|
|
|
|
attn_type = "vanilla-xformers"
|
|
|
|
|
# print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
|
|
|
|
if attn_type == "vanilla":
|
|
|
|
|
assert attn_kwargs is None
|
|
|
|
|
return AttnBlock(in_channels)
|
|
|
|
|
elif attn_type == "vanilla-xformers":
|
|
|
|
|
if attn_type == "vanilla-xformers":
|
|
|
|
|
# print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
|
|
|
|
|
return MemoryEfficientAttnBlock(in_channels)
|
|
|
|
|
elif type == "memory-efficient-cross-attn":
|
|
|
|
|
if type == "memory-efficient-cross-attn":
|
|
|
|
|
attn_kwargs["query_dim"] = in_channels
|
|
|
|
|
return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
|
|
|
|
|
elif attn_type == "none":
|
|
|
|
|
if attn_type == "none":
|
|
|
|
|
return nn.Identity(in_channels)
|
|
|
|
|
else:
|
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
|
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Model(nn.Module):
|
|
|
|
@ -599,7 +601,7 @@ class Decoder(nn.Module):
|
|
|
|
|
tanh_out=False,
|
|
|
|
|
use_linear_attn=False,
|
|
|
|
|
attn_type="vanilla",
|
|
|
|
|
**ignorekwargs,
|
|
|
|
|
**ignore_kwargs,
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
if use_linear_attn:
|
|
|
|
@ -677,6 +679,8 @@ class Decoder(nn.Module):
|
|
|
|
|
block_in, out_ch, kernel_size=3, stride=1, padding=1
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.last_z_shape = None
|
|
|
|
|
|
|
|
|
|
def forward(self, z):
|
|
|
|
|
# assert z.shape[1:] == self.z_shape[1:]
|
|
|
|
|
self.last_z_shape = z.shape
|
|
|
|
@ -1003,17 +1007,16 @@ class Resize(nn.Module):
|
|
|
|
|
# f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode"
|
|
|
|
|
# )
|
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
assert in_channels is not None
|
|
|
|
|
# no asymmetric padding in torch conv, must do it ourselves
|
|
|
|
|
self.conv = torch.nn.Conv2d(
|
|
|
|
|
in_channels, in_channels, kernel_size=4, stride=2, padding=1
|
|
|
|
|
)
|
|
|
|
|
# assert in_channels is not None
|
|
|
|
|
# # no asymmetric padding in torch conv, must do it ourselves
|
|
|
|
|
# self.conv = torch.nn.Conv2d(
|
|
|
|
|
# in_channels, in_channels, kernel_size=4, stride=2, padding=1
|
|
|
|
|
# )
|
|
|
|
|
|
|
|
|
|
def forward(self, x, scale_factor=1.0):
|
|
|
|
|
if scale_factor == 1.0:
|
|
|
|
|
return x
|
|
|
|
|
else:
|
|
|
|
|
x = torch.nn.functional.interpolate(
|
|
|
|
|
x, mode=self.mode, align_corners=False, scale_factor=scale_factor
|
|
|
|
|
)
|
|
|
|
|
x = torch.nn.functional.interpolate(
|
|
|
|
|
x, mode=self.mode, align_corners=False, scale_factor=scale_factor
|
|
|
|
|
)
|
|
|
|
|
return x
|
|
|
|
|