From a46424c673ef1057849153f79e3a5ee4dcaacf2d Mon Sep 17 00:00:00 2001 From: Bryce Date: Tue, 20 Sep 2022 08:42:00 -0700 Subject: [PATCH] feature: img2img now supported with PLMS (instead of just DDIM) Kinda hacky copy/pasting from ddim. Need to cleanup --- README.md | 30 +++++- imaginairy/api.py | 5 +- imaginairy/cmds.py | 3 - imaginairy/enhancers/describe_image_blip.py | 2 +- imaginairy/modules/diffusion/ddpm.py | 7 +- imaginairy/samplers/plms.py | 104 +++++++++++++++++++- tests/test_experiments.py | 39 +++++++- tests/test_imagine.py | 27 ++++- 8 files changed, 194 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index 81e12bf..bf4ca57 100644 --- a/README.md +++ b/README.md @@ -162,9 +162,27 @@ docker build . -t imaginairy docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface -v $HOME/.cache/torch:/root/.cache/torch -v `pwd`/outputs:/outputs imaginairy /bin/bash ``` -## Improvements from CompVis - - img2img actually does # of steps you specify +## ChangeLog + +**1.5.0** + - img2img now supported with PLMS (instead of just DDIM) + - added image captioning feature `aimg describe dog.jpg` => `a brown dog sitting on grass` + - added new commandline tool `aimg` for additional image manipulation functionality + +**1.4.0** + - support multiple additive targets for masking with `|` symbol. Example: "fruit|stem|fruit stem" + +**1.3.0** + - added prompt based image editing. Example: "fruit => gold coins" + - test coverage improved + +**1.2.0** + - allow urls as init-images + +** previous ** + - img2img actually does # of steps you specify - performance optimizations + - numerous other changes ## Models Used - CLIP - https://openai.com/blog/clip/ @@ -205,6 +223,9 @@ docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface - - ✅ realesrgan - ldm - https://github.com/lowfuel/progrock-stable + - stable super-res? + - todo: try with 1-0-0-0 mask at full image resolution (rencoding entire image+predicted image at every step) + - todo: use a gaussian pyramid and only include the "high-detail" level of the pyramid into the next step - ✅ face enhancers - ✅ gfpgan - https://github.com/TencentARC/GFPGAN - ✅ codeformer - https://github.com/sczhou/CodeFormer @@ -214,14 +235,15 @@ docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface - - https://github.com/pharmapsychotic/clip-interrogator (blip + clip) - https://github.com/KaiyangZhou/CoOp - outpainting - - inpainting + - ✅ inpainting - https://github.com/andreas128/RePaint - img2img but keeps img stable - https://www.reddit.com/r/StableDiffusion/comments/xboy90/a_better_way_of_doing_img2img_by_finding_the/ - https://gist.github.com/trygvebw/c71334dd127d537a15e9d59790f7f5e1 - https://github.com/pesser/stable-diffusion/commit/bbb52981460707963e2a62160890d7ecbce00e79 - CPU support - - img2img for plms? + - ✅ img2img for plms + - img2img for kdiff functions - images as actual prompts instead of just init images - requires model fine-tuning since SD1.4 expects 77x768 text encoding input - https://twitter.com/Buntworthy/status/1566744186153484288 diff --git a/imaginairy/api.py b/imaginairy/api.py index e25fec6..416fe93 100755 --- a/imaginairy/api.py +++ b/imaginairy/api.py @@ -215,8 +215,9 @@ def imagine( prompt.height // downsampling_factor, prompt.width // downsampling_factor, ] - if prompt.init_image: - sampler_type = "ddim" + if prompt.init_image and prompt.sampler_type not in ("ddim", "plms"): + sampler_type = "plms" + logger.info(" Sampler type switched to plms for img2img") else: sampler_type = prompt.sampler_type start_code = None diff --git a/imaginairy/cmds.py b/imaginairy/cmds.py index 3f6948d..0a29f4b 100644 --- a/imaginairy/cmds.py +++ b/imaginairy/cmds.py @@ -185,9 +185,6 @@ def imagine_cmd( logger.info( f"🤖🧠 imaginAIry received {len(prompt_texts)} prompt(s) and will repeat them {repeats} times to create {total_image_count} images." ) - if init_image and sampler_type != "ddim": - sampler_type = "ddim" - logger.info(" Sampler type switched to ddim for img2img") if init_image and init_image.startswith("http"): init_image = LazyLoadingImage(url=init_image) diff --git a/imaginairy/enhancers/describe_image_blip.py b/imaginairy/enhancers/describe_image_blip.py index 9be24f8..8881415 100644 --- a/imaginairy/enhancers/describe_image_blip.py +++ b/imaginairy/enhancers/describe_image_blip.py @@ -18,7 +18,7 @@ BLIP_EVAL_SIZE = 384 @lru_cache() def blip_model(): - from imaginairy import PKG_ROOT + from imaginairy import PKG_ROOT # noqa config_path = os.path.join( PKG_ROOT, "vendored", "blip", "configs", "med_config.json" diff --git a/imaginairy/modules/diffusion/ddpm.py b/imaginairy/modules/diffusion/ddpm.py index a16c699..337ddc8 100644 --- a/imaginairy/modules/diffusion/ddpm.py +++ b/imaginairy/modules/diffusion/ddpm.py @@ -255,7 +255,8 @@ class LatentDiffusion(DDPM): self.cond_stage_key = cond_stage_key try: self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 - except: + except: # noqa + logger.exception("Bad num downs?") self.num_downs = 0 if not scale_by_std: self.scale_factor = scale_factor @@ -639,7 +640,7 @@ class LatentDiffusion(DDPM): ks = self.split_input_params["ks"] # eg. (128, 128) stride = self.split_input_params["stride"] # eg. (64, 64) - h, w = x_noisy.shape[-2:] + h, w = x_noisy.shape[-2:] # noqa fold, unfold, normalization, weighting = self.get_fold_unfold( x_noisy, ks, stride @@ -711,7 +712,7 @@ class LatentDiffusion(DDPM): # tokenize crop coordinates for the bounding boxes of the respective patches patch_limits_tknzd = [ - torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[ + torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[ # noqa None ].to( # noqa self.device diff --git a/imaginairy/samplers/plms.py b/imaginairy/samplers/plms.py index 096819e..7e9bb1c 100644 --- a/imaginairy/samplers/plms.py +++ b/imaginairy/samplers/plms.py @@ -5,7 +5,9 @@ import numpy as np import torch from tqdm import tqdm +from imaginairy.img_log import log_latent from imaginairy.modules.diffusion.util import ( + extract_into_tensor, make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, @@ -172,7 +174,7 @@ class PLMSSampler: img = torch.randn(shape, device="cpu").to(device) else: img = x_T - + log_latent(img, "initial img") if timesteps is None: timesteps = ( self.ddpm_num_timesteps @@ -217,7 +219,7 @@ class PLMSSampler: ) # TODO: deterministic forward pass? img = img_orig * mask + (1.0 - mask) * img - outs = self.p_sample_plms( + img, pred_x0, e_t = self.p_sample_plms( img, cond, ts, @@ -233,7 +235,6 @@ class PLMSSampler: old_eps=old_eps, t_next=ts_next, ) - img, pred_x0, e_t = outs old_eps.append(e_t) if len(old_eps) >= 4: old_eps.pop(0) @@ -277,7 +278,11 @@ class PLMSSampler: t_in = torch.cat([t] * 2) c_in = torch.cat([unconditional_conditioning, c]) e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + log_latent(e_t_uncond, "noise pred uncond") + log_latent(e_t, "noise pred cond") + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + log_latent(e_t, "noise pred combined") if score_corrector is not None: assert self.model.parameterization == "eps" @@ -326,6 +331,7 @@ class PLMSSampler: return x_prev, pred_x0 e_t = get_model_output(x, t) + if len(old_eps) == 0: # Pseudo Improved Euler (2nd order) x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) @@ -344,5 +350,97 @@ class PLMSSampler: ) / 24 x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) + log_latent(x_prev, "x_prev") + log_latent(pred_x0, "pred_x0") return x_prev, pred_x0, e_t + + @torch.no_grad() + def stochastic_encode(self, init_latent, t, noise=None): + # fast, but does not allow for exact reconstruction + # t serves as an index to gather the correct alphas + + sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) + sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas + + if noise is None: + noise = torch.randn_like(init_latent, device="cpu").to(get_device()) + return ( + extract_into_tensor(sqrt_alphas_cumprod, t, init_latent.shape) * init_latent + + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, init_latent.shape) + * noise + ) + + @torch.no_grad() + def decode( + self, + x_latent, + cond, + t_start, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + img_callback=None, + score_corrector=None, + temperature=1.0, + mask=None, + orig_latent=None, + ): + + timesteps = self.ddim_timesteps[:t_start] + + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + + iterator = tqdm(time_range, desc="PLMS altering image", total=total_steps) + x_dec = x_latent + old_eps = [] + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full( + (x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long + ) + ts_next = torch.full( + (x_latent.shape[0],), + time_range[min(i + 1, len(time_range) - 1)], + device=x_latent.device, + dtype=torch.long, + ) + + if mask is not None: + assert orig_latent is not None + xdec_orig = self.model.q_sample(orig_latent, ts) + log_latent(xdec_orig, "xdec_orig") + log_latent(xdec_orig * mask, "masked_xdec_orig") + x_dec = xdec_orig * mask + (1.0 - mask) * x_dec + log_latent(x_dec, "x_dec") + + x_dec, pred_x0, e_t = self.p_sample_plms( + x_dec, + cond, + ts, + index=index, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + temperature=temperature, + old_eps=old_eps, + t_next=ts_next, + ) + # original_loss = ((x_dec - x_latent).abs().mean()*70) + # sigma_t = torch.full((1, 1, 1, 1), self.ddim_sigmas[index], device=get_device()) + # x_dec = x_dec.detach() + (original_loss * 0.1) ** 2 + # cond_grad = -torch.autograd.grad(original_loss, x_dec)[0] + # x_dec = x_dec.detach() + cond_grad * sigma_t ** 2 + ## x_dec_alt = x_dec + (original_loss * 0.1) ** 2 + + old_eps.append(e_t) + if len(old_eps) >= 4: + old_eps.pop(0) + + if img_callback: + img_callback(x_dec, "x_dec") + img_callback(pred_x0, "pred_x0") + + log_latent(x_dec, f"x_dec {i}") + log_latent(pred_x0, f"pred_x0 {i}") + return x_dec diff --git a/tests/test_experiments.py b/tests/test_experiments.py index 5a9768b..9dfa001 100644 --- a/tests/test_experiments.py +++ b/tests/test_experiments.py @@ -56,10 +56,8 @@ def experiment_step_repeats(): sampler.make_schedule(1000) img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/beach_at_sainte_adresse.jpg") - init_image, _, h = pillow_img_to_torch_image( + init_image, _, _ = pillow_img_to_torch_image( img, - max_height=512, - max_width=512, ) init_image = init_image.to(get_device()) init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) @@ -119,3 +117,38 @@ def experiment_repeated_img_2_img(): img = result.img os.makedirs(outdir, exist_ok=True) img.save(f"{outdir}/{step_num:04}.png") + + +def experiment_superresolution(): + """ + Try to trick it into making a superresolution image + + Did not work, resulting image was more blurry + + # i put this into the api.py file hardcoded + row_a = torch.tensor([1, 0]).repeat(32) + row_b = torch.tensor([0, 1]).repeat(32) + grid = torch.stack([row_a, row_b]).repeat(32, 1) + mask = grid + mask = mask.to(get_device()) + """ + + description = "a black and white photo of a dog's face" + # image was a quarter of existing image + img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/../outputs/dog02.jpg") + + # todo: try with 1000 mask at image resultion (rencoding entire image+predicted image at every step) + # todo: use a gaussian pyramid and only include the "high-detail" level of the pyramid into the next step + + prompt = ImaginePrompt( + description, + init_image=img, + init_image_strength=0.8, + width=512, + height=512, + steps=50, + seed=1, + sampler_type="DDIM", + ) + out_folder = f"{TESTS_FOLDER}/test_output" + imagine_image_files(prompt, outdir=out_folder) diff --git a/tests/test_imagine.py b/tests/test_imagine.py index a69ccd2..5090891 100644 --- a/tests/test_imagine.py +++ b/tests/test_imagine.py @@ -45,7 +45,23 @@ def test_imagine(sampler_type, expected_md5): assert result.md5() == expected_md5 -def test_img_to_img(): +device_sampler_type_test_cases_img_2_img = { + "mps:0": { + ("plms", "54656a7f449cb73b99436e61470172b3"), + ("ddim", "87d04423f6d03ddfc065cabc62e3909c"), + }, + "cuda": { + ("plms", "efba8b836b51d262dbf72284844869f8"), + ("ddim", "a62878000ad3b581a11dd3fb329dc7d2"), + }, +} +sampler_type_test_cases_img_2_img = device_sampler_type_test_cases_img_2_img[ + get_device() +] + + +@pytest.mark.parametrize("sampler_type,expected_md5", sampler_type_test_cases_img_2_img) +def test_img_to_img(sampler_type, expected_md5): prompt = ImaginePrompt( "a photo of a beach", init_image=f"{TESTS_FOLDER}/data/beach_at_sainte_adresse.jpg", @@ -54,10 +70,13 @@ def test_img_to_img(): height=512, steps=5, seed=1, - sampler_type="DDIM", + sampler_type=sampler_type, ) - out_folder = f"{TESTS_FOLDER}/test_output" - imagine_image_files(prompt, outdir=out_folder) + result = next(imagine(prompt)) + result.img.save( + f"{TESTS_FOLDER}/test_output/sampler_type_{sampler_type.upper()}_img2img_beach.jpg" + ) + assert result.md5() == expected_md5 def test_img_to_img_from_url():