From 9e06013adefb2b5bdedfd71b775f77c4fcf05a29 Mon Sep 17 00:00:00 2001 From: Bryce Drennan Date: Fri, 27 Jan 2023 17:18:42 -0800 Subject: [PATCH] feature: save gifs that show image generation process (#218) --- imaginairy/api.py | 68 +++++++++++++++++++++++++++++++++++---- imaginairy/cmds.py | 8 +++-- imaginairy/log_utils.py | 4 +++ imaginairy/schema.py | 4 +++ imaginairy/surprise_me.py | 2 +- 5 files changed, 77 insertions(+), 9 deletions(-) diff --git a/imaginairy/api.py b/imaginairy/api.py index aebf576..b42add4 100755 --- a/imaginairy/api.py +++ b/imaginairy/api.py @@ -16,6 +16,7 @@ from imaginairy.enhancers.face_restoration_codeformer import enhance_faces from imaginairy.enhancers.upscale_realesrgan import upscale_image from imaginairy.img_utils import ( make_gif_image, + model_latents_to_pillow_imgs, pillow_fit_image_within, pillow_img_to_torch_image, ) @@ -63,7 +64,7 @@ def imagine_image_files( record_step_images=False, output_file_extension="jpg", print_caption=False, - make_comparison_gif=False, + make_gif=False, return_filename_type="generated", ): generated_imgs_path = os.path.join(outdir, "generated") @@ -84,6 +85,9 @@ def imagine_image_files( draw.text((10, 10), str(description)) img.save(destination) + if make_gif: + for p in prompts: + p.collect_progress_latents = True result_filenames = [] for result in imagine( prompts, @@ -113,18 +117,50 @@ def imagine_image_files( logger.info(f" [{image_type}] saved to: {filepath}") if image_type == return_filename_type: result_filenames.append(filepath) - if make_comparison_gif and prompt.init_image: + + if make_gif and result.progress_latents: subpath = os.path.join(outdir, "gif") os.makedirs(subpath, exist_ok=True) filepath = os.path.join(subpath, f"{basefilename}.gif") - resized_init_image = pillow_fit_image_within( - prompt.init_image, prompt.width, prompt.height + + transition_length = 1500 + pause_length_ms = 500 + max_fps = 20 + max_frames = int(round(transition_length / 1000 * max_fps)) + + usable_latents = shrink_list(result.progress_latents, max_frames) + progress_imgs = [ + model_latents_to_pillow_imgs(latent)[0] for latent in usable_latents + ] + frames = ( + progress_imgs + + [result.images["generated"]] + + list(reversed(progress_imgs)) ) + progress_duration = int(round(300 / len(frames))) + min_duration = int(1000 / 20) + progress_duration = max(progress_duration, min_duration) + durations = ( + [progress_duration] * len(progress_imgs) + + [pause_length_ms] + + [progress_duration] * len(progress_imgs) + ) + assert len(frames) == len(durations) + if prompt.init_image: + resized_init_image = pillow_fit_image_within( + prompt.init_image, prompt.width, prompt.height + ) + frames = [resized_init_image] + frames + durations = [pause_length_ms] + durations + else: + durations[0] = pause_length_ms + make_gif_image( filepath, - imgs=[result.images["generated"], resized_init_image], - duration=1750, + imgs=frames, + duration=durations, ) + logger.info(f" [gif] saved to: {filepath}") base_count += 1 del result @@ -208,6 +244,11 @@ def _generate_single_image( for_inpainting=prompt.mask_image or prompt.mask_prompt or prompt.outpaint, ) has_depth_channel = hasattr(model, "depth_stage_key") + progress_latents = [] + + def latent_logger(latents): + progress_latents.append(latents) + with ImageLoggingContext( prompt=prompt, model=model, @@ -215,6 +256,9 @@ def _generate_single_image( progress_img_callback=progress_img_callback, progress_img_interval_steps=progress_img_interval_steps, progress_img_interval_min_s=progress_img_interval_min_s, + progress_latent_callback=latent_logger + if prompt.collect_progress_latents + else None, ) as lc: seed_everything(prompt.seed) @@ -480,6 +524,8 @@ def _generate_single_image( img, safety_mode=IMAGINAIRY_SAFETY_MODE, ) + if safety_score.is_filtered: + progress_latents.clear() if not safety_score.is_filtered: if prompt.fix_faces: logger.info("Fixing 😊 's in 🖼 using CodeFormer...") @@ -525,6 +571,7 @@ def _generate_single_image( mask_grayscale=mask_grayscale, depth_image=depth_image_display, timings=lc.get_timings(), + progress_latents=progress_latents.copy(), ) _most_recent_result = result logger.info(f"Image Generated. Timings: {result.timings_str()}") @@ -542,3 +589,12 @@ def _prompts_to_embeddings(prompts, model): def prompt_normalized(prompt): return re.sub(r"[^a-zA-Z0-9.,\[\]-]+", "_", prompt)[:130] + + +def shrink_list(items, max_size): + if len(items) <= max_size: + return items + num_to_remove = len(items) - max_size + interval = int(round(len(items) / num_to_remove)) + + return [val for i, val in enumerate(items) if i % interval != 0] diff --git a/imaginairy/cmds.py b/imaginairy/cmds.py index e8bbf9d..3f6b449 100644 --- a/imaginairy/cmds.py +++ b/imaginairy/cmds.py @@ -223,7 +223,11 @@ logger = logging.getLogger(__name__) help="Print the version and exit.", ) @click.option( - "--gif", "make_gif", default=False, is_flag=True, help="Generate a gif of the edit." + "--gif", + "make_gif", + default=False, + is_flag=True, + help="Generate a gif of the generation.", ) @click.pass_context def imagine_cmd( @@ -708,7 +712,7 @@ def _imagine_cmd( output_file_extension="jpg", print_caption=caption, precision=precision, - make_comparison_gif=make_gif, + make_gif=make_gif, ) diff --git a/imaginairy/log_utils.py b/imaginairy/log_utils.py index a48b905..e037a9b 100644 --- a/imaginairy/log_utils.py +++ b/imaginairy/log_utils.py @@ -79,6 +79,7 @@ class ImageLoggingContext: progress_img_callback=None, progress_img_interval_steps=3, progress_img_interval_min_s=0.1, + progress_latent_callback=None, ): self.prompt = prompt self.model = model @@ -89,6 +90,7 @@ class ImageLoggingContext: self.progress_img_callback = progress_img_callback self.progress_img_interval_steps = progress_img_interval_steps self.progress_img_interval_min_s = progress_img_interval_min_s + self.progress_latent_callback = progress_latent_callback self.start_ts = time.perf_counter() self.timings = {} @@ -124,6 +126,8 @@ class ImageLoggingContext: from imaginairy.img_utils import model_latents_to_pillow_imgs # noqa if "predicted_latent" in description: + if self.progress_latent_callback is not None: + self.progress_latent_callback(latents) if ( self.step_count - self.last_progress_img_step ) > self.progress_img_interval_steps: diff --git a/imaginairy/schema.py b/imaginairy/schema.py index a353312..88da801 100644 --- a/imaginairy/schema.py +++ b/imaginairy/schema.py @@ -116,6 +116,7 @@ class ImaginePrompt: model=config.DEFAULT_MODEL, model_config_path=None, is_intermediate=False, + collect_progress_latents=False, ): self.prompts = self.process_prompt_input(prompt) @@ -164,6 +165,7 @@ class ImaginePrompt: self.model_config_path = model_config_path # we don't want to save intermediate images self.is_intermediate = is_intermediate + self.collect_progress_latents = collect_progress_latents if self.height is None or self.width is None or self.steps is None: SamplerCls = SAMPLER_LOOKUP[self.sampler_type] @@ -263,6 +265,7 @@ class ImagineResult: mask_grayscale=None, depth_image=None, timings=None, + progress_latents=None, ): self.prompt = prompt @@ -284,6 +287,7 @@ class ImagineResult: self.images["depth_image"] = depth_image self.timings = timings + self.progress_latents = progress_latents # for backward compat self.img = img diff --git a/imaginairy/surprise_me.py b/imaginairy/surprise_me.py index e3d4bf9..e792560 100644 --- a/imaginairy/surprise_me.py +++ b/imaginairy/surprise_me.py @@ -191,7 +191,7 @@ def create_surprise_me_images( record_step_images=False, output_file_extension="jpg", print_caption=False, - make_comparison_gif=make_gif, + make_gif=make_gif, ) if make_gif: imgs_path = os.path.join(outdir, "compilations")