feature: k-diffusion samplers

- improved image logging functionality. can just stick log_latent wherever you want
- improved some variable naming
- moved all the samplers together
- vendored k-diffusion library
This commit is contained in:
Bryce 2022-09-14 00:40:25 -07:00
parent 20ac04d9df
commit b4a3b8c2b3
27 changed files with 2317 additions and 132 deletions

View File

@ -77,15 +77,23 @@ vendor_openai_clip:
echo "vendored from git@github.com:openai/CLIP.git" | tee ./imaginairy/vendored/clip/readme.txt
revendorize:
make vendorize REPO=git@github.com:openai/CLIP.git PKG=clip
#make vendorize REPO=git@github.com:xinntao/Real-ESRGAN.git PKG=realesrgan
make vendorize REPO=git@github.com:openai/CLIP.git PKG=clip COMMIT=d50d76daa670286dd6cacf3bcd80b5e4823fc8e1
make vendorize REPO=git@github.com:crowsonkb/k-diffusion.git PKG=k_diffusion COMMIT=1a0703dfb7d24d8806267c3e7ccc4caf67fd1331
#sed -i'' -e 's/^import\sclip/from\simaginairy.vendored\simport\sclip/g' imaginairy/vendored/k_diffusion/evaluation.py
rm imaginairy/vendored/k_diffusion/evaluation.py
touch imaginairy/vendored/k_diffusion/evaluation.py
rm imaginairy/vendored/k_diffusion/config.py
touch imaginairy/vendored/k_diffusion/config.py
# without this most of the k-diffusion samplers didn't work
sed -i'' -e 's#return (x - denoised) / utils.append_dims(sigma, x.ndim)#return (x - denoised) / sigma#g' imaginairy/vendored/k_diffusion/sampling.py
make af
vendorize: ## vendorize a github repo. `make vendorize REPO=git@github.com:openai/CLIP.git PKG=clip`
mkdir -p ./downloads
-cd ./downloads && git clone $(REPO) $(PKG)
cd ./downloads/$(PKG) && git pull
cd ./downloads/$(PKG) && git fetch && git checkout $(COMMIT)
rm -rf ./imaginairy/vendored/$(PKG)
cp -R ./downloads/$(PKG)/$(PKG) imaginairy/vendored/
git --git-dir ./downloads/$(PKG)/.git rev-parse HEAD | tee ./imaginairy/vendored/$(PKG)/clip-commit-hash.txt

View File

@ -141,7 +141,7 @@ imagine_image_files(prompts, outdir="./my-art")
- ✅ init-image at command line
- prompt expansion
- Image Generation Features
- add k-diffusion sampling methods
- add k-diffusion sampling methods
- upscaling
- ✅ realesrgan
- ldm

View File

