feature: save gifs that show image generation process (#218)

This commit is contained in:
Bryce Drennan 2023-01-27 17:18:42 -08:00 committed by GitHub
parent 8791e15bec
commit 9e06013ade
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 77 additions and 9 deletions

View File

@ -16,6 +16,7 @@ from imaginairy.enhancers.face_restoration_codeformer import enhance_faces
from imaginairy.enhancers.upscale_realesrgan import upscale_image from imaginairy.enhancers.upscale_realesrgan import upscale_image
from imaginairy.img_utils import ( from imaginairy.img_utils import (
make_gif_image, make_gif_image,
model_latents_to_pillow_imgs,
pillow_fit_image_within, pillow_fit_image_within,
pillow_img_to_torch_image, pillow_img_to_torch_image,
) )
@ -63,7 +64,7 @@ def imagine_image_files(
record_step_images=False, record_step_images=False,
output_file_extension="jpg", output_file_extension="jpg",
print_caption=False, print_caption=False,
make_comparison_gif=False, make_gif=False,
return_filename_type="generated", return_filename_type="generated",
): ):
generated_imgs_path = os.path.join(outdir, "generated") generated_imgs_path = os.path.join(outdir, "generated")
@ -84,6 +85,9 @@ def imagine_image_files(
draw.text((10, 10), str(description)) draw.text((10, 10), str(description))
img.save(destination) img.save(destination)
if make_gif:
for p in prompts:
p.collect_progress_latents = True
result_filenames = [] result_filenames = []
for result in imagine( for result in imagine(
prompts, prompts,
@ -113,18 +117,50 @@ def imagine_image_files(
logger.info(f" [{image_type}] saved to: {filepath}") logger.info(f" [{image_type}] saved to: {filepath}")
if image_type == return_filename_type: if image_type == return_filename_type:
result_filenames.append(filepath) result_filenames.append(filepath)
if make_comparison_gif and prompt.init_image:
if make_gif and result.progress_latents:
subpath = os.path.join(outdir, "gif") subpath = os.path.join(outdir, "gif")
os.makedirs(subpath, exist_ok=True) os.makedirs(subpath, exist_ok=True)
filepath = os.path.join(subpath, f"{basefilename}.gif") filepath = os.path.join(subpath, f"{basefilename}.gif")
resized_init_image = pillow_fit_image_within(
prompt.init_image, prompt.width, prompt.height transition_length = 1500
pause_length_ms = 500
max_fps = 20
max_frames = int(round(transition_length / 1000 * max_fps))
usable_latents = shrink_list(result.progress_latents, max_frames)
progress_imgs = [
model_latents_to_pillow_imgs(latent)[0] for latent in usable_latents
]
frames = (
progress_imgs
+ [result.images["generated"]]
+ list(reversed(progress_imgs))
) )
progress_duration = int(round(300 / len(frames)))
min_duration = int(1000 / 20)
progress_duration = max(progress_duration, min_duration)
durations = (
[progress_duration] * len(progress_imgs)
+ [pause_length_ms]
+ [progress_duration] * len(progress_imgs)
)
assert len(frames) == len(durations)
if prompt.init_image:
resized_init_image = pillow_fit_image_within(
prompt.init_image, prompt.width, prompt.height
)
frames = [resized_init_image] + frames
durations = [pause_length_ms] + durations
else:
durations[0] = pause_length_ms
make_gif_image( make_gif_image(
filepath, filepath,
imgs=[result.images["generated"], resized_init_image], imgs=frames,
duration=1750, duration=durations,
) )
logger.info(f" [gif] saved to: {filepath}")
base_count += 1 base_count += 1
del result del result
@ -208,6 +244,11 @@ def _generate_single_image(
for_inpainting=prompt.mask_image or prompt.mask_prompt or prompt.outpaint, for_inpainting=prompt.mask_image or prompt.mask_prompt or prompt.outpaint,
) )
has_depth_channel = hasattr(model, "depth_stage_key") has_depth_channel = hasattr(model, "depth_stage_key")
progress_latents = []
def latent_logger(latents):
progress_latents.append(latents)
with ImageLoggingContext( with ImageLoggingContext(
prompt=prompt, prompt=prompt,
model=model, model=model,
@ -215,6 +256,9 @@ def _generate_single_image(
progress_img_callback=progress_img_callback, progress_img_callback=progress_img_callback,
progress_img_interval_steps=progress_img_interval_steps, progress_img_interval_steps=progress_img_interval_steps,
progress_img_interval_min_s=progress_img_interval_min_s, progress_img_interval_min_s=progress_img_interval_min_s,
progress_latent_callback=latent_logger
if prompt.collect_progress_latents
else None,
) as lc: ) as lc:
seed_everything(prompt.seed) seed_everything(prompt.seed)
@ -480,6 +524,8 @@ def _generate_single_image(
img, img,
safety_mode=IMAGINAIRY_SAFETY_MODE, safety_mode=IMAGINAIRY_SAFETY_MODE,
) )
if safety_score.is_filtered:
progress_latents.clear()
if not safety_score.is_filtered: if not safety_score.is_filtered:
if prompt.fix_faces: if prompt.fix_faces:
logger.info("Fixing 😊 's in 🖼 using CodeFormer...") logger.info("Fixing 😊 's in 🖼 using CodeFormer...")
@ -525,6 +571,7 @@ def _generate_single_image(
mask_grayscale=mask_grayscale, mask_grayscale=mask_grayscale,
depth_image=depth_image_display, depth_image=depth_image_display,
timings=lc.get_timings(), timings=lc.get_timings(),
progress_latents=progress_latents.copy(),
) )
_most_recent_result = result _most_recent_result = result
logger.info(f"Image Generated. Timings: {result.timings_str()}") logger.info(f"Image Generated. Timings: {result.timings_str()}")
@ -542,3 +589,12 @@ def _prompts_to_embeddings(prompts, model):
def prompt_normalized(prompt): def prompt_normalized(prompt):
return re.sub(r"[^a-zA-Z0-9.,\[\]-]+", "_", prompt)[:130] return re.sub(r"[^a-zA-Z0-9.,\[\]-]+", "_", prompt)[:130]
def shrink_list(items, max_size):
if len(items) <= max_size:
return items
num_to_remove = len(items) - max_size
interval = int(round(len(items) / num_to_remove))
return [val for i, val in enumerate(items) if i % interval != 0]

View File

@ -223,7 +223,11 @@ logger = logging.getLogger(__name__)
help="Print the version and exit.", help="Print the version and exit.",
) )
@click.option( @click.option(
"--gif", "make_gif", default=False, is_flag=True, help="Generate a gif of the edit." "--gif",
"make_gif",
default=False,
is_flag=True,
help="Generate a gif of the generation.",
) )
@click.pass_context @click.pass_context
def imagine_cmd( def imagine_cmd(
@ -708,7 +712,7 @@ def _imagine_cmd(
output_file_extension="jpg", output_file_extension="jpg",
print_caption=caption, print_caption=caption,
precision=precision, precision=precision,
make_comparison_gif=make_gif, make_gif=make_gif,
) )

View File

@ -79,6 +79,7 @@ class ImageLoggingContext:
progress_img_callback=None, progress_img_callback=None,
progress_img_interval_steps=3, progress_img_interval_steps=3,
progress_img_interval_min_s=0.1, progress_img_interval_min_s=0.1,
progress_latent_callback=None,
): ):
self.prompt = prompt self.prompt = prompt
self.model = model self.model = model
@ -89,6 +90,7 @@ class ImageLoggingContext:
self.progress_img_callback = progress_img_callback self.progress_img_callback = progress_img_callback
self.progress_img_interval_steps = progress_img_interval_steps self.progress_img_interval_steps = progress_img_interval_steps
self.progress_img_interval_min_s = progress_img_interval_min_s self.progress_img_interval_min_s = progress_img_interval_min_s
self.progress_latent_callback = progress_latent_callback
self.start_ts = time.perf_counter() self.start_ts = time.perf_counter()
self.timings = {} self.timings = {}
@ -124,6 +126,8 @@ class ImageLoggingContext:
from imaginairy.img_utils import model_latents_to_pillow_imgs # noqa from imaginairy.img_utils import model_latents_to_pillow_imgs # noqa
if "predicted_latent" in description: if "predicted_latent" in description:
if self.progress_latent_callback is not None:
self.progress_latent_callback(latents)
if ( if (
self.step_count - self.last_progress_img_step self.step_count - self.last_progress_img_step
) > self.progress_img_interval_steps: ) > self.progress_img_interval_steps:

