fix: k-diff samplers made more stable by skipping second to last step

https://github.com/crowsonkb/k-diffusion/issues/43#issuecomment-1305195666

See discussion here as well: https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/3483
pull/279/head
Bryce 1 year ago committed by Bryce Drennan
parent 70c58467c0
commit dbc6a249a6

@ -65,9 +65,10 @@ class DDIMSampler(ImageSampler):
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
if orig_latent is not None:
# t_start is none if init image strength set to 0
if orig_latent is not None and t_start is not None:
noisy_latent = self.noise_an_image(
init_latent=orig_latent, t=t_start, schedule=schedule, noise=noise
init_latent=orig_latent, t=t_start - 1, schedule=schedule, noise=noise
)
else:
noisy_latent = noise

@ -14,6 +14,7 @@ from imaginairy.samplers.base import (
from imaginairy.utils import get_device
from imaginairy.vendored.k_diffusion import sampling as k_sampling
from imaginairy.vendored.k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
from imaginairy.vendored.k_diffusion.sampling import get_sigmas_karras
class StandardCompVisDenoiser(CompVisDenoiser):
@ -96,12 +97,17 @@ class KDiffusionSampler(ImageSampler, ABC):
t_start = num_steps - t_start + 1
sigmas = self.cv_denoiser.get_sigmas(num_steps)[t_start:]
# see https://github.com/crowsonkb/k-diffusion/issues/43#issuecomment-1305195666
if self.short_name in (SamplerName.K_DPM_2, SamplerName.K_DPMPP_2M, SamplerName.K_DPM_2_ANCESTRAL):
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
# if our number of steps is zero, just return the initial latent
if sigmas.nelement() == 0:
if orig_latent is not None:
return orig_latent
return noise
# t_start is none if init image strength set to 0
if orig_latent is not None and t_start is not None:
noisy_latent = noise * sigmas[0] + orig_latent
else:
@ -141,12 +147,6 @@ class KDiffusionSampler(ImageSampler, ABC):
return samples
@torch.no_grad()
def noise_an_image(self, init_latent, t, sigmas, noise=None):
if isinstance(t, int):
t = torch.tensor([t], device=get_device())
t = t.clamp(0, 1000)
class DPMFastSampler(KDiffusionSampler):
short_name = SamplerName.K_DPM_FAST

@ -75,9 +75,10 @@ class PLMSSampler(ImageSampler):
old_eps = []
if orig_latent is not None:
# t_start is none if init image strength set to 0
if orig_latent is not None and t_start is not None:
noisy_latent = self.noise_an_image(
init_latent=orig_latent, t=t_start, schedule=schedule, noise=noise
init_latent=orig_latent, t=t_start - 1, schedule=schedule, noise=noise
)
else:
noisy_latent = noise

Loading…
Cancel
Save