mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-10-31 03:20:40 +00:00
4610d7f01d
add more upscaling code (that doesn't yet work)
106 lines
3.5 KiB
Python
106 lines
3.5 KiB
Python
from functools import partial
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch import nn
|
|
|
|
from imaginairy.modules.diffusion.util import extract_into_tensor, make_beta_schedule
|
|
|
|
|
|
class AbstractLowScaleModel(nn.Module):
|
|
# for concatenating a downsampled image to the latent representation
|
|
def __init__(self, noise_schedule_config=None):
|
|
super().__init__()
|
|
if noise_schedule_config is not None:
|
|
self.register_schedule(**noise_schedule_config)
|
|
|
|
def register_schedule(
|
|
self,
|
|
beta_schedule="linear",
|
|
timesteps=1000,
|
|
linear_start=1e-4,
|
|
linear_end=2e-2,
|
|
cosine_s=8e-3,
|
|
):
|
|
betas = make_beta_schedule(
|
|
beta_schedule,
|
|
timesteps,
|
|
linear_start=linear_start,
|
|
linear_end=linear_end,
|
|
cosine_s=cosine_s,
|
|
)
|
|
alphas = 1.0 - betas
|
|
alphas_cumprod = np.cumprod(alphas, axis=0)
|
|
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
|
|
|
|
(timesteps,) = betas.shape
|
|
self.num_timesteps = int(timesteps)
|
|
self.linear_start = linear_start
|
|
self.linear_end = linear_end
|
|
assert (
|
|
alphas_cumprod.shape[0] == self.num_timesteps
|
|
), "alphas have to be defined for each timestep"
|
|
|
|
to_torch = partial(torch.tensor, dtype=torch.float32)
|
|
|
|
self.register_buffer("betas", to_torch(betas))
|
|
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
|
|
self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
|
|
|
|
# calculations for diffusion q(x_t | x_{t-1}) and others
|
|
self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
|
|
self.register_buffer(
|
|
"sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
|
|
)
|
|
self.register_buffer(
|
|
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
|
|
)
|
|
self.register_buffer(
|
|
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
|
|
)
|
|
self.register_buffer(
|
|
"sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
|
|
)
|
|
|
|
def q_sample(self, x_start, t, noise=None):
|
|
if noise is None:
|
|
noise = torch.randn_like(x_start)
|
|
return (
|
|
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
|
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
|
|
* noise
|
|
)
|
|
|
|
def forward(self, x):
|
|
return x, None
|
|
|
|
def decode(self, x):
|
|
return x
|
|
|
|
|
|
class SimpleImageConcat(AbstractLowScaleModel):
|
|
# no noise level conditioning
|
|
def __init__(self):
|
|
super().__init__(noise_schedule_config=None)
|
|
self.max_noise_level = 0
|
|
|
|
def forward(self, x):
|
|
# fix to constant noise level
|
|
return x, torch.zeros(x.shape[0], device=x.device).long()
|
|
|
|
|
|
class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
|
|
def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
|
|
super().__init__(noise_schedule_config=noise_schedule_config)
|
|
self.max_noise_level = max_noise_level
|
|
|
|
def forward(self, x, noise_level=None):
|
|
if noise_level is None:
|
|
noise_level = torch.randint(
|
|
0, self.max_noise_level, (x.shape[0],), device=x.device
|
|
).long()
|
|
else:
|
|
assert isinstance(noise_level, torch.Tensor)
|
|
z = self.q_sample(x, noise_level)
|
|
return z, noise_level
|