View File

@ -116,6 +116,7 @@ class ImaginePrompt:
model=config.DEFAULT_MODEL, model=config.DEFAULT_MODEL,
model_config_path=None, model_config_path=None,
is_intermediate=False, is_intermediate=False,
collect_progress_latents=False,
): ):
self.prompts = self.process_prompt_input(prompt) self.prompts = self.process_prompt_input(prompt)
@ -164,6 +165,7 @@ class ImaginePrompt:
self.model_config_path = model_config_path self.model_config_path = model_config_path
# we don't want to save intermediate images # we don't want to save intermediate images
self.is_intermediate = is_intermediate self.is_intermediate = is_intermediate
self.collect_progress_latents = collect_progress_latents
if self.height is None or self.width is None or self.steps is None: if self.height is None or self.width is None or self.steps is None:
SamplerCls = SAMPLER_LOOKUP[self.sampler_type] SamplerCls = SAMPLER_LOOKUP[self.sampler_type]
@ -263,6 +265,7 @@ class ImagineResult:
mask_grayscale=None, mask_grayscale=None,
depth_image=None, depth_image=None,
timings=None, timings=None,
progress_latents=None,
): ):
self.prompt = prompt self.prompt = prompt
@ -284,6 +287,7 @@ class ImagineResult:
self.images["depth_image"] = depth_image self.images["depth_image"] = depth_image
self.timings = timings self.timings = timings
self.progress_latents = progress_latents
# for backward compat # for backward compat
self.img = img self.img = img

View File

@ -191,7 +191,7 @@ def create_surprise_me_images(
record_step_images=False, record_step_images=False,
output_file_extension="jpg", output_file_extension="jpg",
print_caption=False, print_caption=False,
make_comparison_gif=make_gif, make_gif=make_gif,
) )
if make_gif: if make_gif:
imgs_path = os.path.join(outdir, "compilations") imgs_path = os.path.join(outdir, "compilations")