fix: img2img was broken for all samplers except ddim,plms

img2img was broken for all samplers except plms and ddim when init image strength was >~0.25.  Been this way for a while. whoops
pull/279/head
Bryce 1 year ago committed by Bryce Drennan
parent 28c294554b
commit 70c58467c0

@ -380,6 +380,8 @@ docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface -
## ChangeLog
- fix: img2img was broken for all samplers except plms and ddim when init image strength was >~0.25
**10.2.0**
- feature: input raw control images (a pose, canny map, depth map, etc) directly using `--control-image-raw`
This is opposed to current behavior of extracting the control signal from an input image via `--control-image`

@ -224,7 +224,6 @@ def _generate_single_image(
from imaginairy.outpaint import outpaint_arg_str_parse, prepare_image_for_outpaint
from imaginairy.safety import create_safety_score
from imaginairy.samplers import SAMPLER_LOOKUP
from imaginairy.samplers.base import NoiseSchedule, noise_an_image
from imaginairy.samplers.editing import CFGEditingDenoiser
from imaginairy.schema import ImaginePrompt, ImagineResult
from imaginairy.utils import get_device, randn_seeded
@ -290,7 +289,7 @@ def _generate_single_image(
SamplerCls = SAMPLER_LOOKUP[prompt.sampler_type.lower()]
sampler = SamplerCls(model)
mask_latent = mask_image = mask_image_orig = mask_grayscale = None
t_enc = init_latent = init_latent_noised = control_image = None
t_enc = init_latent = control_image = None
starting_image = None
denoiser_cls = None
@ -303,8 +302,9 @@ def _generate_single_image(
if prompt.init_image:
starting_image = prompt.init_image
generation_strength = 1 - prompt.init_image_strength
if model.cond_stage_key == "edit":
t_enc = prompt.steps
if model.cond_stage_key == "edit" or generation_strength >= 1:
t_enc = None
else:
t_enc = int(prompt.steps * generation_strength)
@ -359,24 +359,24 @@ def _generate_single_image(
)
# noise = noise[:, :, : init_latent.shape[2], : init_latent.shape[3]]
schedule = NoiseSchedule(
model_num_timesteps=model.num_timesteps,
ddim_num_steps=prompt.steps,
model_alphas_cumprod=model.alphas_cumprod,
ddim_discretize="uniform",
)
if generation_strength >= 1:
# prompt strength gets converted to time encodings,
# which means you can't get to true 0 without this hack
# (or setting steps=1000)
init_latent_noised = noise
else:
init_latent_noised = noise_an_image(
init_latent,
torch.tensor([t_enc - 1]).to(get_device()),
schedule=schedule,
noise=noise,
)
# schedule = NoiseSchedule(
# model_num_timesteps=model.num_timesteps,
# ddim_num_steps=prompt.steps,
# model_alphas_cumprod=model.alphas_cumprod,
# ddim_discretize="uniform",
# )
# if generation_strength >= 1:
# # prompt strength gets converted to time encodings,
# # which means you can't get to true 0 without this hack
# # (or setting steps=1000)
# init_latent_noised = noise
# else:
# init_latent_noised = noise_an_image(
# init_latent,
# torch.tensor([t_enc - 1]).to(get_device()),
# schedule=schedule,
# noise=noise,
# )
if hasattr(model, "depth_stage_key"):
# depth model
@ -472,7 +472,6 @@ def _generate_single_image(
"c_concat": c_cat_neutral,
"c_crossattn": [neutral_conditioning],
}
log_latent(init_latent_noised, "init_latent_noised")
if (
prompt.allow_compose_phase
@ -483,7 +482,7 @@ def _generate_single_image(
sampler=sampler,
sampler_kwargs={
"num_steps": prompt.steps,
"initial_latent": init_latent_noised,
"noise": noise,
"positive_conditioning": positive_conditioning,
"neutral_conditioning": neutral_conditioning,
"guidance_scale": prompt.prompt_strength,
@ -498,26 +497,12 @@ def _generate_single_image(
if comp_samples is not None:
result_images["composition"] = comp_samples
noise = noise[:, :, : comp_samples.shape[2], : comp_samples.shape[3]]
schedule = NoiseSchedule(
model_num_timesteps=model.num_timesteps,
ddim_num_steps=prompt.steps,
model_alphas_cumprod=model.alphas_cumprod,
ddim_discretize="uniform",
)
t_enc = int(prompt.steps * 0.75)
init_latent_noised = noise_an_image(
comp_samples,
torch.tensor([t_enc - 1]).to(get_device()),
schedule=schedule,
noise=noise,
)
log_latent(comp_samples, "comp_samples")
init_latent = comp_samples
with lc.timing("sampling"):
samples = sampler.sample(
num_steps=prompt.steps,
initial_latent=init_latent_noised,
positive_conditioning=positive_conditioning,
neutral_conditioning=neutral_conditioning,
guidance_scale=prompt.prompt_strength,
@ -527,6 +512,7 @@ def _generate_single_image(
shape=shape,
batch_size=1,
denoiser_cls=denoiser_cls,
noise=noise,
)
if return_latent:
return samples
@ -636,7 +622,7 @@ def _generate_composition_latent(
from imaginairy.enhancers.upscale_riverwing import upscale_latent
from imaginairy.log_utils import log_img, log_latent
b, c, h, w = orig_shape = sampler_kwargs["shape"]
b, c, h, w = sampler_kwargs["shape"]
max_compose_gen_size = 768
shrink_scale = calc_scale_to_fit_within(
height=h,
@ -650,9 +636,9 @@ def _generate_composition_latent(
# shrink everything
new_shape = b, c, int(round(h * shrink_scale)), int(round(w * shrink_scale))
initial_latent = new_kwargs["initial_latent"]
if initial_latent is not None:
initial_latent = F.interpolate(initial_latent, size=new_shape[2:], mode="area")
noise = new_kwargs["noise"]
if noise is not None:
noise = F.interpolate(noise, size=new_shape[2:], mode="nearest-exact")
for cond in [
new_kwargs["positive_conditioning"],
@ -676,7 +662,7 @@ def _generate_composition_latent(
new_kwargs.update(
{
"num_steps": 15,
"initial_latent": initial_latent,
"noise": noise,
"t_start": t_start,
"mask": mask_latent,
"orig_latent": orig_latent,
@ -685,6 +671,8 @@ def _generate_composition_latent(
)
samples = sampler.sample(**new_kwargs)
# while samples.shape[2] < h:
logger.info("Upscaling latent...")
samples = upscale_latent(samples)
log_latent(samples, "upscaled")
img_t = sampler.model.decode_first_stage(samples)

@ -87,3 +87,7 @@ def model_list_cmd():
print(
f"{control_mode.alias: <10} {control_mode.short_name: <18} {control_mode.control_type}"
)
if __name__ == "__main__":
aimg() # noqa

@ -43,18 +43,11 @@ class DDIMSampler(ImageSampler):
orig_latent=None,
temperature=1.0,
noise_dropout=0.0,
initial_latent=None,
noise=None,
t_start=None,
quantize_x0=False,
**kwargs,
):
# print("Sampling with DDIM")
# print("num_steps", num_steps)
# print("shape", shape)
# print("neutral_conditioning", neutral_conditioning)
# print("positive_conditioning", positive_conditioning)
# print("guidance_scale", guidance_scale)
# print("batch_size", batch_size)
schedule = NoiseSchedule(
model_num_timesteps=self.model.num_timesteps,
model_alphas_cumprod=self.model.alphas_cumprod,
@ -62,16 +55,22 @@ class DDIMSampler(ImageSampler):
ddim_discretize="uniform",
)
if initial_latent is None:
initial_latent = torch.randn(shape, device="cpu").to(self.device)
if noise is None:
noise = torch.randn(shape, device="cpu").to(self.device)
log_latent(initial_latent, "initial latent")
log_latent(noise, "initial noise")
timesteps = schedule.ddim_timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
noisy_latent = initial_latent
if orig_latent is not None:
noisy_latent = self.noise_an_image(
init_latent=orig_latent, t=t_start, schedule=schedule, noise=noise
)
else:
noisy_latent = noise
mask_noise = None
if mask is not None:
@ -219,7 +218,8 @@ class DDIMSampler(ImageSampler):
@torch.no_grad()
def noise_an_image(self, init_latent, t, schedule, noise=None):
# t serves as an index to gather the correct alphas
if isinstance(t, int):
t = torch.tensor([t], device=get_device())
t = t.clamp(0, 1000)
sqrt_alphas_cumprod = torch.sqrt(schedule.ddim_alphas)
sqrt_one_minus_alphas_cumprod = schedule.ddim_sqrt_one_minus_alphas

@ -79,7 +79,7 @@ class KDiffusionSampler(ImageSampler, ABC):
batch_size=1,
mask=None,
orig_latent=None,
initial_latent=None,
noise=None,
t_start=None,
denoiser_cls=None,
):
@ -88,22 +88,27 @@ class KDiffusionSampler(ImageSampler, ABC):
# f"Got {positive_conditioning.shape[0]} conditionings but batch-size is {batch_size}"
# )
if initial_latent is None:
initial_latent = torch.randn(shape, device="cpu").to(self.device)
if noise is None:
noise = torch.randn(shape, device="cpu").to(self.device)
log_latent(initial_latent, "initial_latent")
log_latent(noise, "initial noise")
if t_start is not None:
t_start = num_steps - t_start + 1
sigmas = self.cv_denoiser.get_sigmas(num_steps)[t_start:]
# if our number of steps is zero, just return the initial latent
if sigmas.nelement() == 0:
if orig_latent is not None:
return orig_latent
return initial_latent
return noise
if orig_latent is not None and t_start is not None:
noisy_latent = noise * sigmas[0] + orig_latent
else:
noisy_latent = noise * sigmas[0]
x = noisy_latent
x = initial_latent * sigmas[0]
log_latent(x, "initial_sigma_noised_tensor")
if denoiser_cls is None:
denoiser_cls = CFGDenoiser
@ -111,9 +116,7 @@ class KDiffusionSampler(ImageSampler, ABC):
mask_noise = None
if mask is not None:
mask_noise = torch.randn_like(initial_latent, device="cpu").to(
initial_latent.device
)
mask_noise = torch.randn_like(x, device="cpu").to(x.device)
def callback(data):
log_latent(data["x"], "noisy_latent")
@ -138,6 +141,12 @@ class KDiffusionSampler(ImageSampler, ABC):
return samples
@torch.no_grad()
def noise_an_image(self, init_latent, t, sigmas, noise=None):
if isinstance(t, int):
t = torch.tensor([t], device=get_device())
t = t.clamp(0, 1000)
class DPMFastSampler(KDiffusionSampler):
short_name = SamplerName.K_DPM_FAST

@ -46,7 +46,7 @@ class PLMSSampler(ImageSampler):
orig_latent=None,
temperature=1.0,
noise_dropout=0.0,
initial_latent=None,
noise=None,
t_start=None,
quantize_denoised=False,
**kwargs,
@ -63,10 +63,10 @@ class PLMSSampler(ImageSampler):
ddim_discretize="uniform",
)
if initial_latent is None:
initial_latent = torch.randn(shape, device="cpu").to(self.device)
if noise is None:
noise = torch.randn(shape, device="cpu").to(self.device)
log_latent(initial_latent, "initial latent")
log_latent(noise, "initial noise")
timesteps = schedule.ddim_timesteps[:t_start]
@ -74,7 +74,14 @@ class PLMSSampler(ImageSampler):
total_steps = timesteps.shape[0]
old_eps = []
noisy_latent = initial_latent
if orig_latent is not None:
noisy_latent = self.noise_an_image(
init_latent=orig_latent, t=t_start, schedule=schedule, noise=noise
)
else:
noisy_latent = noise
mask_noise = None
if mask is not None:
mask_noise = torch.randn_like(noisy_latent, device="cpu").to(
@ -220,8 +227,8 @@ class PLMSSampler(ImageSampler):
@torch.no_grad()
def noise_an_image(self, init_latent, t, schedule, noise=None):
# fast, but does not allow for exact reconstruction
# t serves as an index to gather the correct alphas
if isinstance(t, int):
t = torch.tensor([t], device=get_device())
t = t.clamp(0, 1000)
sqrt_alphas_cumprod = torch.sqrt(schedule.ddim_alphas)
sqrt_one_minus_alphas_cumprod = schedule.ddim_sqrt_one_minus_alphas

Binary file not shown.

After

Width:  |  Height:  |  Size: 248 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 247 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 253 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 248 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 243 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 245 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 241 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 245 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 245 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 247 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 247 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 251 KiB

@ -129,6 +129,40 @@ def test_img_to_img_from_url_cats(
assert_image_similar_to_expectation(result.img, img_path=img_path, threshold=17000)
def test_img2img_low_noise(
filename_base_for_outputs,
sampler_type,
):
fruit_path = os.path.join(TESTS_FOLDER, "data", "bowl_of_fruit.jpg")
img = LazyLoadingImage(filepath=fruit_path)
prompt = ImaginePrompt(
"a white bowl filled with gold coins",
prompt_strength=12,
init_image=img,
init_image_strength=0.5,
mask_prompt="(fruit{*2} OR stem{*10} OR fruit stem{*3})",
mask_mode="replace",
# steps=40,
seed=1,
sampler_type=sampler_type,
)
result = next(imagine(prompt))
threshold_lookup = {
"k_dpm_2_a": 26000,
"k_euler_a": 18000,
"k_dpm_adaptive": 13000,
}
threshold = threshold_lookup.get(sampler_type, 14000)
img_path = f"{filename_base_for_outputs}.png"
assert_image_similar_to_expectation(
result.img, img_path=img_path, threshold=threshold
)
@pytest.mark.parametrize("init_strength", [0, 0.05, 0.2, 1])
def test_img_to_img_fruit_2_gold(
filename_base_for_outputs,

Loading…
Cancel
Save