|
|
@ -18,39 +18,38 @@ from imaginairy.utils import get_device
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def to_torch(x):
|
|
|
|
|
|
|
|
return x.clone().detach().to(torch.float32).to(get_device())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PLMSSchedule:
|
|
|
|
class PLMSSchedule:
|
|
|
|
def __init__(
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
self,
|
|
|
|
ddpm_num_timesteps, # 1000?
|
|
|
|
model_num_timesteps, # 1000?
|
|
|
|
|
|
|
|
model_alphas_cumprod,
|
|
|
|
ddim_num_steps, # prompt.steps?
|
|
|
|
ddim_num_steps, # prompt.steps?
|
|
|
|
alphas_cumprod,
|
|
|
|
|
|
|
|
ddim_discretize="uniform",
|
|
|
|
ddim_discretize="uniform",
|
|
|
|
):
|
|
|
|
):
|
|
|
|
device = get_device()
|
|
|
|
device = get_device()
|
|
|
|
|
|
|
|
if model_alphas_cumprod.shape[0] != model_num_timesteps:
|
|
|
|
|
|
|
|
raise ValueError("alphas have to be defined for each timestep")
|
|
|
|
|
|
|
|
|
|
|
|
assert (
|
|
|
|
self.alphas_cumprod = to_torch(model_alphas_cumprod)
|
|
|
|
alphas_cumprod.shape[0] == ddpm_num_timesteps
|
|
|
|
|
|
|
|
), "alphas have to be defined for each timestep"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def to_torch(x):
|
|
|
|
|
|
|
|
return x.clone().detach().to(torch.float32).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.alphas_cumprod = to_torch(alphas_cumprod)
|
|
|
|
|
|
|
|
# calculations for diffusion q(x_t | x_{t-1}) and others
|
|
|
|
# calculations for diffusion q(x_t | x_{t-1}) and others
|
|
|
|
self.sqrt_alphas_cumprod = to_torch(np.sqrt(alphas_cumprod.cpu()))
|
|
|
|
self.sqrt_alphas_cumprod = to_torch(np.sqrt(model_alphas_cumprod.cpu()))
|
|
|
|
self.sqrt_one_minus_alphas_cumprod = to_torch(
|
|
|
|
self.sqrt_one_minus_alphas_cumprod = to_torch(
|
|
|
|
np.sqrt(1.0 - alphas_cumprod.cpu())
|
|
|
|
np.sqrt(1.0 - model_alphas_cumprod.cpu())
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
self.ddim_timesteps = make_ddim_timesteps(
|
|
|
|
self.ddim_timesteps = make_ddim_timesteps(
|
|
|
|
ddim_discr_method=ddim_discretize,
|
|
|
|
ddim_discr_method=ddim_discretize,
|
|
|
|
num_ddim_timesteps=ddim_num_steps,
|
|
|
|
num_ddim_timesteps=ddim_num_steps,
|
|
|
|
num_ddpm_timesteps=ddpm_num_timesteps,
|
|
|
|
num_ddpm_timesteps=model_num_timesteps,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# ddim sampling parameters
|
|
|
|
# ddim sampling parameters
|
|
|
|
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
|
|
|
|
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
|
|
|
|
alphacums=alphas_cumprod.cpu(),
|
|
|
|
alphacums=model_alphas_cumprod.cpu(),
|
|
|
|
ddim_timesteps=self.ddim_timesteps,
|
|
|
|
ddim_timesteps=self.ddim_timesteps,
|
|
|
|
eta=0.0,
|
|
|
|
eta=0.0,
|
|
|
|
)
|
|
|
|
)
|
|
|
@ -63,7 +62,14 @@ class PLMSSchedule:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PLMSSampler:
|
|
|
|
class PLMSSampler:
|
|
|
|
"""probabilistic least-mean-squares"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
probabilistic least-mean-squares
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Provenance:
|
|
|
|
|
|
|
|
https://github.com/CompVis/latent-diffusion/commit/f0c4e092c156986e125f48c61a0edd38ba8ad059
|
|
|
|
|
|
|
|
https://arxiv.org/abs/2202.09778
|
|
|
|
|
|
|
|
https://github.com/luping-liu/PNDM
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, model):
|
|
|
|
def __init__(self, model):
|
|
|
|
self.model = model
|
|
|
|
self.model = model
|
|
|
@ -73,106 +79,89 @@ class PLMSSampler:
|
|
|
|
def sample(
|
|
|
|
def sample(
|
|
|
|
self,
|
|
|
|
self,
|
|
|
|
num_steps,
|
|
|
|
num_steps,
|
|
|
|
batch_size,
|
|
|
|
|
|
|
|
shape,
|
|
|
|
shape,
|
|
|
|
conditioning=None,
|
|
|
|
neutral_conditioning,
|
|
|
|
callback=None,
|
|
|
|
positive_conditioning,
|
|
|
|
img_callback=None,
|
|
|
|
guidance_scale=1.0,
|
|
|
|
quantize_x0=False,
|
|
|
|
batch_size=1,
|
|
|
|
eta=0.0,
|
|
|
|
|
|
|
|
mask=None,
|
|
|
|
mask=None,
|
|
|
|
orig_latent=None,
|
|
|
|
orig_latent=None,
|
|
|
|
temperature=1.0,
|
|
|
|
temperature=1.0,
|
|
|
|
noise_dropout=0.0,
|
|
|
|
noise_dropout=0.0,
|
|
|
|
initial_latent=None,
|
|
|
|
initial_latent=None,
|
|
|
|
unconditional_guidance_scale=1.0,
|
|
|
|
|
|
|
|
unconditional_conditioning=None,
|
|
|
|
|
|
|
|
timesteps=None,
|
|
|
|
|
|
|
|
quantize_denoised=False,
|
|
|
|
quantize_denoised=False,
|
|
|
|
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
|
|
|
|
|
|
|
**kwargs,
|
|
|
|
**kwargs,
|
|
|
|
):
|
|
|
|
):
|
|
|
|
if conditioning.shape[0] != batch_size:
|
|
|
|
if positive_conditioning.shape[0] != batch_size:
|
|
|
|
logger.warning(
|
|
|
|
raise ValueError(
|
|
|
|
f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
|
|
|
|
f"Got {positive_conditioning.shape[0]} conditionings but batch-size is {batch_size}"
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
schedule = PLMSSchedule(
|
|
|
|
schedule = PLMSSchedule(
|
|
|
|
ddpm_num_timesteps=self.model.num_timesteps,
|
|
|
|
model_num_timesteps=self.model.num_timesteps,
|
|
|
|
ddim_num_steps=num_steps,
|
|
|
|
ddim_num_steps=num_steps,
|
|
|
|
alphas_cumprod=self.model.alphas_cumprod,
|
|
|
|
model_alphas_cumprod=self.model.alphas_cumprod,
|
|
|
|
ddim_discretize="uniform",
|
|
|
|
ddim_discretize="uniform",
|
|
|
|
)
|
|
|
|
)
|
|
|
|
device = self.device
|
|
|
|
|
|
|
|
# batch_size = shape[0]
|
|
|
|
|
|
|
|
if initial_latent is None:
|
|
|
|
if initial_latent is None:
|
|
|
|
initial_latent = torch.randn(shape, device="cpu").to(device)
|
|
|
|
initial_latent = torch.randn(shape, device="cpu").to(self.device)
|
|
|
|
|
|
|
|
|
|
|
|
log_latent(initial_latent, "initial latent")
|
|
|
|
log_latent(initial_latent, "initial latent")
|
|
|
|
if timesteps is None:
|
|
|
|
|
|
|
|
timesteps = schedule.ddim_timesteps
|
|
|
|
timesteps = schedule.ddim_timesteps
|
|
|
|
elif timesteps is not None:
|
|
|
|
|
|
|
|
subset_end = (
|
|
|
|
|
|
|
|
int(
|
|
|
|
|
|
|
|
min(timesteps / schedule.ddim_timesteps.shape[0], 1)
|
|
|
|
|
|
|
|
* schedule.ddim_timesteps.shape[0]
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
- 1
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
timesteps = schedule.ddim_timesteps[:subset_end]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
time_range = np.flip(timesteps)
|
|
|
|
time_range = np.flip(timesteps)
|
|
|
|
total_steps = timesteps.shape[0]
|
|
|
|
total_steps = timesteps.shape[0]
|
|
|
|
logger.debug(f"Running PLMS Sampling with {total_steps} timesteps")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
iterator = tqdm(time_range, desc=" PLMS Sampler", total=total_steps)
|
|
|
|
|
|
|
|
old_eps = []
|
|
|
|
old_eps = []
|
|
|
|
img = initial_latent
|
|
|
|
noisy_latent = initial_latent
|
|
|
|
|
|
|
|
|
|
|
|
for i, step in enumerate(iterator):
|
|
|
|
for i, step in enumerate(tqdm(time_range, total=total_steps)):
|
|
|
|
index = total_steps - i - 1
|
|
|
|
index = total_steps - i - 1
|
|
|
|
ts = torch.full((batch_size,), step, device=device, dtype=torch.long)
|
|
|
|
ts = torch.full((batch_size,), step, device=self.device, dtype=torch.long)
|
|
|
|
ts_next = torch.full(
|
|
|
|
ts_next = torch.full(
|
|
|
|
(batch_size,),
|
|
|
|
(batch_size,),
|
|
|
|
time_range[min(i + 1, len(time_range) - 1)],
|
|
|
|
time_range[min(i + 1, len(time_range) - 1)],
|
|
|
|
device=device,
|
|
|
|
device=self.device,
|
|
|
|
dtype=torch.long,
|
|
|
|
dtype=torch.long,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if mask is not None:
|
|
|
|
if mask is not None:
|
|
|
|
assert orig_latent is not None
|
|
|
|
assert orig_latent is not None
|
|
|
|
img_orig = self.model.q_sample(orig_latent, ts)
|
|
|
|
img_orig = self.model.q_sample(orig_latent, ts)
|
|
|
|
img = img_orig * mask + (1.0 - mask) * img
|
|
|
|
noisy_latent = img_orig * mask + (1.0 - mask) * noisy_latent
|
|
|
|
|
|
|
|
|
|
|
|
img, pred_x0, noise_prediction = self.p_sample_plms(
|
|
|
|
noisy_latent, predicted_latent, noise_pred = self.p_sample_plms(
|
|
|
|
img,
|
|
|
|
noisy_latent=noisy_latent,
|
|
|
|
conditioning,
|
|
|
|
neutral_conditioning=neutral_conditioning,
|
|
|
|
ts,
|
|
|
|
positive_conditioning=positive_conditioning,
|
|
|
|
|
|
|
|
guidance_scale=guidance_scale,
|
|
|
|
|
|
|
|
time_encoding=ts,
|
|
|
|
schedule=schedule,
|
|
|
|
schedule=schedule,
|
|
|
|
index=index,
|
|
|
|
index=index,
|
|
|
|
quantize_denoised=quantize_denoised,
|
|
|
|
quantize_denoised=quantize_denoised,
|
|
|
|
temperature=temperature,
|
|
|
|
temperature=temperature,
|
|
|
|
noise_dropout=noise_dropout,
|
|
|
|
noise_dropout=noise_dropout,
|
|
|
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
|
|
|
|
|
|
|
unconditional_conditioning=unconditional_conditioning,
|
|
|
|
|
|
|
|
old_eps=old_eps,
|
|
|
|
old_eps=old_eps,
|
|
|
|
t_next=ts_next,
|
|
|
|
t_next=ts_next,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
old_eps.append(noise_prediction)
|
|
|
|
old_eps.append(noise_pred)
|
|
|
|
if len(old_eps) >= 4:
|
|
|
|
if len(old_eps) >= 4:
|
|
|
|
old_eps.pop(0)
|
|
|
|
old_eps.pop(0)
|
|
|
|
if callback:
|
|
|
|
|
|
|
|
callback(i)
|
|
|
|
|
|
|
|
if img_callback:
|
|
|
|
|
|
|
|
img_callback(img, "img")
|
|
|
|
|
|
|
|
img_callback(pred_x0, "pred_x0")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return img
|
|
|
|
log_latent(noisy_latent, "noisy_latent")
|
|
|
|
|
|
|
|
log_latent(predicted_latent, "predicted_latent")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return noisy_latent
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
@torch.no_grad()
|
|
|
|
def p_sample_plms(
|
|
|
|
def p_sample_plms(
|
|
|
|
self,
|
|
|
|
self,
|
|
|
|
noisy_latent,
|
|
|
|
noisy_latent,
|
|
|
|
|
|
|
|
neutral_conditioning,
|
|
|
|
positive_conditioning,
|
|
|
|
positive_conditioning,
|
|
|
|
|
|
|
|
guidance_scale,
|
|
|
|
time_encoding,
|
|
|
|
time_encoding,
|
|
|
|
schedule: PLMSSchedule,
|
|
|
|
schedule: PLMSSchedule,
|
|
|
|
index,
|
|
|
|
index,
|
|
|
@ -180,20 +169,19 @@ class PLMSSampler:
|
|
|
|
quantize_denoised=False,
|
|
|
|
quantize_denoised=False,
|
|
|
|
temperature=1.0,
|
|
|
|
temperature=1.0,
|
|
|
|
noise_dropout=0.0,
|
|
|
|
noise_dropout=0.0,
|
|
|
|
unconditional_guidance_scale=1.0,
|
|
|
|
|
|
|
|
unconditional_conditioning=None,
|
|
|
|
|
|
|
|
old_eps=None,
|
|
|
|
old_eps=None,
|
|
|
|
t_next=None,
|
|
|
|
t_next=None,
|
|
|
|
):
|
|
|
|
):
|
|
|
|
batch_size = noisy_latent.shape[0]
|
|
|
|
assert guidance_scale >= 1
|
|
|
|
noise_prediction = get_noise_prediction(
|
|
|
|
noise_pred = get_noise_prediction(
|
|
|
|
denoise_func=self.model.apply_model,
|
|
|
|
denoise_func=self.model.apply_model,
|
|
|
|
noisy_latent=noisy_latent,
|
|
|
|
noisy_latent=noisy_latent,
|
|
|
|
time_encoding=time_encoding,
|
|
|
|
time_encoding=time_encoding,
|
|
|
|
neutral_conditioning=unconditional_conditioning,
|
|
|
|
neutral_conditioning=neutral_conditioning,
|
|
|
|
positive_conditioning=positive_conditioning,
|
|
|
|
positive_conditioning=positive_conditioning,
|
|
|
|
signal_amplification=unconditional_guidance_scale,
|
|
|
|
signal_amplification=guidance_scale,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
batch_size = noisy_latent.shape[0]
|
|
|
|
|
|
|
|
|
|
|
|
def get_x_prev_and_pred_x0(e_t, index):
|
|
|
|
def get_x_prev_and_pred_x0(e_t, index):
|
|
|
|
# select parameters corresponding to the currently considered timestep
|
|
|
|
# select parameters corresponding to the currently considered timestep
|
|
|
@ -232,38 +220,33 @@ class PLMSSampler:
|
|
|
|
|
|
|
|
|
|
|
|
if len(old_eps) == 0:
|
|
|
|
if len(old_eps) == 0:
|
|
|
|
# Pseudo Improved Euler (2nd order)
|
|
|
|
# Pseudo Improved Euler (2nd order)
|
|
|
|
x_prev, pred_x0 = get_x_prev_and_pred_x0(noise_prediction, index)
|
|
|
|
x_prev, pred_x0 = get_x_prev_and_pred_x0(noise_pred, index)
|
|
|
|
e_t_next = get_noise_prediction(
|
|
|
|
e_t_next = get_noise_prediction(
|
|
|
|
denoise_func=self.model.apply_model,
|
|
|
|
denoise_func=self.model.apply_model,
|
|
|
|
noisy_latent=x_prev,
|
|
|
|
noisy_latent=x_prev,
|
|
|
|
time_encoding=t_next,
|
|
|
|
time_encoding=t_next,
|
|
|
|
neutral_conditioning=unconditional_conditioning,
|
|
|
|
neutral_conditioning=neutral_conditioning,
|
|
|
|
positive_conditioning=positive_conditioning,
|
|
|
|
positive_conditioning=positive_conditioning,
|
|
|
|
signal_amplification=unconditional_guidance_scale,
|
|
|
|
signal_amplification=guidance_scale,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
e_t_prime = (noise_prediction + e_t_next) / 2
|
|
|
|
e_t_prime = (noise_pred + e_t_next) / 2
|
|
|
|
elif len(old_eps) == 1:
|
|
|
|
elif len(old_eps) == 1:
|
|
|
|
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
|
|
|
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
|
|
|
e_t_prime = (3 * noise_prediction - old_eps[-1]) / 2
|
|
|
|
e_t_prime = (3 * noise_pred - old_eps[-1]) / 2
|
|
|
|
elif len(old_eps) == 2:
|
|
|
|
elif len(old_eps) == 2:
|
|
|
|
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
|
|
|
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
|
|
|
e_t_prime = (
|
|
|
|
e_t_prime = (23 * noise_pred - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
|
|
|
23 * noise_prediction - 16 * old_eps[-1] + 5 * old_eps[-2]
|
|
|
|
|
|
|
|
) / 12
|
|
|
|
|
|
|
|
elif len(old_eps) >= 3:
|
|
|
|
elif len(old_eps) >= 3:
|
|
|
|
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
|
|
|
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
|
|
|
e_t_prime = (
|
|
|
|
e_t_prime = (
|
|
|
|
55 * noise_prediction
|
|
|
|
55 * noise_pred - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]
|
|
|
|
- 59 * old_eps[-1]
|
|
|
|
|
|
|
|
+ 37 * old_eps[-2]
|
|
|
|
|
|
|
|
- 9 * old_eps[-3]
|
|
|
|
|
|
|
|
) / 24
|
|
|
|
) / 24
|
|
|
|
|
|
|
|
|
|
|
|
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
|
|
|
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
|
|
|
log_latent(x_prev, "x_prev")
|
|
|
|
log_latent(x_prev, "x_prev")
|
|
|
|
log_latent(pred_x0, "pred_x0")
|
|
|
|
log_latent(pred_x0, "pred_x0")
|
|
|
|
|
|
|
|
|
|
|
|
return x_prev, pred_x0, noise_prediction
|
|
|
|
return x_prev, pred_x0, noise_pred
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
@torch.no_grad()
|
|
|
|
def noise_an_image(self, init_latent, t, schedule, noise=None):
|
|
|
|
def noise_an_image(self, init_latent, t, schedule, noise=None):
|
|
|
@ -285,25 +268,22 @@ class PLMSSampler:
|
|
|
|
@torch.no_grad()
|
|
|
|
@torch.no_grad()
|
|
|
|
def decode(
|
|
|
|
def decode(
|
|
|
|
self,
|
|
|
|
self,
|
|
|
|
cond,
|
|
|
|
neutral_conditioning,
|
|
|
|
|
|
|
|
positive_conditioning,
|
|
|
|
|
|
|
|
guidance_scale,
|
|
|
|
schedule,
|
|
|
|
schedule,
|
|
|
|
initial_latent=None,
|
|
|
|
initial_latent=None,
|
|
|
|
t_start=None,
|
|
|
|
t_start=None,
|
|
|
|
unconditional_guidance_scale=1.0,
|
|
|
|
|
|
|
|
unconditional_conditioning=None,
|
|
|
|
|
|
|
|
img_callback=None,
|
|
|
|
|
|
|
|
temperature=1.0,
|
|
|
|
temperature=1.0,
|
|
|
|
mask=None,
|
|
|
|
mask=None,
|
|
|
|
orig_latent=None,
|
|
|
|
orig_latent=None,
|
|
|
|
noise=None,
|
|
|
|
noise=None,
|
|
|
|
):
|
|
|
|
):
|
|
|
|
device = self.device
|
|
|
|
|
|
|
|
timesteps = schedule.ddim_timesteps[:t_start]
|
|
|
|
timesteps = schedule.ddim_timesteps[:t_start]
|
|
|
|
|
|
|
|
|
|
|
|
time_range = np.flip(timesteps)
|
|
|
|
time_range = np.flip(timesteps)
|
|
|
|
total_steps = timesteps.shape[0]
|
|
|
|
total_steps = timesteps.shape[0]
|
|
|
|
|
|
|
|
|
|
|
|
iterator = tqdm(time_range, desc="PLMS img2img", total=total_steps)
|
|
|
|
|
|
|
|
x_dec = initial_latent
|
|
|
|
x_dec = initial_latent
|
|
|
|
old_eps = []
|
|
|
|
old_eps = []
|
|
|
|
log_latent(x_dec, "x_dec")
|
|
|
|
log_latent(x_dec, "x_dec")
|
|
|
@ -315,7 +295,7 @@ class PLMSSampler:
|
|
|
|
if noise is None
|
|
|
|
if noise is None
|
|
|
|
else noise
|
|
|
|
else noise
|
|
|
|
)
|
|
|
|
)
|
|
|
|
for i, step in enumerate(iterator):
|
|
|
|
for i, step in enumerate(tqdm(time_range, total=total_steps)):
|
|
|
|
index = total_steps - i - 1
|
|
|
|
index = total_steps - i - 1
|
|
|
|
ts = torch.full(
|
|
|
|
ts = torch.full(
|
|
|
|
(initial_latent.shape[0],),
|
|
|
|
(initial_latent.shape[0],),
|
|
|
@ -326,7 +306,7 @@ class PLMSSampler:
|
|
|
|
ts_next = torch.full(
|
|
|
|
ts_next = torch.full(
|
|
|
|
(initial_latent.shape[0],),
|
|
|
|
(initial_latent.shape[0],),
|
|
|
|
time_range[min(i + 1, len(time_range) - 1)],
|
|
|
|
time_range[min(i + 1, len(time_range) - 1)],
|
|
|
|
device=device,
|
|
|
|
device=self.device,
|
|
|
|
dtype=torch.long,
|
|
|
|
dtype=torch.long,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
@ -346,13 +326,13 @@ class PLMSSampler:
|
|
|
|
log_latent(x_dec, f"x_dec {ts}")
|
|
|
|
log_latent(x_dec, f"x_dec {ts}")
|
|
|
|
|
|
|
|
|
|
|
|
x_dec, pred_x0, noise_prediction = self.p_sample_plms(
|
|
|
|
x_dec, pred_x0, noise_prediction = self.p_sample_plms(
|
|
|
|
x_dec,
|
|
|
|
noisy_latent=x_dec,
|
|
|
|
cond,
|
|
|
|
guidance_scale=guidance_scale,
|
|
|
|
ts,
|
|
|
|
neutral_conditioning=neutral_conditioning,
|
|
|
|
|
|
|
|
positive_conditioning=positive_conditioning,
|
|
|
|
|
|
|
|
time_encoding=ts,
|
|
|
|
schedule=schedule,
|
|
|
|
schedule=schedule,
|
|
|
|
index=index,
|
|
|
|
index=index,
|
|
|
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
|
|
|
|
|
|
|
unconditional_conditioning=unconditional_conditioning,
|
|
|
|
|
|
|
|
temperature=temperature,
|
|
|
|
temperature=temperature,
|
|
|
|
old_eps=old_eps,
|
|
|
|
old_eps=old_eps,
|
|
|
|
t_next=ts_next,
|
|
|
|
t_next=ts_next,
|
|
|
@ -362,11 +342,6 @@ class PLMSSampler:
|
|
|
|
if len(old_eps) >= 4:
|
|
|
|
if len(old_eps) >= 4:
|
|
|
|
old_eps.pop(0)
|
|
|
|
old_eps.pop(0)
|
|
|
|
|
|
|
|
|
|
|
|
if img_callback:
|
|
|
|
|
|
|
|
img_callback(x_dec, "x_dec")
|
|
|
|
|
|
|
|
img_callback(pred_x0, "pred_x0")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
log_latent(x_dec, f"x_dec {i}")
|
|
|
|
log_latent(x_dec, f"x_dec {i}")
|
|
|
|
log_latent(x_dec, f"e_t {i}")
|
|
|
|
|
|
|
|
log_latent(pred_x0, f"pred_x0 {i}")
|
|
|
|
log_latent(pred_x0, f"pred_x0 {i}")
|
|
|
|
return x_dec
|
|
|
|
return x_dec
|
|
|
|