@ -16,9 +16,9 @@ from transformers import cached_path
from imaginairy.enhancers.face_restoration_codeformer import enhance_faces
from imaginairy.enhancers.upscale_realesrgan import upscale_image
from imaginairy.modules.diffusion.ddim import DDIMSampler
from imaginairy.modules.diffusion.plms import PLMSSampler
from imaginairy.img_log import LatentLoggingContext, log_latent
from imaginairy.safety import is_nsfw, safety_models
from imaginairy.samplers.base import get_sampler
from imaginairy.schema import ImaginePrompt, ImagineResult
from imaginairy.utils import (
fix_torch_nn_layer_norm,
@ -57,7 +57,10 @@ def load_model_from_config(config):
def patch_conv(**patch):
"""https://github.com/replicate/cog-stable-diffusion/compare/main...TomMoore515:material_stable_diffusion:main"""
"""
Patch to enable tiling mode
https://github.com/replicate/cog-stable-diffusion/compare/main...TomMoore515:material_stable_diffusion:main
"""
cls = torch.nn.Conv2d
init = cls.__init__
@ -96,34 +99,26 @@ def imagine_image_files(
os.makedirs(outdir, exist_ok=True)
base_count = len(os.listdir(outdir))
step_count = 0
output_file_extension = output_file_extension.lower()
if output_file_extension not in {"jpg", "png"}:
raise ValueError("Must output a png or jpg")
def _record_steps(samples, description, model, prompt):
nonlocal step_count
step_count += 1
samples = model.decode_first_stage(samples)
samples = torch.clamp((samples + 1.0) / 2.0, min=0.0, max=1.0)
def _record_step(img, description, step_count, prompt):
steps_path = os.path.join(outdir, "steps", f"{base_count:08}_S{prompt.seed}")
os.makedirs(steps_path, exist_ok=True)
for pred_x0 in samples:
pred_x0 = 255.0 * rearrange(pred_x0.cpu().numpy(), "c h w -> h w c")
filename = f"{base_count:08}_S{prompt.seed}_step{step_count:04}.jpg"
img = Image.fromarray(pred_x0.astype(np.uint8))
draw = ImageDraw.Draw(img)
draw.text((10, 10), str(description))
img.save(os.path.join(steps_path, filename))
filename = f"{base_count:08}_S{prompt.seed}_step{step_count:04}.jpg"
destination = os.path.join(steps_path, filename)
draw = ImageDraw.Draw(img)
draw.text((10, 10), str(description))
img.save(destination)
img_callback = _record_steps if record_step_images else None
for result in imagine(
prompts,
latent_channels=latent_channels,
downsampling_factor=downsampling_factor,
precision=precision,
ddim_eta=ddim_eta,
img_callback=img_callback,
img_callback=_record_step if record_step_images else None,
tile_mode=tile_mode,
):
prompt = result.prompt
@ -164,6 +159,7 @@ def imagine(
prompts = [ImaginePrompt(prompts)] if isinstance(prompts, str) else prompts
prompts = [prompts] if isinstance(prompts, ImaginePrompt) else prompts
_img_callback = None
step_count = 0
precision_scope = (
autocast
@ -172,108 +168,101 @@ def imagine(
)
with (torch.no_grad(), precision_scope(get_device()), fix_torch_nn_layer_norm()):
for prompt in prompts:
logger.info(f"Generating {prompt.prompt_description()}")
seed_everything(prompt.seed)
with LatentLoggingContext(
prompt=prompt, model=model, img_callback=img_callback
):
logger.info(f"Generating {prompt.prompt_description()}")
seed_everything(prompt.seed)
uc = None
if prompt.prompt_strength != 1.0:
uc = model.get_learned_conditioning(1 * [""])
total_weight = sum(wp.weight for wp in prompt.prompts)
c = sum(
[
model.get_learned_conditioning(wp.text) * (wp.weight / total_weight)
for wp in prompt.prompts
uc = None
if prompt.prompt_strength != 1.0:
uc = model.get_learned_conditioning(1 * [""])
total_weight = sum(wp.weight for wp in prompt.prompts)
c = sum(
[
model.get_learned_conditioning(wp.text)
* (wp.weight / total_weight)
for wp in prompt.prompts
]
)
shape = [
latent_channels,
prompt.height // downsampling_factor,
prompt.width // downsampling_factor,
]
)
def _img_callback(samples, description):
if img_callback:
img_callback(samples, description, model, prompt)
start_code = None
sampler = get_sampler(prompt.sampler_type, model)
if prompt.init_image:
generation_strength = 1 - prompt.init_image_strength
ddim_steps = int(prompt.steps / generation_strength)
sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta)
shape = [
latent_channels,
prompt.height // downsampling_factor,
prompt.width // downsampling_factor,
]
t_enc = int(generation_strength * ddim_steps)
init_image, w, h = img_path_to_torch_image(prompt.init_image)
init_image = init_image.to(get_device())
init_latent = model.get_first_stage_encoding(
model.encode_first_stage(init_image)
)
log_latent(init_latent, "init_latent")
start_code = None
sampler = get_sampler(prompt.sampler_type, model)
if prompt.init_image:
generation_strength = 1 - prompt.init_image_strength
ddim_steps = int(prompt.steps / generation_strength)
sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta)
# encode (scaled latent)
z_enc = sampler.stochastic_encode(
init_latent, torch.tensor([t_enc]).to(get_device())
)
log_latent(z_enc, "z_enc")
t_enc = int(generation_strength * ddim_steps)
init_image, w, h = img_path_to_torch_image(prompt.init_image)
init_image = init_image.to(get_device())
init_latent = model.get_first_stage_encoding(
model.encode_first_stage(init_image)
)
_img_callback(init_latent, "init_latent")
# decode it
samples = sampler.decode(
z_enc,
c,
t_enc,
unconditional_guidance_scale=prompt.prompt_strength,
unconditional_conditioning=uc,
img_callback=_img_callback,
)
else:
# encode (scaled latent)
z_enc = sampler.stochastic_encode(
init_latent, torch.tensor([t_enc]).to(get_device())
)
_img_callback(z_enc, "z_enc")
samples, _ = sampler.sample(
num_steps=prompt.steps,
conditioning=c,
batch_size=1,
shape=shape,
unconditional_guidance_scale=prompt.prompt_strength,
unconditional_conditioning=uc,
eta=ddim_eta,
initial_noise_tensor=start_code,
img_callback=_img_callback,
)
# decode it
samples = sampler.decode(
z_enc,
c,
t_enc,
unconditional_guidance_scale=prompt.prompt_strength,
unconditional_conditioning=uc,
img_callback=_img_callback,
)
else:
x_samples = model.decode_first_stage(samples)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
samples, _ = sampler.sample(
S=prompt.steps,
conditioning=c,
batch_size=1,
shape=shape,
unconditional_guidance_scale=prompt.prompt_strength,
unconditional_conditioning=uc,
eta=ddim_eta,
x_T=start_code,
img_callback=_img_callback,
)
for x_sample in x_samples:
x_sample = 255.0 * rearrange(
x_sample.cpu().numpy(), "c h w -> h w c"
)
x_sample_8_orig = x_sample.astype(np.uint8)
img = Image.fromarray(x_sample_8_orig)
upscaled_img = None
if not IMAGINAIRY_ALLOW_NSFW and is_nsfw(
img, x_sample, half_mode=half_mode
):
logger.info(" ⚠️ Filtering NSFW image")
img = img.filter(ImageFilter.GaussianBlur(radius=40))
x_samples = model.decode_first_stage(samples)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
if prompt.fix_faces:
logger.info(" Fixing 😊 's in 🖼 using GFPGAN...")
img = enhance_faces(img, fidelity=0.2)
if prompt.upscale:
logger.info(" Upscaling 🖼 using real-ESRGAN...")
upscaled_img = upscale_image(img)
for x_sample in x_samples:
x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c")
x_sample_8_orig = x_sample.astype(np.uint8)
img = Image.fromarray(x_sample_8_orig)
upscaled_img = None
if not IMAGINAIRY_ALLOW_NSFW and is_nsfw(
img, x_sample, half_mode=half_mode
):
logger.info(" ⚠️ Filtering NSFW image")
img = img.filter(ImageFilter.GaussianBlur(radius=40))
if prompt.fix_faces:
logger.info(" Fixing 😊 's in 🖼 using GFPGAN...")
img = enhance_faces(img, fidelity=0.2)
if prompt.upscale:
logger.info(" Upscaling 🖼 using real-ESRGAN...")
upscaled_img = upscale_image(img)
yield ImagineResult(img=img, prompt=prompt, upscaled_img=upscaled_img)
yield ImagineResult(
img=img, prompt=prompt, upscaled_img=upscaled_img
)
def prompt_normalized(prompt):
return re.sub(r"[^a-zA-Z0-9.,]+", "_", prompt)[:130]
DOWNLOADED_FILES_PATH = f"{LIB_PATH}/../downloads/"
def get_sampler(sampler_type, model):
sampler_type = sampler_type.upper()
if sampler_type == "PLMS":
return PLMSSampler(model)
elif sampler_type == "DDIM":
return DDIMSampler(model)

View File

@ -100,8 +100,19 @@ def configure_logging(level="INFO"):
@click.option("--fix-faces-method", default="gfpgan", type=click.Choice(["gfpgan"]))
@click.option(
"--sampler-type",
default="PLMS",
type=click.Choice(["PLMS", "DDIM"]),
default="plms",
type=click.Choice(
[
"plms",
"ddim",
"k_lms",
"k_dpm_2",
"k_dpm_2_a",
"k_euler",
"k_euler_a",
"k_heun",
]
),
help="What sampling strategy to use",
)
@click.option("--ddim-eta", default=0.0, type=float)
@ -154,7 +165,7 @@ def imagine_cmd(
logger.info(
f"🤖🧠 imaginAIry received {len(prompt_texts)} prompt(s) and will repeat them {repeats} times to create {total_image_count} images."
)
if init_image and sampler_type != "DDIM":
if init_image and sampler_type == "DDIM":
sampler_type = "DDIM"
prompts = []

55
imaginairy/img_log.py Normal file
View File

@ -0,0 +1,55 @@
import logging
import numpy as np
import torch
from einops import rearrange
from PIL import Image
_CURRENT_LOGGING_CONTEXT = None
logger = logging.getLogger(__name__)
def log_latent(latents, description):
if _CURRENT_LOGGING_CONTEXT is None:
return
if torch.isnan(latents).any() or torch.isinf(latents).any():
logger.error(
"Inf/NaN values showing in transformer."
+ repr(latents)[:50]
+ " "
+ description[:50]
)
_CURRENT_LOGGING_CONTEXT.log_latents(latents, description)
class LatentLoggingContext:
def __init__(self, prompt, model, img_callback=None):
self.prompt = prompt
self.model = model
self.step_count = 0
self.img_callback = img_callback
def __enter__(self):
global _CURRENT_LOGGING_CONTEXT
_CURRENT_LOGGING_CONTEXT = self
return self
def __exit__(self, exc_type, exc_val, exc_tb):
global _CURRENT_LOGGING_CONTEXT
_CURRENT_LOGGING_CONTEXT = None
def log_latents(self, samples, description):
if not self.img_callback:
return
if samples.shape[1] != 4:
# logger.info(f"Didn't save tensor of shape {samples.shape} for {description}")
return
self.step_count += 1
description = f"{description} - {samples.shape}"
samples = self.model.decode_first_stage(samples)
samples = torch.clamp((samples + 1.0) / 2.0, min=0.0, max=1.0)
for pred_x0 in samples:
pred_x0 = 255.0 * rearrange(pred_x0.cpu().numpy(), "c h w -> h w c")
img = Image.fromarray(pred_x0.astype(np.uint8))
self.img_callback(img, description, self.step_count, self.prompt)

View File

View File

@ -0,0 +1,90 @@
import torch
from torch import nn
from imaginairy.samplers.ddim import DDIMSampler
from imaginairy.samplers.kdiff import KDiffusionSampler
from imaginairy.samplers.plms import PLMSSampler
from imaginairy.utils import get_device
_k_sampler_type_lookup = {
"k_dpm_2": "dpm_2",
"k_dpm_2_a": "dpm_2_ancestral",
"k_euler": "euler",
"k_euler_a": "euler_ancestral",
"k_heun": "heun",
"k_lms": "lms",
}
def get_sampler(sampler_type, model):
sampler_type = sampler_type.lower()
if sampler_type == "plms":
return PLMSSampler(model)
elif sampler_type == "ddim":
return DDIMSampler(model)
elif sampler_type.startswith("k_"):
sampler_type = _k_sampler_type_lookup[sampler_type]
return KDiffusionSampler(model, sampler_type)
class CFGDenoiser(nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
return uncond + (cond - uncond) * cond_scale
class DiffusionSampler:
"""
wip
hope to enforce an api upon samplers
"""
def __init__(self, noise_prediction_model, sampler_func, device=get_device()):
self.noise_prediction_model = noise_prediction_model
self.cfg_noise_prediction_model = CFGDenoiser(noise_prediction_model)
self.sampler_func = sampler_func
self.device = device
def sample(
self,
num_steps,
text_conditioning,
batch_size,
shape,
unconditional_guidance_scale,
unconditional_conditioning,
eta,
initial_noise_tensor=None,
img_callback=None,
):
size = (batch_size, *shape)
initial_noise_tensor = (
torch.randn(size, device="cpu").to(get_device())
if initial_noise_tensor is None
else initial_noise_tensor
)
sigmas = self.noise_prediction_model.get_sigmas(num_steps)
x = initial_noise_tensor * sigmas[0]
samples = self.sampler_func(
self.cfg_noise_prediction_model,
x,
sigmas,
extra_args={
"cond": text_conditioning,
"uncond": unconditional_conditioning,
"cond_scale": unconditional_guidance_scale,
},
disable=False,
)
return samples, None

View File

@ -5,6 +5,7 @@ import numpy as np
import torch
from tqdm import tqdm
from imaginairy.img_log import log_latent
from imaginairy.modules.diffusion.util import (
extract_into_tensor,
make_ddim_sampling_parameters,
@ -89,7 +90,7 @@ class DDIMSampler:
@torch.no_grad()
def sample(
self,
S,
num_steps,
batch_size,
shape,
conditioning=None,
@ -124,7 +125,7 @@ class DDIMSampler:
f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
)
self.make_schedule(ddim_num_steps=S, ddim_eta=eta)
self.make_schedule(ddim_num_steps=num_steps, ddim_eta=eta)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
@ -178,6 +179,7 @@ class DDIMSampler:
img = torch.randn(shape, device="cpu").to(device)
else:
img = x_T
log_latent(img, "initial noise")
if timesteps is None:
timesteps = (
@ -231,9 +233,9 @@ class DDIMSampler:
)
if callback:
callback(i)
if img_callback:
img_callback(pred_x0, i)
img_callback(pred_x0, i)
log_latent(img, "img")
log_latent(pred_x0, "pred_x0")
if index % log_every_t == 0 or index == total_steps - 1:
intermediates["x_inter"].append(img)
@ -378,7 +380,7 @@ class DDIMSampler:
# cond_grad = -torch.autograd.grad(original_loss, x_dec)[0]
# x_dec = x_dec.detach() + cond_grad * sigma_t ** 2
## x_dec_alt = x_dec + (original_loss * 0.1) ** 2
if img_callback:
img_callback(x_dec, f"x_dec {i}")
img_callback(pred_x0, f"pred_x0 {i}")
log_latent(x_dec, f"x_dec {i}")
log_latent(pred_x0, f"pred_x0 {i}")
return x_dec

View File

@ -0,0 +1,97 @@
import torch
from torch import nn
from imaginairy.img_log import log_latent
from imaginairy.utils import get_device
from imaginairy.vendored.k_diffusion import sampling as k_sampling
from imaginairy.vendored.k_diffusion.external import CompVisDenoiser
class CFGMaskedDenoiser(nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
def forward(self, x, sigma, uncond, cond, cond_scale, mask, x0, xi):
x_in = x
x_in = torch.cat([x_in] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
denoised = uncond + (cond - uncond) * cond_scale
if mask is not None:
assert x0 is not None
img_orig = x0
mask_inv = 1.0 - mask
denoised = (img_orig * mask_inv) + (mask * denoised)
return denoised
class CFGDenoiser(nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
return uncond + (cond - uncond) * cond_scale
class KDiffusionSampler:
def __init__(self, model, sampler_name):
self.model = model
self.cv_denoiser = CompVisDenoiser(model)
# self.cfg_denoiser = CompVisDenoiser(self.cv_denoiser)
self.sampler_name = sampler_name
self.sampler_func = getattr(k_sampling, f"sample_{sampler_name}")
def sample(
self,
num_steps,
conditioning,
batch_size,
shape,
unconditional_guidance_scale,
unconditional_conditioning,
eta,
initial_noise_tensor=None,
img_callback=None,
):
size = (batch_size, *shape)
initial_noise_tensor = (
torch.randn(size, device="cpu").to(get_device())
if initial_noise_tensor is None
else initial_noise_tensor
)
log_latent(initial_noise_tensor, "initial_noise_tensor")
sigmas = self.cv_denoiser.get_sigmas(num_steps)
x = initial_noise_tensor * sigmas[0]
log_latent(x, "initial_sigma_noised_tensor")
model_wrap_cfg = CFGDenoiser(self.cv_denoiser)
def callback(data):
log_latent(data["x"], "x")
log_latent(data["denoised"], "denoised")
samples = self.sampler_func(
model_wrap_cfg,
x,
sigmas,
extra_args={
"cond": conditioning,
"uncond": unconditional_conditioning,
"cond_scale": unconditional_guidance_scale,
},
disable=False,
callback=callback,
)
return samples, None

View File

@ -90,7 +90,7 @@ class PLMSSampler(object):
@torch.no_grad()
def sample(
self,
S,
num_steps,
batch_size,
shape,
conditioning=None,
@ -125,7 +125,7 @@ class PLMSSampler(object):
f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
)
self.make_schedule(ddim_num_steps=S, ddim_eta=eta)
self.make_schedule(ddim_num_steps=num_steps, ddim_eta=eta)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
@ -248,8 +248,8 @@ class PLMSSampler(object):
if callback:
callback(i)
if img_callback:
img_callback(img, i)
img_callback(pred_x0, i)
img_callback(img, "img")
img_callback(pred_x0, "pred_x0")
if index % log_every_t == 0 or index == total_steps - 1:
intermediates["x_inter"].append(img)

View File

@ -0,0 +1,12 @@
from . import (
augmentation,
config,
evaluation,
external,
gns,
layers,
models,
sampling,
utils,
)
from .layers import Denoiser

View File

@ -0,0 +1,113 @@
import math
import operator
from functools import reduce
import numpy as np
import torch
from skimage import transform
from torch import nn
def translate2d(tx, ty):
mat = [[1, 0, tx], [0, 1, ty], [0, 0, 1]]
return torch.tensor(mat, dtype=torch.float32)
def scale2d(sx, sy):
mat = [[sx, 0, 0], [0, sy, 0], [0, 0, 1]]
return torch.tensor(mat, dtype=torch.float32)
def rotate2d(theta):
mat = [
[torch.cos(theta), torch.sin(-theta), 0],
[torch.sin(theta), torch.cos(theta), 0],
[0, 0, 1],
]
return torch.tensor(mat, dtype=torch.float32)
class KarrasAugmentationPipeline:
def __init__(self, a_prob=0.12, a_scale=2**0.2, a_aniso=2**0.2, a_trans=1 / 8):
self.a_prob = a_prob
self.a_scale = a_scale
self.a_aniso = a_aniso
self.a_trans = a_trans
def __call__(self, image):
h, w = image.size
mats = [translate2d(h / 2 - 0.5, w / 2 - 0.5)]
# x-flip
a0 = torch.randint(2, []).float()
mats.append(scale2d(1 - 2 * a0, 1))
# y-flip
do = (torch.rand([]) < self.a_prob).float()
a1 = torch.randint(2, []).float() * do
mats.append(scale2d(1, 1 - 2 * a1))
# scaling
do = (torch.rand([]) < self.a_prob).float()
a2 = torch.randn([]) * do
mats.append(scale2d(self.a_scale**a2, self.a_scale**a2))
# rotation
do = (torch.rand([]) < self.a_prob).float()
a3 = (torch.rand([]) * 2 * math.pi - math.pi) * do
mats.append(rotate2d(-a3))
# anisotropy
do = (torch.rand([]) < self.a_prob).float()
a4 = (torch.rand([]) * 2 * math.pi - math.pi) * do
a5 = torch.randn([]) * do
mats.append(rotate2d(a4))
mats.append(scale2d(self.a_aniso**a5, self.a_aniso**-a5))
mats.append(rotate2d(-a4))
# translation
do = (torch.rand([]) < self.a_prob).float()
a6 = torch.randn([]) * do
a7 = torch.randn([]) * do
mats.append(translate2d(self.a_trans * w * a6, self.a_trans * h * a7))
# form the transformation matrix and conditioning vector
mats.append(translate2d(-h / 2 + 0.5, -w / 2 + 0.5))
mat = reduce(operator.matmul, mats)
cond = torch.stack(
[a0, a1, a2, a3.cos() - 1, a3.sin(), a5 * a4.cos(), a5 * a4.sin(), a6, a7]
)
# apply the transformation
image_orig = np.array(image, dtype=np.float32) / 255
if image_orig.ndim == 2:
image_orig = image_orig[..., None]
tf = transform.AffineTransform(mat.numpy())
image = transform.warp(
image_orig,
tf.inverse,
order=3,
mode="reflect",
cval=0.5,
clip=False,
preserve_range=True,
)
image_orig = torch.as_tensor(image_orig).movedim(2, 0) * 2 - 1
image = torch.as_tensor(image).movedim(2, 0) * 2 - 1
return image, image_orig, cond
class KarrasAugmentWrapper(nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
def forward(self, input, sigma, aug_cond=None, mapping_cond=None, **kwargs):
if aug_cond is None:
aug_cond = input.new_zeros([input.shape[0], 9])
if mapping_cond is None:
mapping_cond = aug_cond
else:
mapping_cond = torch.cat([aug_cond, mapping_cond], dim=1)
return self.inner_model(input, sigma, mapping_cond=mapping_cond, **kwargs)
def set_skip_stages(self, skip_stages):
return self.inner_model.set_skip_stages(skip_stages)
def set_patch_size(self, patch_size):
return self.inner_model.set_patch_size(patch_size)

View File

@ -0,0 +1 @@
1a0703dfb7d24d8806267c3e7ccc4caf67fd1331

View File

@ -0,0 +1,145 @@
import math
import torch
from torch import nn
from . import sampling, utils
class VDenoiser(nn.Module):
"""A v-diffusion-pytorch model wrapper for k-diffusion."""
def __init__(self, inner_model):
super().__init__()
self.inner_model = inner_model
self.sigma_data = 1.0
def get_scalings(self, sigma):
c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
c_out = -sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
return c_skip, c_out, c_in
def sigma_to_t(self, sigma):
return sigma.atan() / math.pi * 2
def t_to_sigma(self, t):
return (t * math.pi / 2).tan()
def loss(self, input, noise, sigma, **kwargs):
c_skip, c_out, c_in = [
utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)
]
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
model_output = self.inner_model(
noised_input * c_in, self.sigma_to_t(sigma), **kwargs
)
target = (input - c_skip * noised_input) / c_out
return (model_output - target).pow(2).flatten(1).mean(1)
def forward(self, input, sigma, **kwargs):
c_skip, c_out, c_in = [
utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)
]
return (
self.inner_model(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out
+ input * c_skip
)
class DiscreteSchedule(nn.Module):
"""A mapping between continuous noise levels (sigmas) and a list of discrete noise
levels."""
def __init__(self, sigmas, quantize):
super().__init__()
self.register_buffer("sigmas", sigmas)
self.quantize = quantize
def get_sigmas(self, n=None):
if n is None:
return sampling.append_zero(self.sigmas.flip(0))
t_max = len(self.sigmas) - 1
t = torch.linspace(t_max, 0, n, device=self.sigmas.device)
return sampling.append_zero(self.t_to_sigma(t))
def sigma_to_t(self, sigma, quantize=None):
quantize = self.quantize if quantize is None else quantize
dists = torch.abs(sigma - self.sigmas[:, None])
if quantize:
return torch.argmin(dists, dim=0).view(sigma.shape)
low_idx, high_idx = torch.sort(
torch.topk(dists, dim=0, k=2, largest=False).indices, dim=0
)[0]
low, high = self.sigmas[low_idx], self.sigmas[high_idx]
w = (low - sigma) / (low - high)
w = w.clamp(0, 1)
t = (1 - w) * low_idx + w * high_idx
return t.view(sigma.shape)
def t_to_sigma(self, t):
t = t.float()
low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
return (1 - w) * self.sigmas[low_idx] + w * self.sigmas[high_idx]
class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
"""A wrapper for discrete schedule DDPM models that output eps (the predicted
noise)."""
def __init__(self, model, alphas_cumprod, quantize):
super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
self.inner_model = model
self.sigma_data = 1.0
def get_scalings(self, sigma):
c_out = -sigma
c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
return c_out, c_in
def get_eps(self, *args, **kwargs):
return self.inner_model(*args, **kwargs)
def loss(self, input, noise, sigma, **kwargs):
c_out, c_in = [
utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)
]
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
eps = self.get_eps(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
return (eps - noise).pow(2).flatten(1).mean(1)
def forward(self, input, sigma, **kwargs):
c_out, c_in = [
utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)
]
eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)
return input + eps * c_out
class OpenAIDenoiser(DiscreteEpsDDPMDenoiser):
"""A wrapper for OpenAI diffusion models."""
def __init__(
self, model, diffusion, quantize=False, has_learned_sigmas=True, device="cpu"
):
alphas_cumprod = torch.tensor(
diffusion.alphas_cumprod, device=device, dtype=torch.float32
)
super().__init__(model, alphas_cumprod, quantize=quantize)
self.has_learned_sigmas = has_learned_sigmas
def get_eps(self, *args, **kwargs):
model_output = self.inner_model(*args, **kwargs)
if self.has_learned_sigmas:
return model_output.chunk(2, dim=1)[0]
return model_output
class CompVisDenoiser(DiscreteEpsDDPMDenoiser):
"""A wrapper for CompVis diffusion models."""
def __init__(self, model, quantize=False, device="cpu"):
super().__init__(model, model.alphas_cumprod, quantize=quantize)
def get_eps(self, *args, **kwargs):
return self.inner_model.apply_model(*args, **kwargs)

View File

@ -0,0 +1,115 @@
import torch
from torch import nn
class DDPGradientStatsHook:
def __init__(self, ddp_module):
try:
ddp_module.register_comm_hook(self, self._hook_fn)
except AttributeError:
raise ValueError(
"DDPGradientStatsHook does not support non-DDP wrapped modules"
)
self._clear_state()
def _clear_state(self):
self.bucket_sq_norms_small_batch = []
self.bucket_sq_norms_large_batch = []
@staticmethod
def _hook_fn(self, bucket):
buf = bucket.buffer()
self.bucket_sq_norms_small_batch.append(buf.pow(2).sum())
fut = torch.distributed.all_reduce(
buf, op=torch.distributed.ReduceOp.AVG, async_op=True
).get_future()
def callback(fut):
buf = fut.value()[0]
self.bucket_sq_norms_large_batch.append(buf.pow(2).sum())
return buf
return fut.then(callback)
def get_stats(self):
sq_norm_small_batch = sum(self.bucket_sq_norms_small_batch)
sq_norm_large_batch = sum(self.bucket_sq_norms_large_batch)
self._clear_state()
stats = torch.stack([sq_norm_small_batch, sq_norm_large_batch])
torch.distributed.all_reduce(stats, op=torch.distributed.ReduceOp.AVG)
return stats[0].item(), stats[1].item()
class GradientNoiseScale:
"""Calculates the gradient noise scale (1 / SNR), or critical batch size,
from _An Empirical Model of Large-Batch Training_,
https://arxiv.org/abs/1812.06162).
Args:
beta (float): The decay factor for the exponential moving averages used to
calculate the gradient noise scale.
Default: 0.9998
eps (float): Added for numerical stability.
Default: 1e-8
"""
def __init__(self, beta=0.9998, eps=1e-8):
self.beta = beta
self.eps = eps
self.ema_sq_norm = 0.0
self.ema_var = 0.0
self.beta_cumprod = 1.0
self.gradient_noise_scale = float("nan")
def state_dict(self):
"""Returns the state of the object as a :class:`dict`."""
return dict(self.__dict__.items())
def load_state_dict(self, state_dict):
"""Loads the object's state.
Args:
state_dict (dict): object state. Should be an object returned
from a call to :meth:`state_dict`.
"""
self.__dict__.update(state_dict)
def update(
self, sq_norm_small_batch, sq_norm_large_batch, n_small_batch, n_large_batch
):
"""Updates the state with a new batch's gradient statistics, and returns the
current gradient noise scale.
Args:
sq_norm_small_batch (float): The mean of the squared 2-norms of microbatch or
per sample gradients.
sq_norm_large_batch (float): The squared 2-norm of the mean of the microbatch or
per sample gradients.
n_small_batch (int): The batch size of the individual microbatch or per sample
gradients (1 if per sample).
n_large_batch (int): The total batch size of the mean of the microbatch or
per sample gradients.
"""
est_sq_norm = (
n_large_batch * sq_norm_large_batch - n_small_batch * sq_norm_small_batch
) / (n_large_batch - n_small_batch)
est_var = (sq_norm_small_batch - sq_norm_large_batch) / (
1 / n_small_batch - 1 / n_large_batch
)
self.ema_sq_norm = self.beta * self.ema_sq_norm + (1 - self.beta) * est_sq_norm
self.ema_var = self.beta * self.ema_var + (1 - self.beta) * est_var
self.beta_cumprod *= self.beta
self.gradient_noise_scale = max(self.ema_var, self.eps) / max(
self.ema_sq_norm, self.eps
)
return self.gradient_noise_scale
def get_gns(self):
"""Returns the current gradient noise scale."""
return self.gradient_noise_scale
def get_stats(self):
"""Returns the current (debiased) estimates of the squared mean gradient
and gradient variance."""
return self.ema_sq_norm / (1 - self.beta_cumprod), self.ema_var / (
1 - self.beta_cumprod
)

View File

@ -0,0 +1,296 @@
import math
import torch
from einops import rearrange, repeat
from torch import nn
from torch.nn import functional as F
from . import utils
# Karras et al. preconditioned denoiser
class Denoiser(nn.Module):
"""A Karras et al. preconditioner for denoising diffusion models."""
def __init__(self, inner_model, sigma_data=1.0):
super().__init__()
self.inner_model = inner_model
self.sigma_data = sigma_data
def get_scalings(self, sigma):
c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
return c_skip, c_out, c_in
def loss(self, input, noise, sigma, **kwargs):
c_skip, c_out, c_in = [
utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)
]
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
model_output = self.inner_model(noised_input * c_in, sigma, **kwargs)
target = (input - c_skip * noised_input) / c_out
return (model_output - target).pow(2).flatten(1).mean(1)
def forward(self, input, sigma, **kwargs):
c_skip, c_out, c_in = [
utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)
]
return self.inner_model(input * c_in, sigma, **kwargs) * c_out + input * c_skip
class DenoiserWithVariance(Denoiser):
def loss(self, input, noise, sigma, **kwargs):
c_skip, c_out, c_in = [
utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)
]
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
model_output, logvar = self.inner_model(
noised_input * c_in, sigma, return_variance=True, **kwargs
)
logvar = utils.append_dims(logvar, model_output.ndim)
target = (input - c_skip * noised_input) / c_out
losses = ((model_output - target) ** 2 / logvar.exp() + logvar) / 2
return losses.flatten(1).mean(1)
# Residual blocks
class ResidualBlock(nn.Module):
def __init__(self, *main, skip=None):
super().__init__()
self.main = nn.Sequential(*main)
self.skip = skip if skip else nn.Identity()
def forward(self, input):
return self.main(input) + self.skip(input)
# Noise level (and other) conditioning
class ConditionedModule(nn.Module):
pass
class UnconditionedModule(ConditionedModule):
def __init__(self, module):
self.module = module
def forward(self, input, cond):
return self.module(input)
class ConditionedSequential(nn.Sequential, ConditionedModule):
def forward(self, input, cond):
for module in self:
if isinstance(module, ConditionedModule):
input = module(input, cond)
else:
input = module(input)
return input
class ConditionedResidualBlock(ConditionedModule):
def __init__(self, *main, skip=None):
super().__init__()
self.main = ConditionedSequential(*main)
self.skip = skip if skip else nn.Identity()
def forward(self, input, cond):
skip = (
self.skip(input, cond)
if isinstance(self.skip, ConditionedModule)
else self.skip(input)
)
return self.main(input, cond) + skip
class AdaGN(ConditionedModule):
def __init__(self, feats_in, c_out, num_groups, eps=1e-5, cond_key="cond"):
super().__init__()
self.num_groups = num_groups
self.eps = eps
self.cond_key = cond_key
self.mapper = nn.Linear(feats_in, c_out * 2)
def forward(self, input, cond):
weight, bias = self.mapper(cond[self.cond_key]).chunk(2, dim=-1)
input = F.group_norm(input, self.num_groups, eps=self.eps)
return torch.addcmul(
utils.append_dims(bias, input.ndim),
input,
utils.append_dims(weight, input.ndim) + 1,
)
# Attention
class SelfAttention2d(ConditionedModule):
def __init__(self, c_in, n_head, norm, dropout_rate=0.0):
super().__init__()
assert c_in % n_head == 0
self.norm_in = norm(c_in)
self.n_head = n_head
self.qkv_proj = nn.Conv2d(c_in, c_in * 3, 1)
self.out_proj = nn.Conv2d(c_in, c_in, 1)
self.dropout = nn.Dropout(dropout_rate)
def forward(self, input, cond):
n, c, h, w = input.shape
qkv = self.qkv_proj(self.norm_in(input, cond))
qkv = qkv.view([n, self.n_head * 3, c // self.n_head, h * w]).transpose(2, 3)
q, k, v = qkv.chunk(3, dim=1)
scale = k.shape[3] ** -0.25
att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3)
att = self.dropout(att)
y = (att @ v).transpose(2, 3).contiguous().view([n, c, h, w])
return input + self.out_proj(y)
class CrossAttention2d(ConditionedModule):
def __init__(
self,
c_dec,
c_enc,
n_head,
norm_dec,
dropout_rate=0.0,
cond_key="cross",
cond_key_padding="cross_padding",
):
super().__init__()
assert c_dec % n_head == 0
self.cond_key = cond_key
self.cond_key_padding = cond_key_padding
self.norm_enc = nn.LayerNorm(c_enc)
self.norm_dec = norm_dec(c_dec)
self.n_head = n_head
self.q_proj = nn.Conv2d(c_dec, c_dec, 1)
self.kv_proj = nn.Linear(c_enc, c_dec * 2)
self.out_proj = nn.Conv2d(c_dec, c_dec, 1)
self.dropout = nn.Dropout(dropout_rate)
def forward(self, input, cond):
n, c, h, w = input.shape
q = self.q_proj(self.norm_dec(input, cond))
q = q.view([n, self.n_head, c // self.n_head, h * w]).transpose(2, 3)
kv = self.kv_proj(self.norm_enc(cond[self.cond_key]))
kv = kv.view([n, -1, self.n_head * 2, c // self.n_head]).transpose(1, 2)
k, v = kv.chunk(2, dim=1)
scale = k.shape[3] ** -0.25
att = (q * scale) @ (k.transpose(2, 3) * scale)
att = att - (cond[self.cond_key_padding][:, None, None, :]) * 10000
att = att.softmax(3)
att = self.dropout(att)
y = (att @ v).transpose(2, 3)
y = y.contiguous().view([n, c, h, w])
return input + self.out_proj(y)
# Downsampling/upsampling
_kernels = {
"linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8],
"cubic": [
-0.01171875,
-0.03515625,
0.11328125,
0.43359375,
0.43359375,
0.11328125,
-0.03515625,
-0.01171875,
],
"lanczos3": [
0.003689131001010537,
0.015056144446134567,
-0.03399861603975296,
-0.066637322306633,
0.13550527393817902,
0.44638532400131226,
0.44638532400131226,
0.13550527393817902,
-0.066637322306633,
-0.03399861603975296,
0.015056144446134567,
0.003689131001010537,
],
}
_kernels["bilinear"] = _kernels["linear"]
_kernels["bicubic"] = _kernels["cubic"]
class Downsample2d(nn.Module):
def __init__(self, kernel="linear", pad_mode="reflect"):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor([_kernels[kernel]])
self.pad = kernel_1d.shape[1] // 2 - 1
self.register_buffer("kernel", kernel_1d.T @ kernel_1d)
def forward(self, x):
x = F.pad(x, (self.pad,) * 4, self.pad_mode)
weight = x.new_zeros(
[x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]]
)
indices = torch.arange(x.shape[1], device=x.device)
weight[indices, indices] = self.kernel.to(weight)
return F.conv2d(x, weight, stride=2)
class Upsample2d(nn.Module):
def __init__(self, kernel="linear", pad_mode="reflect"):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor([_kernels[kernel]]) * 2
self.pad = kernel_1d.shape[1] // 2 - 1
self.register_buffer("kernel", kernel_1d.T @ kernel_1d)
def forward(self, x):
x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode)
weight = x.new_zeros(
[x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]]
)
indices = torch.arange(x.shape[1], device=x.device)
weight[indices, indices] = self.kernel.to(weight)
return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1)
# Embeddings
class FourierFeatures(nn.Module):
def __init__(self, in_features, out_features, std=1.0):
super().__init__()
assert out_features % 2 == 0
self.register_buffer(
"weight", torch.randn([out_features // 2, in_features]) * std
)
def forward(self, input):
f = 2 * math.pi * input @ self.weight.T
return torch.cat([f.cos(), f.sin()], dim=-1)
# U-Nets
class UNet(ConditionedModule):
def __init__(self, d_blocks, u_blocks, skip_stages=0):
super().__init__()
self.d_blocks = nn.ModuleList(d_blocks)
self.u_blocks = nn.ModuleList(u_blocks)
self.skip_stages = skip_stages
def forward(self, input, cond):
skips = []
for block in self.d_blocks[self.skip_stages :]:
input = block(input, cond)
skips.append(input)
for i, (block, skip) in enumerate(zip(self.u_blocks, reversed(skips))):
input = block(input, cond, skip if i > 0 else None)
return input

View File

@ -0,0 +1 @@
from .image_v1 import ImageDenoiserModelV1

View File

@ -0,0 +1,305 @@
import math
import torch
from torch import nn
from torch.nn import functional as F
from .. import layers, utils
def orthogonal_(module):
nn.init.orthogonal_(module.weight)
return module
class ResConvBlock(layers.ConditionedResidualBlock):
def __init__(self, feats_in, c_in, c_mid, c_out, group_size=32, dropout_rate=0.0):
skip = (
None
if c_in == c_out
else orthogonal_(nn.Conv2d(c_in, c_out, 1, bias=False))
)
super().__init__(
layers.AdaGN(feats_in, c_in, max(1, c_in // group_size)),
nn.GELU(),
nn.Conv2d(c_in, c_mid, 3, padding=1),
nn.Dropout2d(dropout_rate, inplace=True),
layers.AdaGN(feats_in, c_mid, max(1, c_mid // group_size)),
nn.GELU(),
nn.Conv2d(c_mid, c_out, 3, padding=1),
nn.Dropout2d(dropout_rate, inplace=True),
skip=skip,
)
class DBlock(layers.ConditionedSequential):
def __init__(
self,
n_layers,
feats_in,
c_in,
c_mid,
c_out,
group_size=32,
head_size=64,
dropout_rate=0.0,
downsample=False,
self_attn=False,
cross_attn=False,
c_enc=0,
):
modules = [nn.Identity()]
for i in range(n_layers):
my_c_in = c_in if i == 0 else c_mid
my_c_out = c_mid if i < n_layers - 1 else c_out
modules.append(
ResConvBlock(
feats_in, my_c_in, c_mid, my_c_out, group_size, dropout_rate
)
)
if self_attn:
norm = lambda c_in: layers.AdaGN(
feats_in, c_in, max(1, my_c_out // group_size)
)
modules.append(
layers.SelfAttention2d(
my_c_out, max(1, my_c_out // head_size), norm, dropout_rate
)
)
if cross_attn:
norm = lambda c_in: layers.AdaGN(
feats_in, c_in, max(1, my_c_out // group_size)
)
modules.append(
layers.CrossAttention2d(
my_c_out,
c_enc,
max(1, my_c_out // head_size),
norm,
dropout_rate,
)
)
super().__init__(*modules)
self.set_downsample(downsample)
def set_downsample(self, downsample):
self[0] = layers.Downsample2d() if downsample else nn.Identity()
return self
class UBlock(layers.ConditionedSequential):
def __init__(
self,
n_layers,
feats_in,
c_in,
c_mid,
c_out,
group_size=32,
head_size=64,
dropout_rate=0.0,
upsample=False,
self_attn=False,
cross_attn=False,
c_enc=0,
):
modules = []
for i in range(n_layers):
my_c_in = c_in if i == 0 else c_mid
my_c_out = c_mid if i < n_layers - 1 else c_out
modules.append(
ResConvBlock(
feats_in, my_c_in, c_mid, my_c_out, group_size, dropout_rate
)
)
if self_attn:
norm = lambda c_in: layers.AdaGN(
feats_in, c_in, max(1, my_c_out // group_size)
)
modules.append(
layers.SelfAttention2d(
my_c_out, max(1, my_c_out // head_size), norm, dropout_rate
)
)
if cross_attn:
norm = lambda c_in: layers.AdaGN(
feats_in, c_in, max(1, my_c_out // group_size)
)
modules.append(
layers.CrossAttention2d(
my_c_out,
c_enc,
max(1, my_c_out // head_size),
norm,
dropout_rate,
)
)
modules.append(nn.Identity())
super().__init__(*modules)
self.set_upsample(upsample)
def forward(self, input, cond, skip=None):
if skip is not None:
input = torch.cat([input, skip], dim=1)
return super().forward(input, cond)
def set_upsample(self, upsample):
self[-1] = layers.Upsample2d() if upsample else nn.Identity()
return self
class MappingNet(nn.Sequential):
def __init__(self, feats_in, feats_out, n_layers=2):
layers = []
for i in range(n_layers):
layers.append(
orthogonal_(nn.Linear(feats_in if i == 0 else feats_out, feats_out))
)
layers.append(nn.GELU())
super().__init__(*layers)
class ImageDenoiserModelV1(nn.Module):
def __init__(
self,
c_in,
feats_in,
depths,
channels,
self_attn_depths,
cross_attn_depths=None,
mapping_cond_dim=0,
unet_cond_dim=0,
cross_cond_dim=0,
dropout_rate=0.0,
patch_size=1,
skip_stages=0,
has_variance=False,
):
super().__init__()
self.c_in = c_in
self.channels = channels
self.unet_cond_dim = unet_cond_dim
self.patch_size = patch_size
self.has_variance = has_variance
self.timestep_embed = layers.FourierFeatures(1, feats_in)
if mapping_cond_dim > 0:
self.mapping_cond = nn.Linear(mapping_cond_dim, feats_in, bias=False)
self.mapping = MappingNet(feats_in, feats_in)
self.proj_in = nn.Conv2d(
(c_in + unet_cond_dim) * self.patch_size**2,
channels[max(0, skip_stages - 1)],
1,
)
self.proj_out = nn.Conv2d(
channels[max(0, skip_stages - 1)],
c_in * self.patch_size**2 + (1 if self.has_variance else 0),
1,
)
nn.init.zeros_(self.proj_out.weight)
nn.init.zeros_(self.proj_out.bias)
if cross_cond_dim == 0:
cross_attn_depths = [False] * len(self_attn_depths)
d_blocks, u_blocks = [], []
for i in range(len(depths)):
my_c_in = channels[max(0, i - 1)]
d_blocks.append(
DBlock(
depths[i],
feats_in,
my_c_in,
channels[i],
channels[i],
downsample=i > skip_stages,
self_attn=self_attn_depths[i],
cross_attn=cross_attn_depths[i],
c_enc=cross_cond_dim,
dropout_rate=dropout_rate,
)
)
for i in range(len(depths)):
my_c_in = channels[i] * 2 if i < len(depths) - 1 else channels[i]
my_c_out = channels[max(0, i - 1)]
u_blocks.append(
UBlock(
depths[i],
feats_in,
my_c_in,
channels[i],
my_c_out,
upsample=i > skip_stages,
self_attn=self_attn_depths[i],
cross_attn=cross_attn_depths[i],
c_enc=cross_cond_dim,
dropout_rate=dropout_rate,
)
)
self.u_net = layers.UNet(d_blocks, reversed(u_blocks), skip_stages=skip_stages)
def forward(
self,
input,
sigma,
mapping_cond=None,
unet_cond=None,
cross_cond=None,
cross_cond_padding=None,
return_variance=False,
):
c_noise = sigma.log() / 4
timestep_embed = self.timestep_embed(utils.append_dims(c_noise, 2))
mapping_cond_embed = (
torch.zeros_like(timestep_embed)
if mapping_cond is None
else self.mapping_cond(mapping_cond)
)
mapping_out = self.mapping(timestep_embed + mapping_cond_embed)
cond = {"cond": mapping_out}
if unet_cond is not None:
input = torch.cat([input, unet_cond], dim=1)
if cross_cond is not None:
cond["cross"] = cross_cond
cond["cross_padding"] = cross_cond_padding
if self.patch_size > 1:
input = F.pixel_unshuffle(input, self.patch_size)
input = self.proj_in(input)
input = self.u_net(input, cond)
input = self.proj_out(input)
if self.has_variance:
input, logvar = input[:, :-1], input[:, -1].flatten(1).mean(1)
if self.patch_size > 1:
input = F.pixel_shuffle(input, self.patch_size)
if self.has_variance and return_variance:
return input, logvar
return input
def set_skip_stages(self, skip_stages):
self.proj_in = nn.Conv2d(
self.proj_in.in_channels, self.channels[max(0, skip_stages - 1)], 1
)
self.proj_out = nn.Conv2d(
self.channels[max(0, skip_stages - 1)], self.proj_out.out_channels, 1
)
nn.init.zeros_(self.proj_out.weight)
nn.init.zeros_(self.proj_out.bias)
self.u_net.skip_stages = skip_stages
for i, block in enumerate(self.u_net.d_blocks):
block.set_downsample(i > skip_stages)
for i, block in enumerate(reversed(self.u_net.u_blocks)):
block.set_upsample(i > skip_stages)
return self
def set_patch_size(self, patch_size):
self.patch_size = patch_size
self.proj_in = nn.Conv2d(
(self.c_in + self.unet_cond_dim) * self.patch_size**2,
self.channels[max(0, self.u_net.skip_stages - 1)],
1,
)
self.proj_out = nn.Conv2d(
self.channels[max(0, self.u_net.skip_stages - 1)],
self.c_in * self.patch_size**2 + (1 if self.has_variance else 0),
1,
)
nn.init.zeros_(self.proj_out.weight)
nn.init.zeros_(self.proj_out.bias)

View File

@ -0,0 +1 @@
vendored from git@github.com:crowsonkb/k-diffusion.git

View File

@ -0,0 +1,333 @@
import math
import torch
from scipy import integrate
from torchdiffeq import odeint
from tqdm.auto import tqdm, trange
from . import utils
def append_zero(x):
return torch.cat([x, x.new_zeros([1])])
def get_sigmas_karras(n, sigma_min, sigma_max, rho=7.0, device="cpu"):
"""Constructs the noise schedule of Karras et al. (2022)."""
ramp = torch.linspace(0, 1, n)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return append_zero(sigmas).to(device)
def get_sigmas_exponential(n, sigma_min, sigma_max, device="cpu"):
"""Constructs an exponential noise schedule."""
sigmas = torch.linspace(
math.log(sigma_max), math.log(sigma_min), n, device=device
).exp()
return append_zero(sigmas)
def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device="cpu"):
"""Constructs a continuous VP noise schedule."""
t = torch.linspace(1, eps_s, n, device=device)
sigmas = torch.sqrt(torch.exp(beta_d * t**2 / 2 + beta_min * t) - 1)
return append_zero(sigmas)
def to_d(x, sigma, denoised):
"""Converts a denoiser output to a Karras ODE derivative."""
return (x - denoised) / sigma
def get_ancestral_step(sigma_from, sigma_to):
"""Calculates the noise level (sigma_down) to step down to and the amount
of noise to add (sigma_up) when doing an ancestral sampling step."""
sigma_up = (
sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2
) ** 0.5
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
return sigma_down, sigma_up
@torch.no_grad()
def sample_euler(
model,
x,
sigmas,
extra_args=None,
callback=None,
disable=None,
s_churn=0.0,
s_tmin=0.0,
s_tmax=float("inf"),
s_noise=1.0,
):
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
gamma = (
min(s_churn / (len(sigmas) - 1), 2**0.5 - 1)
if s_tmin <= sigmas[i] <= s_tmax
else 0.0
)
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, denoised)
if callback is not None:
callback(
{
"x": x,
"i": i,
"sigma": sigmas[i],
"sigma_hat": sigma_hat,
"denoised": denoised,
}
)
dt = sigmas[i + 1] - sigma_hat
# Euler method
x = x + d * dt
return x
@torch.no_grad()
def sample_euler_ancestral(
model, x, sigmas, extra_args=None, callback=None, disable=None
):
"""Ancestral sampling with Euler method steps."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
if callback is not None:
callback(
{
"x": x,
"i": i,
"sigma": sigmas[i],
"sigma_hat": sigmas[i],
"denoised": denoised,
}
)
d = to_d(x, sigmas[i], denoised)
# Euler method
dt = sigma_down - sigmas[i]
x = x + d * dt
x = x + torch.randn_like(x) * sigma_up
return x
@torch.no_grad()
def sample_heun(
model,
x,
sigmas,
extra_args=None,
callback=None,
disable=None,
s_churn=0.0,
s_tmin=0.0,
s_tmax=float("inf"),
s_noise=1.0,
):
"""Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
gamma = (
min(s_churn / (len(sigmas) - 1), 2**0.5 - 1)
if s_tmin <= sigmas[i] <= s_tmax
else 0.0
)
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, denoised)
if callback is not None:
callback(
{
"x": x,
"i": i,
"sigma": sigmas[i],
"sigma_hat": sigma_hat,
"denoised": denoised,
}
)
dt = sigmas[i + 1] - sigma_hat
if sigmas[i + 1] == 0:
# Euler method
x = x + d * dt
else:
# Heun's method
x_2 = x + d * dt
denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
d_prime = (d + d_2) / 2
x = x + d_prime * dt
return x
@torch.no_grad()
def sample_dpm_2(
model,
x,
sigmas,
extra_args=None,
callback=None,
disable=None,
s_churn=0.0,
s_tmin=0.0,
s_tmax=float("inf"),
s_noise=1.0,
):
"""A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
gamma = (
min(s_churn / (len(sigmas) - 1), 2**0.5 - 1)
if s_tmin <= sigmas[i] <= s_tmax
else 0.0
)
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, denoised)
if callback is not None:
callback(
{
"x": x,
"i": i,
"sigma": sigmas[i],
"sigma_hat": sigma_hat,
"denoised": denoised,
}
)
# Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
sigma_mid = ((sigma_hat ** (1 / 3) + sigmas[i + 1] ** (1 / 3)) / 2) ** 3
dt_1 = sigma_mid - sigma_hat
dt_2 = sigmas[i + 1] - sigma_hat
x_2 = x + d * dt_1
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
d_2 = to_d(x_2, sigma_mid, denoised_2)
x = x + d_2 * dt_2
return x
@torch.no_grad()
def sample_dpm_2_ancestral(
model, x, sigmas, extra_args=None, callback=None, disable=None
):
"""Ancestral sampling with DPM-Solver inspired second-order steps."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
if callback is not None:
callback(
{
"x": x,
"i": i,
"sigma": sigmas[i],
"sigma_hat": sigmas[i],
"denoised": denoised,
}
)
d = to_d(x, sigmas[i], denoised)
# Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
sigma_mid = ((sigmas[i] ** (1 / 3) + sigma_down ** (1 / 3)) / 2) ** 3
dt_1 = sigma_mid - sigmas[i]
dt_2 = sigma_down - sigmas[i]
x_2 = x + d * dt_1
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
d_2 = to_d(x_2, sigma_mid, denoised_2)
x = x + d_2 * dt_2
x = x + torch.randn_like(x) * sigma_up
return x
def linear_multistep_coeff(order, t, i, j):
if order - 1 > i:
raise ValueError(f"Order {order} too high for step {i}")
def fn(tau):
prod = 1.0
for k in range(order):
if j == k:
continue
prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
return prod
return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0]
@torch.no_grad()
def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4):
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
sigmas_cpu = sigmas.detach().cpu().numpy()
ds = []
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
d = to_d(x, sigmas[i], denoised)
ds.append(d)
if len(ds) > order:
ds.pop(0)
if callback is not None:
callback(
{
"x": x,
"i": i,
"sigma": sigmas[i],
"sigma_hat": sigmas[i],
"denoised": denoised,
}
)
cur_order = min(i + 1, order)
coeffs = [
linear_multistep_coeff(cur_order, sigmas_cpu, i, j)
for j in range(cur_order)
]
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
return x
@torch.no_grad()
def log_likelihood(
model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4
):
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
v = torch.randint_like(x, 2) * 2 - 1
fevals = 0
def ode_fn(sigma, x):
nonlocal fevals
with torch.enable_grad():
x = x[0].detach().requires_grad_()
denoised = model(x, sigma * s_in, **extra_args)
d = to_d(x, sigma, denoised)
fevals += 1
grad = torch.autograd.grad((d * v).sum(), x)[0]
d_ll = (v * grad).flatten(1).sum(1)
return d.detach(), d_ll
x_min = x, x.new_zeros([x.shape[0]])
t = x.new_tensor([sigma_min, sigma_max])
sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method="dopri5")
latent, delta_ll = sol[0][-1], sol[1][-1]
ll_prior = (
torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1)
)
return ll_prior + delta_ll, {"fevals": fevals}

View File

@ -0,0 +1,221 @@
import math
import torch
from scipy import integrate
from torchdiffeq import odeint
from tqdm.auto import tqdm, trange
from . import utils
def append_zero(x):
return torch.cat([x, x.new_zeros([1])])
def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'):
"""Constructs the noise schedule of Karras et al. (2022)."""
ramp = torch.linspace(0, 1, n)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return append_zero(sigmas).to(device)
def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'):
"""Constructs an exponential noise schedule."""
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp()
return append_zero(sigmas)
def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
"""Constructs a continuous VP noise schedule."""
t = torch.linspace(1, eps_s, n, device=device)
sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1)
return append_zero(sigmas)
def to_d(x, sigma, denoised):
"""Converts a denoiser output to a Karras ODE derivative."""
return (x - denoised) / utils.append_dims(sigma, x.ndim)
def get_ancestral_step(sigma_from, sigma_to):
"""Calculates the noise level (sigma_down) to step down to and the amount
of noise to add (sigma_up) when doing an ancestral sampling step."""
sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5
sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
return sigma_down, sigma_up
@torch.no_grad()
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, denoised)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
dt = sigmas[i + 1] - sigma_hat
# Euler method
x = x + d * dt
return x
@torch.no_grad()
def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None):
"""Ancestral sampling with Euler method steps."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
d = to_d(x, sigmas[i], denoised)
# Euler method
dt = sigma_down - sigmas[i]
x = x + d * dt
x = x + torch.randn_like(x) * sigma_up
return x
@torch.no_grad()
def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
"""Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, denoised)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
dt = sigmas[i + 1] - sigma_hat
if sigmas[i + 1] == 0:
# Euler method
x = x + d * dt
else:
# Heun's method
x_2 = x + d * dt
denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
d_prime = (d + d_2) / 2
x = x + d_prime * dt
return x
@torch.no_grad()
def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
"""A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, denoised)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
# Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
sigma_mid = ((sigma_hat ** (1 / 3) + sigmas[i + 1] ** (1 / 3)) / 2) ** 3
dt_1 = sigma_mid - sigma_hat
dt_2 = sigmas[i + 1] - sigma_hat
x_2 = x + d * dt_1
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
d_2 = to_d(x_2, sigma_mid, denoised_2)
x = x + d_2 * dt_2
return x
@torch.no_grad()
def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None):
"""Ancestral sampling with DPM-Solver inspired second-order steps."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
d = to_d(x, sigmas[i], denoised)
# Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
sigma_mid = ((sigmas[i] ** (1 / 3) + sigma_down ** (1 / 3)) / 2) ** 3
dt_1 = sigma_mid - sigmas[i]
dt_2 = sigma_down - sigmas[i]
x_2 = x + d * dt_1
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
d_2 = to_d(x_2, sigma_mid, denoised_2)
x = x + d_2 * dt_2
x = x + torch.randn_like(x) * sigma_up
return x
def linear_multistep_coeff(order, t, i, j):
if order - 1 > i:
raise ValueError(f'Order {order} too high for step {i}')
def fn(tau):
prod = 1.
for k in range(order):
if j == k:
continue
prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
return prod
return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0]
@torch.no_grad()
def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4):
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
sigmas_cpu = sigmas.detach().cpu().numpy()
ds = []
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
d = to_d(x, sigmas[i], denoised)
ds.append(d)
if len(ds) > order:
ds.pop(0)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
cur_order = min(i + 1, order)
coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
return x
@torch.no_grad()
def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4):
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
v = torch.randint_like(x, 2) * 2 - 1
fevals = 0
def ode_fn(sigma, x):
nonlocal fevals
with torch.enable_grad():
x = x[0].detach().requires_grad_()
denoised = model(x, sigma * s_in, **extra_args)
d = to_d(x, sigma, denoised)
fevals += 1
grad = torch.autograd.grad((d * v).sum(), x)[0]
d_ll = (v * grad).flatten(1).sum(1)
return d.detach(), d_ll
x_min = x, x.new_zeros([x.shape[0]])
t = x.new_tensor([sigma_min, sigma_max])
sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5')
latent, delta_ll = sol[0][-1], sol[1][-1]
ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1)
return ll_prior + delta_ll, {'fevals': fevals}

View File

@ -0,0 +1,385 @@
import hashlib
import math
import shutil
import urllib
import warnings
from contextlib import contextmanager
from pathlib import Path
import torch
from PIL import Image
from torch import nn, optim
from torch.utils import data
from torchvision.transforms import functional as TF
def from_pil_image(x):
"""Converts from a PIL image to a tensor."""
x = TF.to_tensor(x)
if x.ndim == 2:
x = x[..., None]
return x * 2 - 1
def to_pil_image(x):
"""Converts from a tensor to a PIL image."""
if x.ndim == 4:
assert x.shape[0] == 1
x = x[0]
if x.shape[0] == 1:
x = x[0]
return TF.to_pil_image((x.clamp(-1, 1) + 1) / 2)
def hf_datasets_augs_helper(examples, transform, image_key, mode="RGB"):
"""Apply passed in transforms for HuggingFace Datasets."""
images = [transform(image.convert(mode)) for image in examples[image_key]]
return {image_key: images}
def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
)
return x[(...,) + (None,) * dims_to_append]
def n_params(module):
"""Returns the number of trainable parameters in a module."""
return sum(p.numel() for p in module.parameters())
def download_file(path, url, digest=None):
"""Downloads a file if it does not exist, optionally checking its SHA-256 hash."""
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
if not path.exists():
with urllib.request.urlopen(url) as response, open(path, "wb") as f:
shutil.copyfileobj(response, f)
if digest is not None:
file_digest = hashlib.sha256(open(path, "rb").read()).hexdigest()
if digest != file_digest:
raise OSError(f"hash of {path} (url: {url}) failed to validate")
return path
@contextmanager
def train_mode(model, mode=True):
"""A context manager that places a model into training mode and restores
the previous mode on exit."""
modes = [module.training for module in model.modules()]
try:
yield model.train(mode)
finally:
for i, module in enumerate(model.modules()):
module.training = modes[i]
def eval_mode(model):
"""A context manager that places a model into evaluation mode and restores
the previous mode on exit."""
return train_mode(model, False)
@torch.no_grad()
def ema_update(model, averaged_model, decay):
"""Incorporates updated model parameters into an exponential moving averaged
version of a model. It should be called after each optimizer step."""
model_params = dict(model.named_parameters())
averaged_params = dict(averaged_model.named_parameters())
assert model_params.keys() == averaged_params.keys()
for name, param in model_params.items():
averaged_params[name].mul_(decay).add_(param, alpha=1 - decay)
model_buffers = dict(model.named_buffers())
averaged_buffers = dict(averaged_model.named_buffers())
assert model_buffers.keys() == averaged_buffers.keys()
for name, buf in model_buffers.items():
averaged_buffers[name].copy_(buf)
class EMAWarmup:
"""Implements an EMA warmup using an inverse decay schedule.
If inv_gamma=1 and power=1, implements a simple average. inv_gamma=1, power=2/3 are
good values for models you plan to train for a million or more steps (reaches decay
factor 0.999 at 31.6K steps, 0.9999 at 1M steps), inv_gamma=1, power=3/4 for models
you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
215.4k steps).
Args:
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
power (float): Exponential factor of EMA warmup. Default: 1.
min_value (float): The minimum EMA decay rate. Default: 0.
max_value (float): The maximum EMA decay rate. Default: 1.
start_at (int): The epoch to start averaging at. Default: 0.
last_epoch (int): The index of last epoch. Default: 0.
"""
def __init__(
self,
inv_gamma=1.0,
power=1.0,
min_value=0.0,
max_value=1.0,
start_at=0,
last_epoch=0,
):
self.inv_gamma = inv_gamma
self.power = power
self.min_value = min_value
self.max_value = max_value
self.start_at = start_at
self.last_epoch = last_epoch
def state_dict(self):
"""Returns the state of the class as a :class:`dict`."""
return dict(self.__dict__.items())
def load_state_dict(self, state_dict):
"""Loads the class's state.
Args:
state_dict (dict): scaler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
self.__dict__.update(state_dict)
def get_value(self):
"""Gets the current EMA decay rate."""
epoch = max(0, self.last_epoch - self.start_at)
value = 1 - (1 + epoch / self.inv_gamma) ** -self.power
return 0.0 if epoch < 0 else min(self.max_value, max(self.min_value, value))
def step(self):
"""Updates the step count."""
self.last_epoch += 1
class InverseLR(optim.lr_scheduler._LRScheduler):
"""Implements an inverse decay learning rate schedule with an optional exponential
warmup. When last_epoch=-1, sets initial lr as lr.
inv_gamma is the number of steps/epochs required for the learning rate to decay to
(1 / 2)**power of its original value.
Args:
optimizer (Optimizer): Wrapped optimizer.
inv_gamma (float): Inverse multiplicative factor of learning rate decay. Default: 1.
power (float): Exponential factor of learning rate decay. Default: 1.
warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
Default: 0.
min_lr (float): The minimum learning rate. Default: 0.
last_epoch (int): The index of last epoch. Default: -1.
verbose (bool): If ``True``, prints a message to stdout for
each update. Default: ``False``.
"""
def __init__(
self,
optimizer,
inv_gamma=1.0,
power=1.0,
warmup=0.0,
min_lr=0.0,
last_epoch=-1,
verbose=False,
):
self.inv_gamma = inv_gamma
self.power = power
if not 0.0 <= warmup < 1:
raise ValueError("Invalid value for warmup")
self.warmup = warmup
self.min_lr = min_lr
super().__init__(optimizer, last_epoch, verbose)
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn(
"To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`."
)
return self._get_closed_form_lr()
def _get_closed_form_lr(self):
warmup = 1 - self.warmup ** (self.last_epoch + 1)
lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power
return [
warmup * max(self.min_lr, base_lr * lr_mult) for base_lr in self.base_lrs
]
class ExponentialLR(optim.lr_scheduler._LRScheduler):
"""Implements an exponential learning rate schedule with an optional exponential
warmup. When last_epoch=-1, sets initial lr as lr. Decays the learning rate
continuously by decay (default 0.5) every num_steps steps.
Args:
optimizer (Optimizer): Wrapped optimizer.
num_steps (float): The number of steps to decay the learning rate by decay in.
decay (float): The factor by which to decay the learning rate every num_steps
steps. Default: 0.5.
warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
Default: 0.
min_lr (float): The minimum learning rate. Default: 0.
last_epoch (int): The index of last epoch. Default: -1.
verbose (bool): If ``True``, prints a message to stdout for
each update. Default: ``False``.
"""
def __init__(
self,
optimizer,
num_steps,
decay=0.5,
warmup=0.0,
min_lr=0.0,
last_epoch=-1,
verbose=False,
):
self.num_steps = num_steps
self.decay = decay
if not 0.0 <= warmup < 1:
raise ValueError("Invalid value for warmup")
self.warmup = warmup
self.min_lr = min_lr
super().__init__(optimizer, last_epoch, verbose)
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn(
"To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`."
)
return self._get_closed_form_lr()
def _get_closed_form_lr(self):
warmup = 1 - self.warmup ** (self.last_epoch + 1)
lr_mult = (self.decay ** (1 / self.num_steps)) ** self.last_epoch
return [
warmup * max(self.min_lr, base_lr * lr_mult) for base_lr in self.base_lrs
]
def rand_log_normal(shape, loc=0.0, scale=1.0, device="cpu", dtype=torch.float32):
"""Draws samples from an lognormal distribution."""
return (torch.randn(shape, device=device, dtype=dtype) * scale + loc).exp()
def rand_log_logistic(
shape,
loc=0.0,
scale=1.0,
min_value=0.0,
max_value=float("inf"),
device="cpu",
dtype=torch.float32,
):
"""Draws samples from an optionally truncated log-logistic distribution."""
min_value = torch.as_tensor(min_value, device=device, dtype=torch.float64)
max_value = torch.as_tensor(max_value, device=device, dtype=torch.float64)
min_cdf = min_value.log().sub(loc).div(scale).sigmoid()
max_cdf = max_value.log().sub(loc).div(scale).sigmoid()
u = (
torch.rand(shape, device=device, dtype=torch.float64) * (max_cdf - min_cdf)
+ min_cdf
)
return u.logit().mul(scale).add(loc).exp().to(dtype)
def rand_log_uniform(shape, min_value, max_value, device="cpu", dtype=torch.float32):
"""Draws samples from an log-uniform distribution."""
min_value = math.log(min_value)
max_value = math.log(max_value)
return (
torch.rand(shape, device=device, dtype=dtype) * (max_value - min_value)
+ min_value
).exp()
def rand_v_diffusion(
shape,
sigma_data=1.0,
min_value=0.0,
max_value=float("inf"),
device="cpu",
dtype=torch.float32,
):
"""Draws samples from a truncated v-diffusion training timestep distribution."""
min_cdf = math.atan(min_value / sigma_data) * 2 / math.pi
max_cdf = math.atan(max_value / sigma_data) * 2 / math.pi
u = torch.rand(shape, device=device, dtype=dtype) * (max_cdf - min_cdf) + min_cdf
return torch.tan(u * math.pi / 2) * sigma_data
class FolderOfImages(data.Dataset):
"""Recursively finds all images in a directory. It does not support
classes/targets."""
IMG_EXTENSIONS = {
".jpg",
".jpeg",
".png",
".ppm",
".bmp",
".pgm",
".tif",
".tiff",
".webp",
}
def __init__(self, root, transform=None):
super().__init__()
self.root = Path(root)
self.transform = nn.Identity() if transform is None else transform
self.paths = sorted(
path
for path in self.root.rglob("*")
if path.suffix.lower() in self.IMG_EXTENSIONS
)
def __repr__(self):
return f'FolderOfImages(root="{self.root}", len: {len(self)})'
def __len__(self):
return len(self.paths)
def __getitem__(self, key):
path = self.paths[key]
with open(path, "rb") as f:
image = Image.open(f).convert("RGB")
image = self.transform(image)
return (image,)
class CSVLogger:
def __init__(self, filename, columns):
self.filename = Path(filename)
self.columns = columns
if self.filename.exists():
self.file = open(self.filename, "a")
else:
self.file = open(self.filename, "w")
self.write(*self.columns)
def write(self, *args):
print(*args, sep=",", file=self.file, flush=True)
@contextmanager
def tf32_mode(cudnn=None, matmul=None):
"""A context manager that sets whether TF32 is allowed on cuDNN or matmul."""
cudnn_old = torch.backends.cudnn.allow_tf32
matmul_old = torch.backends.cuda.matmul.allow_tf32
try:
if cudnn is not None:
torch.backends.cudnn.allow_tf32 = cudnn
if matmul is not None:
torch.backends.cuda.matmul.allow_tf32 = matmul
yield
finally:
if cudnn is not None:
torch.backends.cudnn.allow_tf32 = cudnn_old
if matmul is not None:
torch.backends.cuda.matmul.allow_tf32 = matmul_old

View File

@ -32,7 +32,7 @@ black==22.8.0
# via -r requirements-dev.in
cachetools==5.2.0
# via google-auth
certifi==2022.6.15.1
certifi==2022.6.15.2
# via requests
charset-normalizer==2.1.1
# via
@ -98,7 +98,7 @@ huggingface-hub==0.9.1
# via
# diffusers
# transformers
idna==3.3
idna==3.4
# via
# requests
# yarl
@ -209,7 +209,7 @@ platformdirs==2.5.2
# pylint
pluggy==1.0.0
# via pytest
protobuf==3.19.4
protobuf==3.19.5
# via
# tb-nightly
# tensorboard
@ -257,7 +257,7 @@ pyyaml==6.0
# transformers
realesrgan==0.2.5.0
# via imaginAIry (setup.py)
regex==2022.9.11
regex==2022.9.13
# via
# diffusers
# transformers
@ -285,6 +285,7 @@ scipy==1.9.1
# filterpy
# gfpgan
# scikit-image
# torchdiffeq
six==1.16.0
# via
# google-auth
@ -292,7 +293,7 @@ six==1.16.0
# python-dateutil
snowballstemmer==2.2.0
# via pydocstyle
tb-nightly==2.11.0a20220912
tb-nightly==2.11.0a20220913
# via
# basicsr
# gfpgan
@ -327,8 +328,11 @@ torch==1.12.1
# kornia
# pytorch-lightning
# realesrgan
# torchdiffeq
# torchmetrics
# torchvision
torchdiffeq==0.2.3
# via imaginAIry (setup.py)
torchmetrics==0.6.0
# via
# imaginAIry (setup.py)

View File

@ -31,6 +31,7 @@ setup(
"pytorch-lightning==1.4.2",
"omegaconf==2.1.1",
"einops==0.3.0",
"torchdiffeq",
"transformers==4.19.2",
"torchmetrics==0.6.0",
"torchvision>=0.13.1",