refactor: cleanup ddim

- delete more unused code
- fix some lints
pull/9/head
Bryce 2 years ago
parent 6307a0daf5
commit d7cbf6e416

@ -263,7 +263,7 @@ def imagine(
upscaled_img = None
is_nsfw_img = None
if IMAGINAIRY_SAFETY_MODE != SafetyMode.DISABLED:
is_nsfw_img = is_nsfw(img, x_sample, half_mode=half_mode)
is_nsfw_img = is_nsfw(img, x_sample)
if is_nsfw_img and IMAGINAIRY_SAFETY_MODE == SafetyMode.FILTER:
logger.info(" ⚠️ Filtering NSFW image")
img = img.filter(ImageFilter.GaussianBlur(radius=40))

@ -72,7 +72,7 @@ def enhance_faces(img, fidelity=0):
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
del output
torch.cuda.empty_cache()
except Exception as error:
except Exception as error: # noqa
logger.error(f"\tFailed inference for CodeFormer: {error}")
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))

@ -157,78 +157,6 @@ class VectorQuantizer(nn.Module):
return z_q
class VQModel(pl.LightningModule):
def __init__(
self,
ddconfig,
lossconfig,
n_embed,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
batch_resize_range=None,
scheduler_config=None,
lr_g_factor=1.0,
remap=None,
sane_index_shape=False, # tell vector quantizer to return indices as bhw
):
super().__init__()
self.embed_dim = embed_dim
self.n_embed = n_embed
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
self.quantize = VectorQuantizer(
n_embed,
embed_dim,
beta=0.25,
remap=remap,
sane_index_shape=sane_index_shape,
)
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
if colorize_nlabels is not None:
assert type(colorize_nlabels) == int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
self.batch_resize_range = batch_resize_range
if self.batch_resize_range is not None:
logger.info(
f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}."
)
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
self.scheduler_config = scheduler_config
self.lr_g_factor = lr_g_factor
class VQModelInterface(VQModel):
def __init__(self, embed_dim, *args, **kwargs):
super().__init__(embed_dim=embed_dim, *args, **kwargs)
self.embed_dim = embed_dim
def encode(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
return h
def decode(self, h, force_not_quantize=False):
# also go through quantization layer
if not force_not_quantize:
quant, emb_loss, info = self.quantize(h)
else:
quant = h
quant = self.post_quant_conv(quant)
dec = self.decoder(quant)
return dec
class AutoencoderKL(pl.LightningModule):
def __init__(
self,

@ -16,7 +16,6 @@ from einops import rearrange
from torchvision.utils import make_grid
from tqdm import tqdm
from imaginairy.modules.autoencoder import VQModelInterface
from imaginairy.modules.diffusion.util import make_beta_schedule, noise_like
from imaginairy.modules.distributions import DiagonalGaussianDistribution
from imaginairy.utils import instantiate_from_config, log_params
@ -570,69 +569,7 @@ class LatentDiffusion(DDPM):
z = 1.0 / self.scale_factor * z
if hasattr(self, "split_input_params"):
if self.split_input_params["patch_distributed_vq"]:
ks = self.split_input_params["ks"] # eg. (128, 128)
stride = self.split_input_params["stride"] # eg. (64, 64)
uf = self.split_input_params["vqf"]
bs, nc, h, w = z.shape
if ks[0] > h or ks[1] > w:
ks = (min(ks[0], h), min(ks[1], w))
logger.info("reducing Kernel")
if stride[0] > h or stride[1] > w:
stride = (min(stride[0], h), min(stride[1], w))
logger.info("reducing stride")
fold, unfold, normalization, weighting = self.get_fold_unfold(
z, ks, stride, uf=uf
)
z = unfold(z) # (bn, nc * prod(**ks), L)
# 1. Reshape to img shape
z = z.view(
(z.shape[0], -1, ks[0], ks[1], z.shape[-1])
) # (bn, nc, ks[0], ks[1], L )
# 2. apply model loop over last dim
if isinstance(self.first_stage_model, VQModelInterface):
output_list = [
self.first_stage_model.decode(
z[:, :, :, :, i],
force_not_quantize=predict_cids or force_not_quantize,
)
for i in range(z.shape[-1])
]
else:
output_list = [
self.first_stage_model.decode(z[:, :, :, :, i])
for i in range(z.shape[-1])
]
o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
o = o * weighting
# Reverse 1. reshape to img shape
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
# stitch crops together
decoded = fold(o)
decoded = decoded / normalization # norm is shape (1, 1, h, w)
return decoded
else:
if isinstance(self.first_stage_model, VQModelInterface):
return self.first_stage_model.decode(
z, force_not_quantize=predict_cids or force_not_quantize
)
else:
return self.first_stage_model.decode(z)
else:
if isinstance(self.first_stage_model, VQModelInterface):
return self.first_stage_model.decode(
z, force_not_quantize=predict_cids or force_not_quantize
)
else:
return self.first_stage_model.decode(z)
return self.first_stage_model.decode(z)
@torch.no_grad()
def encode_first_stage(self, x):
@ -770,7 +707,9 @@ class LatentDiffusion(DDPM):
# tokenize crop coordinates for the bounding boxes of the respective patches
patch_limits_tknzd = [
torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(
torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[
None
].to( # noqa
self.device
)
for bbox in patch_limits

@ -3,15 +3,14 @@ from abc import abstractmethod
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from torch import nn
from imaginairy.modules.attention import SpatialTransformer
from imaginairy.modules.diffusion.util import (
avg_pool_nd,
checkpoint,
conv_nd,
linear,
normalization,
timestep_embedding,
zero_module,
@ -19,15 +18,14 @@ from imaginairy.modules.diffusion.util import (
# dummy replace
def convert_module_to_f16(x):
def convert_module_to_f16(_):
pass
def convert_module_to_f32(x):
def convert_module_to_f32(_):
pass
## go
class AttentionPool2d(nn.Module):
"""
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
@ -123,7 +121,7 @@ class Upsample(nn.Module):
class TransposedUpsample(nn.Module):
"Learned 2x upsampling without padding"
"""Learned 2x upsampling without padding"""
def __init__(self, channels, out_channels=None, ks=5):
super().__init__()
@ -229,7 +227,7 @@ class ResBlock(TimestepBlock):
self.emb_layers = nn.Sequential(
nn.SiLU(),
linear(
nn.Linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
),
@ -492,7 +490,7 @@ class UNetModel(nn.Module):
), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
from omegaconf.listconfig import ListConfig
if type(context_dim) == ListConfig:
if isinstance(context_dim, ListConfig):
context_dim = list(context_dim)
if num_heads_upsample == -1:
@ -527,9 +525,9 @@ class UNetModel(nn.Module):
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.Linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
nn.Linear(time_embed_dim, time_embed_dim),
)
if self.num_classes is not None:
@ -786,223 +784,5 @@ class UNetModel(nn.Module):
h = h.type(x.dtype)
if self.predict_codebook_ids:
return self.id_predictor(h)
else:
return self.out(h)
class EncoderUNetModel(nn.Module):
"""
The half UNet model with attention and timestep embedding.
For usage, see UNet.
"""
def __init__(
self,
image_size,
in_channels,
model_channels,
out_channels,
num_res_blocks,
attention_resolutions,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
use_checkpoint=False,
use_fp16=False,
num_heads=1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
use_new_attention_order=False,
pool="adaptive",
*args,
**kwargs,
):
super().__init__()
if num_heads_upsample == -1:
num_heads_upsample = num_heads
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
self.num_res_blocks = num_res_blocks
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.use_checkpoint = use_checkpoint
self.dtype = th.float16 if use_fp16 else th.float32
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1)
)
]
)
self._feature_size = model_channels
input_block_chans = [model_channels]
ch = model_channels
ds = 1
for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
layers = [
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=mult * model_channels,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = mult * model_channels
if ds in attention_resolutions:
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=num_head_channels,
use_new_attention_order=use_new_attention_order,
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
)
if resblock_updown
else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch
)
)
)
ch = out_ch
input_block_chans.append(ch)
ds *= 2
self._feature_size += ch
self.middle_block = TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=num_head_channels,
use_new_attention_order=use_new_attention_order,
),
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
)
self._feature_size += ch
self.pool = pool
if pool == "adaptive":
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
nn.AdaptiveAvgPool2d((1, 1)),
zero_module(conv_nd(dims, ch, out_channels, 1)),
nn.Flatten(),
)
elif pool == "attention":
assert num_head_channels != -1
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
AttentionPool2d(
(image_size // ds), ch, num_head_channels, out_channels
),
)
elif pool == "spatial":
self.out = nn.Sequential(
nn.Linear(self._feature_size, 2048),
nn.ReLU(),
nn.Linear(2048, self.out_channels),
)
elif pool == "spatial_v2":
self.out = nn.Sequential(
nn.Linear(self._feature_size, 2048),
normalization(2048),
nn.SiLU(),
nn.Linear(2048, self.out_channels),
)
else:
raise NotImplementedError(f"Unexpected {pool} pooling")
def convert_to_fp16(self):
"""
Convert the torso of the model to float16.
"""
self.input_blocks.apply(convert_module_to_f16)
self.middle_block.apply(convert_module_to_f16)
def convert_to_fp32(self):
"""
Convert the torso of the model to float32.
"""
self.input_blocks.apply(convert_module_to_f32)
self.middle_block.apply(convert_module_to_f32)
def forward(self, x, timesteps):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:return: an [N x K] Tensor of outputs.
"""
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
results = []
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb)
if self.pool.startswith("spatial"):
results.append(h.type(x.dtype).mean(dim=(2, 3)))
h = self.middle_block(h, emb)
if self.pool.startswith("spatial"):
results.append(h.type(x.dtype).mean(dim=(2, 3)))
h = th.cat(results, axis=-1)
return self.out(h)
else:
h = h.type(x.dtype)
return self.out(h)
return self.out(h)

@ -14,7 +14,7 @@ import math
import numpy as np
import torch
import torch.nn as nn
from einops import repeat
from einops import repeat as e_repeat
from imaginairy.utils import instantiate_from_config
@ -207,7 +207,7 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
else:
embedding = repeat(timesteps, "b -> b d", d=dim)
embedding = e_repeat(timesteps, "b -> b d", d=dim)
return embedding
@ -269,22 +269,13 @@ def conv_nd(dims, *args, **kwargs):
raise ValueError(f"unsupported dimensions: {dims}")
def linear(*args, **kwargs):
"""
Create a linear module.
"""
return nn.Linear(*args, **kwargs)
def avg_pool_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
"""Create a 1D, 2D, or 3D average pooling module."""
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
elif dims == 2:
if dims == 2:
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
if dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")

@ -2,26 +2,7 @@ import numpy as np
import torch
class AbstractDistribution:
def sample(self):
raise NotImplementedError()
def mode(self):
raise NotImplementedError()
class DiracDistribution(AbstractDistribution):
def __init__(self, value):
self.value = value
def sample(self):
return self.value
def mode(self):
return self.value
class DiagonalGaussianDistribution(object):
class DiagonalGaussianDistribution:
def __init__(self, parameters, deterministic=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
@ -70,33 +51,3 @@ class DiagonalGaussianDistribution(object):
def mode(self):
return self.mean
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor = None
for obj in (mean1, logvar1, mean2, logvar2):
if isinstance(obj, torch.Tensor):
tensor = obj
break
assert tensor is not None, "at least one argument must be a Tensor"
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for torch.exp().
logvar1, logvar2 = [
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
for x in (logvar1, logvar2)
]
return 0.5 * (
-1.0
+ logvar2
- logvar1
+ torch.exp(logvar1 - logvar2)
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
)

@ -14,7 +14,7 @@ def safety_models():
return safety_feature_extractor, safety_checker
def is_nsfw(img, x_sample, half_mode=False):
def is_nsfw(img, x_sample):
safety_feature_extractor, safety_checker = safety_models()
safety_checker_input = safety_feature_extractor([img], return_tensors="pt")
clip_input = safety_checker_input.pixel_values

@ -56,7 +56,7 @@ def experiment_step_repeats():
sampler.make_schedule(1000)
img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/beach_at_sainte_adresse.jpg")
init_image, w, h = pillow_img_to_torch_image(
init_image, _, h = pillow_img_to_torch_image(
img,
max_height=512,
max_width=512,
@ -64,7 +64,7 @@ def experiment_step_repeats():
init_image = init_image.to(get_device())
init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image))
log_latent(init_latent, "init_latent")
noise = torch.randn_like(init_latent)
base_count = 1
neutral_embedding = embedder.encode([""])
outdir = f"{TESTS_FOLDER}/test_output"

@ -12,7 +12,7 @@ skip = */.tox/*,*/.env/*,build/*,*/downloads/*,other/*,prolly_delete/*,downloads
linters = pylint,pycodestyle,pydocstyle,pyflakes,mypy
ignore =
Z999,C0103,C0301,C0114,C0115,C0116,
Z999,D100,D101,D102,D103,D105,D107,D202,D203,D205,D212,D400,D401,D415,
Z999,D100,D101,D102,D103,D105,D107,D200,D202,D203,D205,D212,D400,D401,D415,
Z999,E501,E1101,
Z999,R0901,R0902,R0903,R0193,R0912,R0913,R0914,R0915,
Z999,W0221,W0511,W1203

Loading…
Cancel
Save