|
|
@ -43,33 +43,33 @@ def uniform_on_device(r1, r2, shape, device):
|
|
|
|
class DDPM(pl.LightningModule):
|
|
|
|
class DDPM(pl.LightningModule):
|
|
|
|
# classic DDPM with Gaussian diffusion, in image space
|
|
|
|
# classic DDPM with Gaussian diffusion, in image space
|
|
|
|
def __init__(
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
self,
|
|
|
|
unet_config,
|
|
|
|
unet_config,
|
|
|
|
timesteps=1000,
|
|
|
|
timesteps=1000,
|
|
|
|
beta_schedule="linear",
|
|
|
|
beta_schedule="linear",
|
|
|
|
loss_type="l2",
|
|
|
|
loss_type="l2",
|
|
|
|
ckpt_path=None,
|
|
|
|
ckpt_path=None,
|
|
|
|
ignore_keys=[],
|
|
|
|
ignore_keys=[],
|
|
|
|
load_only_unet=False,
|
|
|
|
load_only_unet=False,
|
|
|
|
monitor="val/loss",
|
|
|
|
monitor="val/loss",
|
|
|
|
first_stage_key="image",
|
|
|
|
first_stage_key="image",
|
|
|
|
image_size=256,
|
|
|
|
image_size=256,
|
|
|
|
channels=3,
|
|
|
|
channels=3,
|
|
|
|
log_every_t=100,
|
|
|
|
log_every_t=100,
|
|
|
|
clip_denoised=True,
|
|
|
|
clip_denoised=True,
|
|
|
|
linear_start=1e-4,
|
|
|
|
linear_start=1e-4,
|
|
|
|
linear_end=2e-2,
|
|
|
|
linear_end=2e-2,
|
|
|
|
cosine_s=8e-3,
|
|
|
|
cosine_s=8e-3,
|
|
|
|
given_betas=None,
|
|
|
|
given_betas=None,
|
|
|
|
original_elbo_weight=0.0,
|
|
|
|
original_elbo_weight=0.0,
|
|
|
|
v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
|
|
|
|
v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
|
|
|
|
l_simple_weight=1.0,
|
|
|
|
l_simple_weight=1.0,
|
|
|
|
conditioning_key=None,
|
|
|
|
conditioning_key=None,
|
|
|
|
parameterization="eps", # all assuming fixed variance schedules
|
|
|
|
parameterization="eps", # all assuming fixed variance schedules
|
|
|
|
scheduler_config=None,
|
|
|
|
scheduler_config=None,
|
|
|
|
use_positional_encodings=False,
|
|
|
|
use_positional_encodings=False,
|
|
|
|
learn_logvar=False,
|
|
|
|
learn_logvar=False,
|
|
|
|
logvar_init=0.0,
|
|
|
|
logvar_init=0.0,
|
|
|
|
):
|
|
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
assert parameterization in [
|
|
|
|
assert parameterization in [
|
|
|
@ -122,13 +122,13 @@ class DDPM(pl.LightningModule):
|
|
|
|
self.logvar = nn.Parameter(self.logvar, requires_grad=True)
|
|
|
|
self.logvar = nn.Parameter(self.logvar, requires_grad=True)
|
|
|
|
|
|
|
|
|
|
|
|
def register_schedule(
|
|
|
|
def register_schedule(
|
|
|
|
self,
|
|
|
|
self,
|
|
|
|
given_betas=None,
|
|
|
|
given_betas=None,
|
|
|
|
beta_schedule="linear",
|
|
|
|
beta_schedule="linear",
|
|
|
|
timesteps=1000,
|
|
|
|
timesteps=1000,
|
|
|
|
linear_start=1e-4,
|
|
|
|
linear_start=1e-4,
|
|
|
|
linear_end=2e-2,
|
|
|
|
linear_end=2e-2,
|
|
|
|
cosine_s=8e-3,
|
|
|
|
cosine_s=8e-3,
|
|
|
|
):
|
|
|
|
):
|
|
|
|
if given_betas is not None:
|
|
|
|
if given_betas is not None:
|
|
|
|
betas = given_betas
|
|
|
|
betas = given_betas
|
|
|
@ -149,7 +149,7 @@ class DDPM(pl.LightningModule):
|
|
|
|
self.linear_start = linear_start
|
|
|
|
self.linear_start = linear_start
|
|
|
|
self.linear_end = linear_end
|
|
|
|
self.linear_end = linear_end
|
|
|
|
assert (
|
|
|
|
assert (
|
|
|
|
alphas_cumprod.shape[0] == self.num_timesteps
|
|
|
|
alphas_cumprod.shape[0] == self.num_timesteps
|
|
|
|
), "alphas have to be defined for each timestep"
|
|
|
|
), "alphas have to be defined for each timestep"
|
|
|
|
|
|
|
|
|
|
|
|
to_torch = partial(torch.tensor, dtype=torch.float32)
|
|
|
|
to_torch = partial(torch.tensor, dtype=torch.float32)
|
|
|
@ -175,7 +175,7 @@ class DDPM(pl.LightningModule):
|
|
|
|
|
|
|
|
|
|
|
|
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
|
|
|
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
|
|
|
posterior_variance = (1 - self.v_posterior) * betas * (
|
|
|
|
posterior_variance = (1 - self.v_posterior) * betas * (
|
|
|
|
1.0 - alphas_cumprod_prev
|
|
|
|
1.0 - alphas_cumprod_prev
|
|
|
|
) / (1.0 - alphas_cumprod) + self.v_posterior * betas
|
|
|
|
) / (1.0 - alphas_cumprod) + self.v_posterior * betas
|
|
|
|
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
|
|
|
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
|
|
|
self.register_buffer("posterior_variance", to_torch(posterior_variance))
|
|
|
|
self.register_buffer("posterior_variance", to_torch(posterior_variance))
|
|
|
@ -196,17 +196,17 @@ class DDPM(pl.LightningModule):
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if self.parameterization == "eps":
|
|
|
|
if self.parameterization == "eps":
|
|
|
|
lvlb_weights = self.betas ** 2 / (
|
|
|
|
lvlb_weights = self.betas**2 / (
|
|
|
|
2
|
|
|
|
2
|
|
|
|
* self.posterior_variance
|
|
|
|
* self.posterior_variance
|
|
|
|
* to_torch(alphas)
|
|
|
|
* to_torch(alphas)
|
|
|
|
* (1 - self.alphas_cumprod)
|
|
|
|
* (1 - self.alphas_cumprod)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
elif self.parameterization == "x0":
|
|
|
|
elif self.parameterization == "x0":
|
|
|
|
lvlb_weights = (
|
|
|
|
lvlb_weights = (
|
|
|
|
0.5
|
|
|
|
0.5
|
|
|
|
* np.sqrt(torch.Tensor(alphas_cumprod))
|
|
|
|
* np.sqrt(torch.Tensor(alphas_cumprod))
|
|
|
|
/ (2.0 * 1 - torch.Tensor(alphas_cumprod))
|
|
|
|
/ (2.0 * 1 - torch.Tensor(alphas_cumprod))
|
|
|
|
)
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
raise NotImplementedError("mu not supported")
|
|
|
|
raise NotImplementedError("mu not supported")
|
|
|
@ -216,26 +216,27 @@ class DDPM(pl.LightningModule):
|
|
|
|
assert not torch.isnan(self.lvlb_weights).all()
|
|
|
|
assert not torch.isnan(self.lvlb_weights).all()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LatentDiffusion(DDPM):
|
|
|
|
class LatentDiffusion(DDPM):
|
|
|
|
"""main class"""
|
|
|
|
"""main class"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
self,
|
|
|
|
first_stage_config,
|
|
|
|
first_stage_config,
|
|
|
|
cond_stage_config,
|
|
|
|
cond_stage_config,
|
|
|
|
num_timesteps_cond=None,
|
|
|
|
num_timesteps_cond=None,
|
|
|
|
cond_stage_key="image",
|
|
|
|
cond_stage_key="image",
|
|
|
|
cond_stage_trainable=False,
|
|
|
|
cond_stage_trainable=False,
|
|
|
|
concat_mode=True,
|
|
|
|
concat_mode=True,
|
|
|
|
cond_stage_forward=None,
|
|
|
|
cond_stage_forward=None,
|
|
|
|
conditioning_key=None,
|
|
|
|
conditioning_key=None,
|
|
|
|
scale_factor=1.0,
|
|
|
|
scale_factor=1.0,
|
|
|
|
scale_by_std=False,
|
|
|
|
scale_by_std=False,
|
|
|
|
*args,
|
|
|
|
*args,
|
|
|
|
**kwargs,
|
|
|
|
**kwargs,
|
|
|
|
):
|
|
|
|
):
|
|
|
|
self.num_timesteps_cond = 1 if num_timesteps_cond is None else num_timesteps_cond
|
|
|
|
self.num_timesteps_cond = (
|
|
|
|
|
|
|
|
1 if num_timesteps_cond is None else num_timesteps_cond
|
|
|
|
|
|
|
|
)
|
|
|
|
self.scale_by_std = scale_by_std
|
|
|
|
self.scale_by_std = scale_by_std
|
|
|
|
assert self.num_timesteps_cond <= kwargs["timesteps"]
|
|
|
|
assert self.num_timesteps_cond <= kwargs["timesteps"]
|
|
|
|
# for backwards compatibility after implementation of DiffusionWrapper
|
|
|
|
# for backwards compatibility after implementation of DiffusionWrapper
|
|
|
@ -269,7 +270,7 @@ class LatentDiffusion(DDPM):
|
|
|
|
self.restarted_from_ckpt = True
|
|
|
|
self.restarted_from_ckpt = True
|
|
|
|
|
|
|
|
|
|
|
|
def make_cond_schedule(
|
|
|
|
def make_cond_schedule(
|
|
|
|
self,
|
|
|
|
self,
|
|
|
|
):
|
|
|
|
):
|
|
|
|
self.cond_ids = torch.full(
|
|
|
|
self.cond_ids = torch.full(
|
|
|
|
size=(self.num_timesteps,),
|
|
|
|
size=(self.num_timesteps,),
|
|
|
@ -282,13 +283,13 @@ class LatentDiffusion(DDPM):
|
|
|
|
self.cond_ids[: self.num_timesteps_cond] = ids
|
|
|
|
self.cond_ids[: self.num_timesteps_cond] = ids
|
|
|
|
|
|
|
|
|
|
|
|
def register_schedule(
|
|
|
|
def register_schedule(
|
|
|
|
self,
|
|
|
|
self,
|
|
|
|
given_betas=None,
|
|
|
|
given_betas=None,
|
|
|
|
beta_schedule="linear",
|
|
|
|
beta_schedule="linear",
|
|
|
|
timesteps=1000,
|
|
|
|
timesteps=1000,
|
|
|
|
linear_start=1e-4,
|
|
|
|
linear_start=1e-4,
|
|
|
|
linear_end=2e-2,
|
|
|
|
linear_end=2e-2,
|
|
|
|
cosine_s=8e-3,
|
|
|
|
cosine_s=8e-3,
|
|
|
|
):
|
|
|
|
):
|
|
|
|
super().register_schedule(
|
|
|
|
super().register_schedule(
|
|
|
|
given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s
|
|
|
|
given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s
|
|
|
@ -327,7 +328,7 @@ class LatentDiffusion(DDPM):
|
|
|
|
self.cond_stage_model = model
|
|
|
|
self.cond_stage_model = model
|
|
|
|
|
|
|
|
|
|
|
|
def _get_denoise_row_from_list(
|
|
|
|
def _get_denoise_row_from_list(
|
|
|
|
self, samples, desc="", force_no_decoder_quantization=False
|
|
|
|
self, samples, desc="", force_no_decoder_quantization=False
|
|
|
|
):
|
|
|
|
):
|
|
|
|
denoise_row = []
|
|
|
|
denoise_row = []
|
|
|
|
for zd in tqdm(samples, desc=desc):
|
|
|
|
for zd in tqdm(samples, desc=desc):
|
|
|
@ -357,7 +358,7 @@ class LatentDiffusion(DDPM):
|
|
|
|
def get_learned_conditioning(self, c):
|
|
|
|
def get_learned_conditioning(self, c):
|
|
|
|
if self.cond_stage_forward is None:
|
|
|
|
if self.cond_stage_forward is None:
|
|
|
|
if hasattr(self.cond_stage_model, "encode") and callable(
|
|
|
|
if hasattr(self.cond_stage_model, "encode") and callable(
|
|
|
|
self.cond_stage_model.encode
|
|
|
|
self.cond_stage_model.encode
|
|
|
|
):
|
|
|
|
):
|
|
|
|
c = self.cond_stage_model.encode(c)
|
|
|
|
c = self.cond_stage_model.encode(c)
|
|
|
|
if isinstance(c, DiagonalGaussianDistribution):
|
|
|
|
if isinstance(c, DiagonalGaussianDistribution):
|
|
|
@ -414,7 +415,7 @@ class LatentDiffusion(DDPM):
|
|
|
|
return weighting
|
|
|
|
return weighting
|
|
|
|
|
|
|
|
|
|
|
|
def get_fold_unfold(
|
|
|
|
def get_fold_unfold(
|
|
|
|
self, x, kernel_size, stride, uf=1, df=1
|
|
|
|
self, x, kernel_size, stride, uf=1, df=1
|
|
|
|
): # todo load once not every time, shorten code
|
|
|
|
): # todo load once not every time, shorten code
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
:param x: img of size (bs, c, h, w)
|
|
|
|
:param x: img of size (bs, c, h, w)
|
|
|
@ -499,14 +500,14 @@ class LatentDiffusion(DDPM):
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
@torch.no_grad()
|
|
|
|
def get_input(
|
|
|
|
def get_input(
|
|
|
|
self,
|
|
|
|
self,
|
|
|
|
batch,
|
|
|
|
batch,
|
|
|
|
k,
|
|
|
|
k,
|
|
|
|
return_first_stage_outputs=False,
|
|
|
|
return_first_stage_outputs=False,
|
|
|
|
force_c_encode=False,
|
|
|
|
force_c_encode=False,
|
|
|
|
cond_key=None,
|
|
|
|
cond_key=None,
|
|
|
|
return_original_cond=False,
|
|
|
|
return_original_cond=False,
|
|
|
|
bs=None,
|
|
|
|
bs=None,
|
|
|
|
):
|
|
|
|
):
|
|
|
|
x = super().get_input(batch, k)
|
|
|
|
x = super().get_input(batch, k)
|
|
|
|
if bs is not None:
|
|
|
|
if bs is not None:
|
|
|
@ -631,6 +632,52 @@ class LatentDiffusion(DDPM):
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
return self.first_stage_model.decode(z)
|
|
|
|
return self.first_stage_model.decode(z)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
|
|
|
def encode_first_stage(self, x):
|
|
|
|
|
|
|
|
if hasattr(self, "split_input_params"):
|
|
|
|
|
|
|
|
if self.split_input_params["patch_distributed_vq"]:
|
|
|
|
|
|
|
|
ks = self.split_input_params["ks"] # eg. (128, 128)
|
|
|
|
|
|
|
|
stride = self.split_input_params["stride"] # eg. (64, 64)
|
|
|
|
|
|
|
|
df = self.split_input_params["vqf"]
|
|
|
|
|
|
|
|
self.split_input_params["original_image_size"] = x.shape[-2:]
|
|
|
|
|
|
|
|
bs, nc, h, w = x.shape
|
|
|
|
|
|
|
|
if ks[0] > h or ks[1] > w:
|
|
|
|
|
|
|
|
ks = (min(ks[0], h), min(ks[1], w))
|
|
|
|
|
|
|
|
print("reducing Kernel")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if stride[0] > h or stride[1] > w:
|
|
|
|
|
|
|
|
stride = (min(stride[0], h), min(stride[1], w))
|
|
|
|
|
|
|
|
print("reducing stride")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fold, unfold, normalization, weighting = self.get_fold_unfold(
|
|
|
|
|
|
|
|
x, ks, stride, df=df
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
z = unfold(x) # (bn, nc * prod(**ks), L)
|
|
|
|
|
|
|
|
# Reshape to img shape
|
|
|
|
|
|
|
|
z = z.view(
|
|
|
|
|
|
|
|
(z.shape[0], -1, ks[0], ks[1], z.shape[-1])
|
|
|
|
|
|
|
|
) # (bn, nc, ks[0], ks[1], L )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output_list = [
|
|
|
|
|
|
|
|
self.first_stage_model.encode(z[:, :, :, :, i])
|
|
|
|
|
|
|
|
for i in range(z.shape[-1])
|
|
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
o = torch.stack(output_list, axis=-1)
|
|
|
|
|
|
|
|
o = o * weighting
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Reverse reshape to img shape
|
|
|
|
|
|
|
|
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
|
|
|
|
|
|
|
|
# stitch crops together
|
|
|
|
|
|
|
|
decoded = fold(o)
|
|
|
|
|
|
|
|
decoded = decoded / normalization
|
|
|
|
|
|
|
|
return decoded
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
return self.first_stage_model.encode(x)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
return self.first_stage_model.encode(x)
|
|
|
|
|
|
|
|
|
|
|
|
def apply_model(self, x_noisy, t, cond, return_ids=False):
|
|
|
|
def apply_model(self, x_noisy, t, cond, return_ids=False):
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(cond, dict):
|
|
|
|
if isinstance(cond, dict):
|
|
|
@ -664,8 +711,8 @@ class LatentDiffusion(DDPM):
|
|
|
|
z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]
|
|
|
|
z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]
|
|
|
|
|
|
|
|
|
|
|
|
if (
|
|
|
|
if (
|
|
|
|
self.cond_stage_key in ["image", "LR_image", "segmentation", "bbox_img"]
|
|
|
|
self.cond_stage_key in ["image", "LR_image", "segmentation", "bbox_img"]
|
|
|
|
and self.model.conditioning_key
|
|
|
|
and self.model.conditioning_key
|
|
|
|
): # todo check for completeness
|
|
|
|
): # todo check for completeness
|
|
|
|
c_key = next(iter(cond.keys())) # get key
|
|
|
|
c_key = next(iter(cond.keys())) # get key
|
|
|
|
c = next(iter(cond.values())) # get value
|
|
|
|
c = next(iter(cond.values())) # get value
|
|
|
@ -681,7 +728,7 @@ class LatentDiffusion(DDPM):
|
|
|
|
|
|
|
|
|
|
|
|
elif self.cond_stage_key == "coordinates_bbox":
|
|
|
|
elif self.cond_stage_key == "coordinates_bbox":
|
|
|
|
assert (
|
|
|
|
assert (
|
|
|
|
"original_image_size" in self.split_input_params
|
|
|
|
"original_image_size" in self.split_input_params
|
|
|
|
), "BoudingBoxRescaling is missing original_image_size"
|
|
|
|
), "BoudingBoxRescaling is missing original_image_size"
|
|
|
|
|
|
|
|
|
|
|
|
# assuming padding of unfold is always 0 and its dilation is always 1
|
|
|
|
# assuming padding of unfold is always 0 and its dilation is always 1
|
|
|
@ -776,16 +823,16 @@ class LatentDiffusion(DDPM):
|
|
|
|
return x_recon
|
|
|
|
return x_recon
|
|
|
|
|
|
|
|
|
|
|
|
def p_mean_variance(
|
|
|
|
def p_mean_variance(
|
|
|
|
self,
|
|
|
|
self,
|
|
|
|
x,
|
|
|
|
x,
|
|
|
|
c,
|
|
|
|
c,
|
|
|
|
t,
|
|
|
|
t,
|
|
|
|
clip_denoised: bool,
|
|
|
|
clip_denoised: bool,
|
|
|
|
return_codebook_ids=False,
|
|
|
|
return_codebook_ids=False,
|
|
|
|
quantize_denoised=False,
|
|
|
|
quantize_denoised=False,
|
|
|
|
return_x0=False,
|
|
|
|
return_x0=False,
|
|
|
|
score_corrector=None,
|
|
|
|
score_corrector=None,
|
|
|
|
corrector_kwargs=None,
|
|
|
|
corrector_kwargs=None,
|
|
|
|
):
|
|
|
|
):
|
|
|
|
t_in = t
|
|
|
|
t_in = t
|
|
|
|
model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
|
|
|
|
model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
|
|
|
@ -822,19 +869,19 @@ class LatentDiffusion(DDPM):
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
@torch.no_grad()
|
|
|
|
def p_sample(
|
|
|
|
def p_sample(
|
|
|
|
self,
|
|
|
|
self,
|
|
|
|
x,
|
|
|
|
x,
|
|
|
|
c,
|
|
|
|
c,
|
|
|
|
t,
|
|
|
|
t,
|
|
|
|
clip_denoised=False,
|
|
|
|
clip_denoised=False,
|
|
|
|
repeat_noise=False,
|
|
|
|
repeat_noise=False,
|
|
|
|
return_codebook_ids=False,
|
|
|
|
return_codebook_ids=False,
|
|
|
|
quantize_denoised=False,
|
|
|
|
quantize_denoised=False,
|
|
|
|
return_x0=False,
|
|
|
|
return_x0=False,
|
|
|
|
temperature=1.0,
|
|
|
|
temperature=1.0,
|
|
|
|
noise_dropout=0.0,
|
|
|
|
noise_dropout=0.0,
|
|
|
|
score_corrector=None,
|
|
|
|
score_corrector=None,
|
|
|
|
corrector_kwargs=None,
|
|
|
|
corrector_kwargs=None,
|
|
|
|
):
|
|
|
|
):
|
|
|
|
b, *_, device = *x.shape, x.device
|
|
|
|
b, *_, device = *x.shape, x.device
|
|
|
|
outputs = self.p_mean_variance(
|
|
|
|
outputs = self.p_mean_variance(
|
|
|
@ -864,7 +911,7 @@ class LatentDiffusion(DDPM):
|
|
|
|
|
|
|
|
|
|
|
|
if return_codebook_ids:
|
|
|
|
if return_codebook_ids:
|
|
|
|
return model_mean + nonzero_mask * (
|
|
|
|
return model_mean + nonzero_mask * (
|
|
|
|
0.5 * model_log_variance
|
|
|
|
0.5 * model_log_variance
|
|
|
|
).exp() * noise, logits.argmax(dim=1)
|
|
|
|
).exp() * noise, logits.argmax(dim=1)
|
|
|
|
if return_x0:
|
|
|
|
if return_x0:
|
|
|
|
return (
|
|
|
|
return (
|
|
|
|