imaginAIry/imaginairy/vendored/k_diffusion/external.py

207 lines
7.1 KiB
Python
Raw Normal View History

import math
import torch
from torch import nn
from . import sampling, utils
class VDenoiser(nn.Module):
"""A v-diffusion-pytorch model wrapper for k-diffusion."""
def __init__(self, inner_model):
super().__init__()
self.inner_model = inner_model
self.sigma_data = 1.0
def get_scalings(self, sigma):
c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
c_out = -sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
return c_skip, c_out, c_in
def sigma_to_t(self, sigma):
return sigma.atan() / math.pi * 2
def t_to_sigma(self, t):
return (t * math.pi / 2).tan()
def loss(self, input, noise, sigma, **kwargs):
2023-01-09 04:45:58 +00:00
c_skip, c_out, c_in = (
utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)
2023-01-09 04:45:58 +00:00
)
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
model_output = self.inner_model(
noised_input * c_in, self.sigma_to_t(sigma), **kwargs
)
target = (input - c_skip * noised_input) / c_out
return (model_output - target).pow(2).flatten(1).mean(1)
def forward(self, input, sigma, **kwargs):
2023-01-09 04:45:58 +00:00
c_skip, c_out, c_in = (
utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)
2023-01-09 04:45:58 +00:00
)
return (
self.inner_model(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out
+ input * c_skip
)
class DiscreteSchedule(nn.Module):
"""A mapping between continuous noise levels (sigmas) and a list of discrete noise
levels."""
def __init__(self, sigmas, quantize):
super().__init__()
self.register_buffer("sigmas", sigmas)
self.register_buffer("log_sigmas", sigmas.log())
self.quantize = quantize
@property
def sigma_min(self):
return self.sigmas[0]
@property
def sigma_max(self):
return self.sigmas[-1]
def get_sigmas(self, n=None):
if n is None:
return sampling.append_zero(self.sigmas.flip(0))
t_max = len(self.sigmas) - 1
t = torch.linspace(t_max, 0, n, device=self.sigmas.device)
return sampling.append_zero(self.t_to_sigma(t))
def sigma_to_t(self, sigma, quantize=None):
quantize = self.quantize if quantize is None else quantize
log_sigma = sigma.log()
dists = log_sigma - self.log_sigmas[:, None]
if quantize:
return dists.abs().argmin(dim=0).view(sigma.shape)
low_idx = (
dists.ge(0)
.cumsum(dim=0)
.argmax(dim=0)
.clamp(max=self.log_sigmas.shape[0] - 2)
)
high_idx = low_idx + 1
low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx]
w = (low - log_sigma) / (low - high)
w = w.clamp(0, 1)
t = (1 - w) * low_idx + w * high_idx
return t.view(sigma.shape)
def t_to_sigma(self, t):
t = t.float()
low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
return log_sigma.exp()
class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
"""A wrapper for discrete schedule DDPM models that output eps (the predicted
noise)."""
def __init__(self, model, alphas_cumprod, quantize):
super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
self.inner_model = model
self.sigma_data = 1.0
def get_scalings(self, sigma):
c_out = -sigma
c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
return c_out, c_in
def get_eps(self, *args, **kwargs):
return self.inner_model(*args, **kwargs)
def loss(self, input, noise, sigma, **kwargs):
2023-01-09 04:45:58 +00:00
c_out, c_in = (
utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)
2023-01-09 04:45:58 +00:00
)
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
eps = self.get_eps(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
return (eps - noise).pow(2).flatten(1).mean(1)
def forward(self, input, sigma, **kwargs):
2023-01-09 04:45:58 +00:00
c_out, c_in = (
utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)
2023-01-09 04:45:58 +00:00
)
eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)
return input + eps * c_out
class OpenAIDenoiser(DiscreteEpsDDPMDenoiser):
"""A wrapper for OpenAI diffusion models."""
def __init__(
self, model, diffusion, quantize=False, has_learned_sigmas=True, device="cpu"
):
alphas_cumprod = torch.tensor(
diffusion.alphas_cumprod, device=device, dtype=torch.float32
)
super().__init__(model, alphas_cumprod, quantize=quantize)
self.has_learned_sigmas = has_learned_sigmas
def get_eps(self, *args, **kwargs):
model_output = self.inner_model(*args, **kwargs)
if self.has_learned_sigmas:
return model_output.chunk(2, dim=1)[0]
return model_output
class CompVisDenoiser(DiscreteEpsDDPMDenoiser):
"""A wrapper for CompVis diffusion models."""
def __init__(self, model, quantize=False, device="cpu"):
super().__init__(model, model.alphas_cumprod, quantize=quantize)
def get_eps(self, *args, **kwargs):
return self.inner_model.apply_model(*args, **kwargs)
class DiscreteVDDPMDenoiser(DiscreteSchedule):
"""A wrapper for discrete schedule DDPM models that output v."""
def __init__(self, model, alphas_cumprod, quantize):
super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
self.inner_model = model
self.sigma_data = 1.0
def get_scalings(self, sigma):
c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
c_out = -sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
return c_skip, c_out, c_in
def get_v(self, *args, **kwargs):
return self.inner_model(*args, **kwargs)
def loss(self, input, noise, sigma, **kwargs):
2023-01-09 04:45:58 +00:00
c_skip, c_out, c_in = (
utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)
2023-01-09 04:45:58 +00:00
)
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
model_output = self.get_v(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
target = (input - c_skip * noised_input) / c_out
return (model_output - target).pow(2).flatten(1).mean(1)
def forward(self, input, sigma, **kwargs):
2023-01-09 04:45:58 +00:00
c_skip, c_out, c_in = (
utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)
2023-01-09 04:45:58 +00:00
)
return (
self.get_v(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out
+ input * c_skip
)
class CompVisVDenoiser(DiscreteVDDPMDenoiser):
"""A wrapper for CompVis diffusion models that output v."""
def __init__(self, model, quantize=False, device="cpu"):
super().__init__(model, model.alphas_cumprod, quantize=quantize)
def get_v(self, x, t, cond, **kwargs):
return self.inner_model.apply_model(x, t, cond)