mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-11-05 12:00:15 +00:00
feature: xformers support
add more upscaling code (that doesn't yet work)
This commit is contained in:
parent
8e9e119052
commit
4610d7f01d
@ -231,6 +231,8 @@ docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface -
|
|||||||
|
|
||||||
## ChangeLog
|
## ChangeLog
|
||||||
|
|
||||||
|
- feature: xformers support
|
||||||
|
|
||||||
**6.1.0**
|
**6.1.0**
|
||||||
- feature: use different default steps and image sizes depending on sampler and model selceted
|
- feature: use different default steps and image sizes depending on sampler and model selceted
|
||||||
- fix: #110 use proper version in image metadata
|
- fix: #110 use proper version in image metadata
|
||||||
|
@ -49,6 +49,12 @@ MODEL_CONFIGS = [
|
|||||||
weights_url="https://huggingface.co/stabilityai/stable-diffusion-2/resolve/main/768-v-ema.ckpt",
|
weights_url="https://huggingface.co/stabilityai/stable-diffusion-2/resolve/main/768-v-ema.ckpt",
|
||||||
default_image_size=768,
|
default_image_size=768,
|
||||||
),
|
),
|
||||||
|
ModelConfig(
|
||||||
|
short_name="SD-2.0-upscale",
|
||||||
|
config_path="configs/stable-diffusion-v2-upscaling.yaml",
|
||||||
|
weights_url="https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler/resolve/main/x4-upscaler-ema.ckpt",
|
||||||
|
default_image_size=512,
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
MODEL_CONFIG_SHORTCUTS = {m.short_name: m for m in MODEL_CONFIGS}
|
MODEL_CONFIG_SHORTCUTS = {m.short_name: m for m in MODEL_CONFIGS}
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
model:
|
model:
|
||||||
base_learning_rate: 1.0e-04
|
base_learning_rate: 1.0e-04
|
||||||
target: ldm.models.diffusion.ddpm.LatentUpscaleDiffusion
|
target: imaginairy.modules.diffusion.ddpm.LatentUpscaleDiffusion
|
||||||
params:
|
params:
|
||||||
parameterization: "v"
|
parameterization: "v"
|
||||||
low_scale_key: "lr"
|
low_scale_key: "lr"
|
||||||
|
@ -171,6 +171,7 @@ class CrossAttention(nn.Module):
|
|||||||
# mask = _global_mask_hack.to(torch.bool)
|
# mask = _global_mask_hack.to(torch.bool)
|
||||||
|
|
||||||
if get_device() == "cuda" or "mps" in get_device():
|
if get_device() == "cuda" or "mps" in get_device():
|
||||||
|
if not XFORMERS_IS_AVAILBLE:
|
||||||
return self.forward_splitmem(x, context=context, mask=mask)
|
return self.forward_splitmem(x, context=context, mask=mask)
|
||||||
|
|
||||||
h = self.heads
|
h = self.heads
|
||||||
|
@ -7,7 +7,7 @@ https://github.com/CompVis/taming-transformers
|
|||||||
"""
|
"""
|
||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager, nullcontext
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -39,6 +39,18 @@ def disabled_train(self):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
def ismap(x):
|
||||||
|
if not isinstance(x, torch.Tensor):
|
||||||
|
return False
|
||||||
|
return (len(x.shape) == 4) and (x.shape[1] > 3)
|
||||||
|
|
||||||
|
|
||||||
|
def isimage(x):
|
||||||
|
if not isinstance(x, torch.Tensor):
|
||||||
|
return False
|
||||||
|
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
|
||||||
|
|
||||||
|
|
||||||
def uniform_on_device(r1, r2, shape, device):
|
def uniform_on_device(r1, r2, shape, device):
|
||||||
return (r1 - r2) * torch.rand(*shape, device=device) + r2
|
return (r1 - r2) * torch.rand(*shape, device=device) + r2
|
||||||
|
|
||||||
@ -1324,6 +1336,188 @@ class DiffusionWrapper(pl.LightningModule):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class LatentFinetuneDiffusion(LatentDiffusion):
|
||||||
|
"""
|
||||||
|
Basis for different finetunas, such as inpainting or depth2image
|
||||||
|
To disable finetuning mode, set finetune_keys to None
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
concat_keys: tuple,
|
||||||
|
finetune_keys=(
|
||||||
|
"model.diffusion_model.input_blocks.0.0.weight",
|
||||||
|
"model_ema.diffusion_modelinput_blocks00weight",
|
||||||
|
),
|
||||||
|
keep_finetune_dims=4,
|
||||||
|
# if model was trained without concat mode before and we would like to keep these channels
|
||||||
|
c_concat_log_start=None, # to log reconstruction of c_concat codes
|
||||||
|
c_concat_log_end=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
ckpt_path = kwargs.pop("ckpt_path", None)
|
||||||
|
ignore_keys = kwargs.pop("ignore_keys", [])
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.finetune_keys = finetune_keys
|
||||||
|
self.concat_keys = concat_keys
|
||||||
|
self.keep_dims = keep_finetune_dims
|
||||||
|
self.c_concat_log_start = c_concat_log_start
|
||||||
|
self.c_concat_log_end = c_concat_log_end
|
||||||
|
if self.finetune_keys is not None:
|
||||||
|
assert ckpt_path is not None, "can only finetune from a given checkpoint"
|
||||||
|
if ckpt_path is not None:
|
||||||
|
self.init_from_ckpt(ckpt_path, ignore_keys)
|
||||||
|
|
||||||
|
def init_from_ckpt(self, path, ignore_keys=tuple(), only_model=False):
|
||||||
|
sd = torch.load(path, map_location="cpu")
|
||||||
|
if "state_dict" in list(sd.keys()):
|
||||||
|
sd = sd["state_dict"]
|
||||||
|
keys = list(sd.keys())
|
||||||
|
for k in keys:
|
||||||
|
for ik in ignore_keys:
|
||||||
|
if k.startswith(ik):
|
||||||
|
print("Deleting key {k} from state_dict.")
|
||||||
|
del sd[k]
|
||||||
|
|
||||||
|
# make it explicit, finetune by including extra input channels
|
||||||
|
if self.finetune_keys is not None and k in self.finetune_keys:
|
||||||
|
new_entry = None
|
||||||
|
for name, param in self.named_parameters():
|
||||||
|
if name in self.finetune_keys:
|
||||||
|
print(
|
||||||
|
f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only"
|
||||||
|
)
|
||||||
|
new_entry = torch.zeros_like(param) # zero init
|
||||||
|
assert (
|
||||||
|
new_entry is not None
|
||||||
|
), "did not find matching parameter to modify"
|
||||||
|
new_entry[:, : self.keep_dims, ...] = sd[k]
|
||||||
|
sd[k] = new_entry
|
||||||
|
|
||||||
|
missing, unexpected = (
|
||||||
|
self.load_state_dict(sd, strict=False)
|
||||||
|
if not only_model
|
||||||
|
else self.model.load_state_dict(sd, strict=False)
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
|
||||||
|
)
|
||||||
|
if len(missing) > 0:
|
||||||
|
print(f"Missing Keys: {missing}")
|
||||||
|
if len(unexpected) > 0:
|
||||||
|
print(f"Unexpected Keys: {unexpected}")
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def log_images( # noqa
|
||||||
|
self,
|
||||||
|
batch,
|
||||||
|
N=8,
|
||||||
|
n_row=4,
|
||||||
|
sample=True,
|
||||||
|
ddim_steps=200,
|
||||||
|
ddim_eta=1.0,
|
||||||
|
return_keys=None,
|
||||||
|
quantize_denoised=True,
|
||||||
|
inpaint=True,
|
||||||
|
plot_denoise_rows=False,
|
||||||
|
plot_progressive_rows=True,
|
||||||
|
plot_diffusion_rows=True,
|
||||||
|
unconditional_guidance_scale=1.0,
|
||||||
|
unconditional_guidance_label=None,
|
||||||
|
use_ema_scope=True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
ema_scope = self.ema_scope if use_ema_scope else nullcontext
|
||||||
|
use_ddim = ddim_steps is not None
|
||||||
|
|
||||||
|
log = {}
|
||||||
|
z, c, x, xrec, xc = self.get_input(
|
||||||
|
batch, self.first_stage_key, bs=N, return_first_stage_outputs=True
|
||||||
|
)
|
||||||
|
c_cat, c = c["c_concat"][0], c["c_crossattn"][0]
|
||||||
|
N = min(x.shape[0], N)
|
||||||
|
n_row = min(x.shape[0], n_row)
|
||||||
|
log["inputs"] = x
|
||||||
|
log["reconstruction"] = xrec
|
||||||
|
if self.model.conditioning_key is not None:
|
||||||
|
if hasattr(self.cond_stage_model, "decode"):
|
||||||
|
xc = self.cond_stage_model.decode(c)
|
||||||
|
log["conditioning"] = xc
|
||||||
|
elif self.cond_stage_key in ["caption", "txt"]:
|
||||||
|
# xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
|
||||||
|
log["conditioning"] = xc
|
||||||
|
elif self.cond_stage_key in ["class_label", "cls"]:
|
||||||
|
# xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
|
||||||
|
log["conditioning"] = xc
|
||||||
|
elif isimage(xc):
|
||||||
|
log["conditioning"] = xc
|
||||||
|
if ismap(xc):
|
||||||
|
log["original_conditioning"] = self.to_rgb(xc)
|
||||||
|
|
||||||
|
if not (self.c_concat_log_start is None and self.c_concat_log_end is None):
|
||||||
|
log["c_concat_decoded"] = self.decode_first_stage(
|
||||||
|
c_cat[:, self.c_concat_log_start : self.c_concat_log_end] # noqa
|
||||||
|
)
|
||||||
|
|
||||||
|
if plot_diffusion_rows:
|
||||||
|
# get diffusion row
|
||||||
|
diffusion_row = []
|
||||||
|
z_start = z[:n_row]
|
||||||
|
for t in range(self.num_timesteps):
|
||||||
|
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
||||||
|
t = repeat(torch.tensor([t]), "1 -> b", b=n_row)
|
||||||
|
t = t.to(self.device).long()
|
||||||
|
noise = torch.randn_like(z_start)
|
||||||
|
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
|
||||||
|
diffusion_row.append(self.decode_first_stage(z_noisy))
|
||||||
|
|
||||||
|
diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
|
||||||
|
diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w")
|
||||||
|
diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w")
|
||||||
|
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
|
||||||
|
log["diffusion_row"] = diffusion_grid
|
||||||
|
|
||||||
|
if sample:
|
||||||
|
# get denoise row
|
||||||
|
with ema_scope("Sampling"):
|
||||||
|
samples, z_denoise_row = self.sample_log(
|
||||||
|
cond={"c_concat": [c_cat], "c_crossattn": [c]},
|
||||||
|
batch_size=N,
|
||||||
|
ddim=use_ddim,
|
||||||
|
ddim_steps=ddim_steps,
|
||||||
|
eta=ddim_eta,
|
||||||
|
)
|
||||||
|
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
|
||||||
|
x_samples = self.decode_first_stage(samples)
|
||||||
|
log["samples"] = x_samples
|
||||||
|
if plot_denoise_rows:
|
||||||
|
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
|
||||||
|
log["denoise_row"] = denoise_grid
|
||||||
|
|
||||||
|
if unconditional_guidance_scale > 1.0:
|
||||||
|
uc_cross = self.get_unconditional_conditioning(
|
||||||
|
N, unconditional_guidance_label
|
||||||
|
)
|
||||||
|
uc_cat = c_cat
|
||||||
|
uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
|
||||||
|
with ema_scope("Sampling with classifier-free guidance"):
|
||||||
|
samples_cfg, _ = self.sample_log(
|
||||||
|
cond={"c_concat": [c_cat], "c_crossattn": [c]},
|
||||||
|
batch_size=N,
|
||||||
|
ddim=use_ddim,
|
||||||
|
ddim_steps=ddim_steps,
|
||||||
|
eta=ddim_eta,
|
||||||
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
|
unconditional_conditioning=uc_full,
|
||||||
|
)
|
||||||
|
x_samples_cfg = self.decode_first_stage(samples_cfg)
|
||||||
|
log[
|
||||||
|
f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"
|
||||||
|
] = x_samples_cfg
|
||||||
|
|
||||||
|
return log
|
||||||
|
|
||||||
|
|
||||||
class LatentInpaintDiffusion(LatentDiffusion):
|
class LatentInpaintDiffusion(LatentDiffusion):
|
||||||
def __init__( # noqa
|
def __init__( # noqa
|
||||||
self,
|
self,
|
||||||
@ -1377,3 +1571,152 @@ class LatentInpaintDiffusion(LatentDiffusion):
|
|||||||
if return_first_stage_outputs:
|
if return_first_stage_outputs:
|
||||||
return z, all_conds, x, xrec, xc
|
return z, all_conds, x, xrec, xc
|
||||||
return z, all_conds
|
return z, all_conds
|
||||||
|
|
||||||
|
|
||||||
|
class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion):
|
||||||
|
"""
|
||||||
|
condition on monocular depth estimation
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, depth_stage_config, concat_keys=("midas_in",), **kwargs):
|
||||||
|
super().__init__(concat_keys=concat_keys, **kwargs)
|
||||||
|
self.depth_model = instantiate_from_config(depth_stage_config)
|
||||||
|
self.depth_stage_key = concat_keys[0]
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def get_input(
|
||||||
|
self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False
|
||||||
|
):
|
||||||
|
# note: restricted to non-trainable encoders currently
|
||||||
|
assert (
|
||||||
|
not self.cond_stage_trainable
|
||||||
|
), "trainable cond stages not yet supported for depth2img"
|
||||||
|
z, c, x, xrec, xc = super().get_input(
|
||||||
|
batch,
|
||||||
|
self.first_stage_key,
|
||||||
|
return_first_stage_outputs=True,
|
||||||
|
force_c_encode=True,
|
||||||
|
return_original_cond=True,
|
||||||
|
bs=bs,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert self.concat_keys is not None
|
||||||
|
assert len(self.concat_keys) == 1
|
||||||
|
c_cat = []
|
||||||
|
for ck in self.concat_keys:
|
||||||
|
cc = batch[ck]
|
||||||
|
if bs is not None:
|
||||||
|
cc = cc[:bs]
|
||||||
|
cc = cc.to(self.device)
|
||||||
|
cc = self.depth_model(cc)
|
||||||
|
cc = torch.nn.functional.interpolate(
|
||||||
|
cc,
|
||||||
|
size=z.shape[2:],
|
||||||
|
mode="bicubic",
|
||||||
|
align_corners=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
depth_min, depth_max = torch.amin(
|
||||||
|
cc, dim=[1, 2, 3], keepdim=True
|
||||||
|
), torch.amax(cc, dim=[1, 2, 3], keepdim=True)
|
||||||
|
cc = 2.0 * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1.0
|
||||||
|
c_cat.append(cc)
|
||||||
|
c_cat = torch.cat(c_cat, dim=1)
|
||||||
|
all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
|
||||||
|
if return_first_stage_outputs:
|
||||||
|
return z, all_conds, x, xrec, xc
|
||||||
|
return z, all_conds
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def log_images(self, *args, **kwargs):
|
||||||
|
log = super().log_images(*args, **kwargs)
|
||||||
|
depth = self.depth_model(args[0][self.depth_stage_key])
|
||||||
|
depth_min, depth_max = torch.amin(
|
||||||
|
depth, dim=[1, 2, 3], keepdim=True
|
||||||
|
), torch.amax(depth, dim=[1, 2, 3], keepdim=True)
|
||||||
|
log["depth"] = 2.0 * (depth - depth_min) / (depth_max - depth_min) - 1.0
|
||||||
|
return log
|
||||||
|
|
||||||
|
|
||||||
|
class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
|
||||||
|
"""
|
||||||
|
condition on low-res image (and optionally on some spatial noise augmentation)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
concat_keys=("lr",),
|
||||||
|
reshuffle_patch_size=None,
|
||||||
|
low_scale_config=None,
|
||||||
|
low_scale_key=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(concat_keys=concat_keys, **kwargs)
|
||||||
|
self.reshuffle_patch_size = reshuffle_patch_size
|
||||||
|
self.low_scale_model = None
|
||||||
|
if low_scale_config is not None:
|
||||||
|
print("Initializing a low-scale model")
|
||||||
|
assert low_scale_key is not None
|
||||||
|
self.instantiate_low_stage(low_scale_config)
|
||||||
|
self.low_scale_key = low_scale_key
|
||||||
|
|
||||||
|
def instantiate_low_stage(self, config):
|
||||||
|
model = instantiate_from_config(config)
|
||||||
|
self.low_scale_model = model.eval()
|
||||||
|
self.low_scale_model.train = disabled_train
|
||||||
|
for param in self.low_scale_model.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def get_input(
|
||||||
|
self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False
|
||||||
|
):
|
||||||
|
# note: restricted to non-trainable encoders currently
|
||||||
|
assert (
|
||||||
|
not self.cond_stage_trainable
|
||||||
|
), "trainable cond stages not yet supported for upscaling-ft"
|
||||||
|
z, c, x, xrec, xc = super().get_input(
|
||||||
|
batch,
|
||||||
|
self.first_stage_key,
|
||||||
|
return_first_stage_outputs=True,
|
||||||
|
force_c_encode=True,
|
||||||
|
return_original_cond=True,
|
||||||
|
bs=bs,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert self.concat_keys is not None
|
||||||
|
assert len(self.concat_keys) == 1
|
||||||
|
# optionally make spatial noise_level here
|
||||||
|
c_cat = []
|
||||||
|
noise_level = None
|
||||||
|
for ck in self.concat_keys:
|
||||||
|
cc = batch[ck]
|
||||||
|
cc = rearrange(cc, "b h w c -> b c h w")
|
||||||
|
if self.reshuffle_patch_size is not None:
|
||||||
|
assert isinstance(self.reshuffle_patch_size, int)
|
||||||
|
cc = rearrange(
|
||||||
|
cc,
|
||||||
|
"b c (p1 h) (p2 w) -> b (p1 p2 c) h w",
|
||||||
|
p1=self.reshuffle_patch_size,
|
||||||
|
p2=self.reshuffle_patch_size,
|
||||||
|
)
|
||||||
|
if bs is not None:
|
||||||
|
cc = cc[:bs]
|
||||||
|
cc = cc.to(self.device)
|
||||||
|
if self.low_scale_model is not None and ck == self.low_scale_key:
|
||||||
|
cc, noise_level = self.low_scale_model(cc)
|
||||||
|
c_cat.append(cc)
|
||||||
|
c_cat = torch.cat(c_cat, dim=1)
|
||||||
|
if noise_level is not None:
|
||||||
|
all_conds = {"c_concat": [c_cat], "c_crossattn": [c], "c_adm": noise_level}
|
||||||
|
else:
|
||||||
|
all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
|
||||||
|
if return_first_stage_outputs:
|
||||||
|
return z, all_conds, x, xrec, xc
|
||||||
|
return z, all_conds
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def log_images(self, *args, **kwargs):
|
||||||
|
log = super().log_images(*args, **kwargs)
|
||||||
|
log["lr"] = rearrange(args[0]["lr"], "b h w c -> b c h w")
|
||||||
|
return log
|
||||||
|
105
imaginairy/modules/diffusion/upscaling.py
Normal file
105
imaginairy/modules/diffusion/upscaling.py
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from imaginairy.modules.diffusion.util import extract_into_tensor, make_beta_schedule
|
||||||
|
|
||||||
|
|
||||||
|
class AbstractLowScaleModel(nn.Module):
|
||||||
|
# for concatenating a downsampled image to the latent representation
|
||||||
|
def __init__(self, noise_schedule_config=None):
|
||||||
|
super().__init__()
|
||||||
|
if noise_schedule_config is not None:
|
||||||
|
self.register_schedule(**noise_schedule_config)
|
||||||
|
|
||||||
|
def register_schedule(
|
||||||
|
self,
|
||||||
|
beta_schedule="linear",
|
||||||
|
timesteps=1000,
|
||||||
|
linear_start=1e-4,
|
||||||
|
linear_end=2e-2,
|
||||||
|
cosine_s=8e-3,
|
||||||
|
):
|
||||||
|
betas = make_beta_schedule(
|
||||||
|
beta_schedule,
|
||||||
|
timesteps,
|
||||||
|
linear_start=linear_start,
|
||||||
|
linear_end=linear_end,
|
||||||
|
cosine_s=cosine_s,
|
||||||
|
)
|
||||||
|
alphas = 1.0 - betas
|
||||||
|
alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||||
|
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
|
||||||
|
|
||||||
|
(timesteps,) = betas.shape
|
||||||
|
self.num_timesteps = int(timesteps)
|
||||||
|
self.linear_start = linear_start
|
||||||
|
self.linear_end = linear_end
|
||||||
|
assert (
|
||||||
|
alphas_cumprod.shape[0] == self.num_timesteps
|
||||||
|
), "alphas have to be defined for each timestep"
|
||||||
|
|
||||||
|
to_torch = partial(torch.tensor, dtype=torch.float32)
|
||||||
|
|
||||||
|
self.register_buffer("betas", to_torch(betas))
|
||||||
|
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
|
||||||
|
self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
|
||||||
|
|
||||||
|
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||||
|
self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
|
||||||
|
self.register_buffer(
|
||||||
|
"sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
"sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
|
||||||
|
)
|
||||||
|
|
||||||
|
def q_sample(self, x_start, t, noise=None):
|
||||||
|
if noise is None:
|
||||||
|
noise = torch.randn_like(x_start)
|
||||||
|
return (
|
||||||
|
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
||||||
|
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
|
||||||
|
* noise
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x, None
|
||||||
|
|
||||||
|
def decode(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleImageConcat(AbstractLowScaleModel):
|
||||||
|
# no noise level conditioning
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(noise_schedule_config=None)
|
||||||
|
self.max_noise_level = 0
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# fix to constant noise level
|
||||||
|
return x, torch.zeros(x.shape[0], device=x.device).long()
|
||||||
|
|
||||||
|
|
||||||
|
class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
|
||||||
|
def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
|
||||||
|
super().__init__(noise_schedule_config=noise_schedule_config)
|
||||||
|
self.max_noise_level = max_noise_level
|
||||||
|
|
||||||
|
def forward(self, x, noise_level=None):
|
||||||
|
if noise_level is None:
|
||||||
|
noise_level = torch.randint(
|
||||||
|
0, self.max_noise_level, (x.shape[0],), device=x.device
|
||||||
|
).long()
|
||||||
|
else:
|
||||||
|
assert isinstance(noise_level, torch.Tensor)
|
||||||
|
z = self.q_sample(x, noise_level)
|
||||||
|
return z, noise_level
|
Loading…
Reference in New Issue
Block a user