refactor: merge img2img and txt2img pipelines

pull/60/head
Bryce 2 years ago committed by Bryce Drennan
parent a105dadbc4
commit 72026c8c90

@ -186,20 +186,20 @@ def imagine(
seed_everything(prompt.seed)
model.tile_mode(prompt.tile_mode)
uc = None
neutral_conditioning = None
if prompt.prompt_strength != 1.0:
uc = model.get_learned_conditioning(1 * [""])
log_conditioning(uc, "neutral conditioning")
neutral_conditioning = model.get_learned_conditioning(1 * [""])
log_conditioning(neutral_conditioning, "neutral conditioning")
if prompt.conditioning is not None:
c = prompt.conditioning
positive_conditioning = prompt.conditioning
else:
total_weight = sum(wp.weight for wp in prompt.prompts)
c = sum(
positive_conditioning = sum(
model.get_learned_conditioning(wp.text)
* (wp.weight / total_weight)
for wp in prompt.prompts
)
log_conditioning(c, "positive conditioning")
log_conditioning(positive_conditioning, "positive conditioning")
shape = [
1,
@ -209,7 +209,7 @@ def imagine(
]
if prompt.init_image and prompt.sampler_type not in ("ddim", "plms"):
sampler_type = "plms"
logger.info(" Sampler type switched to plms for img2img")
logger.info("Sampler type switched to plms for img2img")
else:
sampler_type = prompt.sampler_type
@ -287,36 +287,36 @@ def imagine(
# prompt strength gets converted to time encodings,
# which means you can't get to true 0 without this hack
# (or setting steps=1000)
z_enc = noise
init_latent_noised = noise
else:
z_enc = sampler.noise_an_image(
init_latent_noised = sampler.noise_an_image(
init_latent,
torch.tensor([t_enc - 1]).to(get_device()),
schedule=schedule,
noise=noise,
)
log_latent(z_enc, "z_enc")
log_latent(init_latent_noised, "init_latent_noised")
# decode it
samples = sampler.decode(
initial_latent=z_enc,
positive_conditioning=c,
t_start=t_enc,
schedule=schedule,
samples = sampler.sample(
num_steps=prompt.steps,
initial_latent=init_latent_noised,
positive_conditioning=positive_conditioning,
neutral_conditioning=neutral_conditioning,
guidance_scale=prompt.prompt_strength,
neutral_conditioning=uc,
t_start=t_enc,
mask=mask,
orig_latent=init_latent,
shape=shape,
batch_size=1,
)
else:
samples = sampler.sample(
num_steps=prompt.steps,
positive_conditioning=c,
neutral_conditioning=neutral_conditioning,
positive_conditioning=positive_conditioning,
guidance_scale=prompt.prompt_strength,
batch_size=1,
shape=shape,
guidance_scale=prompt.prompt_strength,
neutral_conditioning=uc,
)
x_samples = model.decode_first_stage(samples)

@ -38,6 +38,7 @@ class DDIMSampler:
temperature=1.0,
noise_dropout=0.0,
initial_latent=None,
t_start=None,
quantize_x0=False,
):
if positive_conditioning.shape[0] != batch_size:
@ -57,7 +58,7 @@ class DDIMSampler:
log_latent(initial_latent, "initial latent")
timesteps = schedule.ddim_timesteps
timesteps = schedule.ddim_timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
@ -69,8 +70,18 @@ class DDIMSampler:
if mask is not None:
assert orig_latent is not None
img_orig = self.model.q_sample(orig_latent, ts)
noisy_latent = img_orig * mask + (1.0 - mask) * noisy_latent
xdec_orig = self.model.q_sample(orig_latent, ts)
log_latent(xdec_orig, "xdec_orig")
# this helps prevent the weird disjointed images that can happen with masking
hint_strength = 0.8
if i < 2:
xdec_orig_with_hints = (
xdec_orig * (1 - hint_strength) + orig_latent * hint_strength
)
else:
xdec_orig_with_hints = xdec_orig
noisy_latent = xdec_orig_with_hints * mask + (1.0 - mask) * noisy_latent
log_latent(noisy_latent, "noisy_latent")
noisy_latent, predicted_latent = self.p_sample_ddim(
noisy_latent=noisy_latent,
@ -190,63 +201,3 @@ class DDIMSampler:
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, init_latent.shape)
* noise
)
@torch.no_grad()
def decode(
self,
initial_latent,
neutral_conditioning,
positive_conditioning,
guidance_scale,
t_start,
schedule,
temperature=1.0,
mask=None,
orig_latent=None,
):
timesteps = schedule.ddim_timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
noisy_latent = initial_latent
for i, step in enumerate(tqdm(time_range, total=total_steps)):
index = total_steps - i - 1
ts = torch.full(
(initial_latent.shape[0],),
step,
device=initial_latent.device,
dtype=torch.long,
)
if mask is not None:
assert orig_latent is not None
xdec_orig = self.model.q_sample(orig_latent, ts)
log_latent(xdec_orig, "xdec_orig")
# this helps prevent the weird disjointed images that can happen with masking
hint_strength = 0.8
if i < 2:
xdec_orig_with_hints = (
xdec_orig * (1 - hint_strength) + orig_latent * hint_strength
)
else:
xdec_orig_with_hints = xdec_orig
noisy_latent = xdec_orig_with_hints * mask + (1.0 - mask) * noisy_latent
log_latent(noisy_latent, "noisy_latent")
noisy_latent, predicted_latent = self.p_sample_ddim(
noisy_latent=noisy_latent,
positive_conditioning=positive_conditioning,
time_encoding=ts,
schedule=schedule,
index=index,
guidance_scale=guidance_scale,
neutral_conditioning=neutral_conditioning,
temperature=temperature,
)
log_latent(noisy_latent, f"noisy_latent {i}")
log_latent(predicted_latent, f"predicted_latent {i}")
return noisy_latent

@ -41,6 +41,7 @@ class PLMSSampler:
temperature=1.0,
noise_dropout=0.0,
initial_latent=None,
t_start=None,
quantize_denoised=False,
**kwargs,
):
@ -61,13 +62,18 @@ class PLMSSampler:
log_latent(initial_latent, "initial latent")
timesteps = schedule.ddim_timesteps
timesteps = schedule.ddim_timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
old_eps = []
noisy_latent = initial_latent
mask_noise = None
if mask is not None:
mask_noise = torch.randn_like(noisy_latent, device="cpu").to(
noisy_latent.device
)
for i, step in enumerate(tqdm(time_range, total=total_steps)):
index = total_steps - i - 1
@ -81,8 +87,18 @@ class PLMSSampler:
if mask is not None:
assert orig_latent is not None
img_orig = self.model.q_sample(orig_latent, ts)
noisy_latent = img_orig * mask + (1.0 - mask) * noisy_latent
xdec_orig = self.model.q_sample(orig_latent, ts, mask_noise)
log_latent(xdec_orig, f"xdec_orig i={i} index-{index}")
# this helps prevent the weird disjointed images that can happen with masking
hint_strength = 0.8
if i < 2:
xdec_orig_with_hints = (
xdec_orig * (1 - hint_strength) + orig_latent * hint_strength
)
else:
xdec_orig_with_hints = xdec_orig
noisy_latent = xdec_orig_with_hints * mask + (1.0 - mask) * noisy_latent
log_latent(noisy_latent, f"x_dec {ts}")
noisy_latent, predicted_latent, noise_pred = self.p_sample_plms(
noisy_latent=noisy_latent,
@ -202,7 +218,6 @@ class PLMSSampler:
@torch.no_grad()
def noise_an_image(self, init_latent, t, schedule, noise=None):
# replace with ddpm.q_sample?
# fast, but does not allow for exact reconstruction
# t serves as an index to gather the correct alphas
t = t.clamp(0, 1000)
@ -216,84 +231,3 @@ class PLMSSampler:
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, init_latent.shape)
* noise
)
@torch.no_grad()
def decode(
self,
neutral_conditioning,
positive_conditioning,
guidance_scale,
schedule,
initial_latent=None,
t_start=None,
temperature=1.0,
mask=None,
orig_latent=None,
noise=None,
):
timesteps = schedule.ddim_timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
x_dec = initial_latent
old_eps = []
log_latent(x_dec, "x_dec")
# not sure what the downside of using the same noise throughout the process would be...
# seems to work fine. maybe it runs faster?
noise = (
torch.randn_like(x_dec, device="cpu").to(x_dec.device)
if noise is None
else noise
)
for i, step in enumerate(tqdm(time_range, total=total_steps)):
index = total_steps - i - 1
ts = torch.full(
(initial_latent.shape[0],),
step,
device=initial_latent.device,
dtype=torch.long,
)
ts_next = torch.full(
(initial_latent.shape[0],),
time_range[min(i + 1, len(time_range) - 1)],
device=self.device,
dtype=torch.long,
)
if mask is not None:
assert orig_latent is not None
xdec_orig = self.model.q_sample(orig_latent, ts, noise)
log_latent(xdec_orig, f"xdec_orig i={i} index-{index}")
# this helps prevent the weird disjointed images that can happen with masking
hint_strength = 0.8
if i < 2:
xdec_orig_with_hints = (
xdec_orig * (1 - hint_strength) + orig_latent * hint_strength
)
else:
xdec_orig_with_hints = xdec_orig
x_dec = xdec_orig_with_hints * mask + (1.0 - mask) * x_dec
log_latent(x_dec, f"x_dec {ts}")
x_dec, pred_x0, noise_prediction = self.p_sample_plms(
noisy_latent=x_dec,
guidance_scale=guidance_scale,
neutral_conditioning=neutral_conditioning,
positive_conditioning=positive_conditioning,
time_encoding=ts,
schedule=schedule,
index=index,
temperature=temperature,
old_eps=old_eps,
t_next=ts_next,
)
old_eps.append(noise_prediction)
if len(old_eps) >= 4:
old_eps.pop(0)
log_latent(x_dec, f"x_dec {i}")
log_latent(pred_x0, f"pred_x0 {i}")
return x_dec

Loading…
Cancel
Save