mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-11-05 12:00:15 +00:00
feature: save gifs that show image generation process (#218)
This commit is contained in:
parent
8791e15bec
commit
9e06013ade
@ -16,6 +16,7 @@ from imaginairy.enhancers.face_restoration_codeformer import enhance_faces
|
|||||||
from imaginairy.enhancers.upscale_realesrgan import upscale_image
|
from imaginairy.enhancers.upscale_realesrgan import upscale_image
|
||||||
from imaginairy.img_utils import (
|
from imaginairy.img_utils import (
|
||||||
make_gif_image,
|
make_gif_image,
|
||||||
|
model_latents_to_pillow_imgs,
|
||||||
pillow_fit_image_within,
|
pillow_fit_image_within,
|
||||||
pillow_img_to_torch_image,
|
pillow_img_to_torch_image,
|
||||||
)
|
)
|
||||||
@ -63,7 +64,7 @@ def imagine_image_files(
|
|||||||
record_step_images=False,
|
record_step_images=False,
|
||||||
output_file_extension="jpg",
|
output_file_extension="jpg",
|
||||||
print_caption=False,
|
print_caption=False,
|
||||||
make_comparison_gif=False,
|
make_gif=False,
|
||||||
return_filename_type="generated",
|
return_filename_type="generated",
|
||||||
):
|
):
|
||||||
generated_imgs_path = os.path.join(outdir, "generated")
|
generated_imgs_path = os.path.join(outdir, "generated")
|
||||||
@ -84,6 +85,9 @@ def imagine_image_files(
|
|||||||
draw.text((10, 10), str(description))
|
draw.text((10, 10), str(description))
|
||||||
img.save(destination)
|
img.save(destination)
|
||||||
|
|
||||||
|
if make_gif:
|
||||||
|
for p in prompts:
|
||||||
|
p.collect_progress_latents = True
|
||||||
result_filenames = []
|
result_filenames = []
|
||||||
for result in imagine(
|
for result in imagine(
|
||||||
prompts,
|
prompts,
|
||||||
@ -113,18 +117,50 @@ def imagine_image_files(
|
|||||||
logger.info(f" [{image_type}] saved to: {filepath}")
|
logger.info(f" [{image_type}] saved to: {filepath}")
|
||||||
if image_type == return_filename_type:
|
if image_type == return_filename_type:
|
||||||
result_filenames.append(filepath)
|
result_filenames.append(filepath)
|
||||||
if make_comparison_gif and prompt.init_image:
|
|
||||||
|
if make_gif and result.progress_latents:
|
||||||
subpath = os.path.join(outdir, "gif")
|
subpath = os.path.join(outdir, "gif")
|
||||||
os.makedirs(subpath, exist_ok=True)
|
os.makedirs(subpath, exist_ok=True)
|
||||||
filepath = os.path.join(subpath, f"{basefilename}.gif")
|
filepath = os.path.join(subpath, f"{basefilename}.gif")
|
||||||
resized_init_image = pillow_fit_image_within(
|
|
||||||
prompt.init_image, prompt.width, prompt.height
|
transition_length = 1500
|
||||||
|
pause_length_ms = 500
|
||||||
|
max_fps = 20
|
||||||
|
max_frames = int(round(transition_length / 1000 * max_fps))
|
||||||
|
|
||||||
|
usable_latents = shrink_list(result.progress_latents, max_frames)
|
||||||
|
progress_imgs = [
|
||||||
|
model_latents_to_pillow_imgs(latent)[0] for latent in usable_latents
|
||||||
|
]
|
||||||
|
frames = (
|
||||||
|
progress_imgs
|
||||||
|
+ [result.images["generated"]]
|
||||||
|
+ list(reversed(progress_imgs))
|
||||||
)
|
)
|
||||||
|
progress_duration = int(round(300 / len(frames)))
|
||||||
|
min_duration = int(1000 / 20)
|
||||||
|
progress_duration = max(progress_duration, min_duration)
|
||||||
|
durations = (
|
||||||
|
[progress_duration] * len(progress_imgs)
|
||||||
|
+ [pause_length_ms]
|
||||||
|
+ [progress_duration] * len(progress_imgs)
|
||||||
|
)
|
||||||
|
assert len(frames) == len(durations)
|
||||||
|
if prompt.init_image:
|
||||||
|
resized_init_image = pillow_fit_image_within(
|
||||||
|
prompt.init_image, prompt.width, prompt.height
|
||||||
|
)
|
||||||
|
frames = [resized_init_image] + frames
|
||||||
|
durations = [pause_length_ms] + durations
|
||||||
|
else:
|
||||||
|
durations[0] = pause_length_ms
|
||||||
|
|
||||||
make_gif_image(
|
make_gif_image(
|
||||||
filepath,
|
filepath,
|
||||||
imgs=[result.images["generated"], resized_init_image],
|
imgs=frames,
|
||||||
duration=1750,
|
duration=durations,
|
||||||
)
|
)
|
||||||
|
logger.info(f" [gif] saved to: {filepath}")
|
||||||
base_count += 1
|
base_count += 1
|
||||||
del result
|
del result
|
||||||
|
|
||||||
@ -208,6 +244,11 @@ def _generate_single_image(
|
|||||||
for_inpainting=prompt.mask_image or prompt.mask_prompt or prompt.outpaint,
|
for_inpainting=prompt.mask_image or prompt.mask_prompt or prompt.outpaint,
|
||||||
)
|
)
|
||||||
has_depth_channel = hasattr(model, "depth_stage_key")
|
has_depth_channel = hasattr(model, "depth_stage_key")
|
||||||
|
progress_latents = []
|
||||||
|
|
||||||
|
def latent_logger(latents):
|
||||||
|
progress_latents.append(latents)
|
||||||
|
|
||||||
with ImageLoggingContext(
|
with ImageLoggingContext(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
model=model,
|
model=model,
|
||||||
@ -215,6 +256,9 @@ def _generate_single_image(
|
|||||||
progress_img_callback=progress_img_callback,
|
progress_img_callback=progress_img_callback,
|
||||||
progress_img_interval_steps=progress_img_interval_steps,
|
progress_img_interval_steps=progress_img_interval_steps,
|
||||||
progress_img_interval_min_s=progress_img_interval_min_s,
|
progress_img_interval_min_s=progress_img_interval_min_s,
|
||||||
|
progress_latent_callback=latent_logger
|
||||||
|
if prompt.collect_progress_latents
|
||||||
|
else None,
|
||||||
) as lc:
|
) as lc:
|
||||||
seed_everything(prompt.seed)
|
seed_everything(prompt.seed)
|
||||||
|
|
||||||
@ -480,6 +524,8 @@ def _generate_single_image(
|
|||||||
img,
|
img,
|
||||||
safety_mode=IMAGINAIRY_SAFETY_MODE,
|
safety_mode=IMAGINAIRY_SAFETY_MODE,
|
||||||
)
|
)
|
||||||
|
if safety_score.is_filtered:
|
||||||
|
progress_latents.clear()
|
||||||
if not safety_score.is_filtered:
|
if not safety_score.is_filtered:
|
||||||
if prompt.fix_faces:
|
if prompt.fix_faces:
|
||||||
logger.info("Fixing 😊 's in 🖼 using CodeFormer...")
|
logger.info("Fixing 😊 's in 🖼 using CodeFormer...")
|
||||||
@ -525,6 +571,7 @@ def _generate_single_image(
|
|||||||
mask_grayscale=mask_grayscale,
|
mask_grayscale=mask_grayscale,
|
||||||
depth_image=depth_image_display,
|
depth_image=depth_image_display,
|
||||||
timings=lc.get_timings(),
|
timings=lc.get_timings(),
|
||||||
|
progress_latents=progress_latents.copy(),
|
||||||
)
|
)
|
||||||
_most_recent_result = result
|
_most_recent_result = result
|
||||||
logger.info(f"Image Generated. Timings: {result.timings_str()}")
|
logger.info(f"Image Generated. Timings: {result.timings_str()}")
|
||||||
@ -542,3 +589,12 @@ def _prompts_to_embeddings(prompts, model):
|
|||||||
|
|
||||||
def prompt_normalized(prompt):
|
def prompt_normalized(prompt):
|
||||||
return re.sub(r"[^a-zA-Z0-9.,\[\]-]+", "_", prompt)[:130]
|
return re.sub(r"[^a-zA-Z0-9.,\[\]-]+", "_", prompt)[:130]
|
||||||
|
|
||||||
|
|
||||||
|
def shrink_list(items, max_size):
|
||||||
|
if len(items) <= max_size:
|
||||||
|
return items
|
||||||
|
num_to_remove = len(items) - max_size
|
||||||
|
interval = int(round(len(items) / num_to_remove))
|
||||||
|
|
||||||
|
return [val for i, val in enumerate(items) if i % interval != 0]
|
||||||
|
@ -223,7 +223,11 @@ logger = logging.getLogger(__name__)
|
|||||||
help="Print the version and exit.",
|
help="Print the version and exit.",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--gif", "make_gif", default=False, is_flag=True, help="Generate a gif of the edit."
|
"--gif",
|
||||||
|
"make_gif",
|
||||||
|
default=False,
|
||||||
|
is_flag=True,
|
||||||
|
help="Generate a gif of the generation.",
|
||||||
)
|
)
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
def imagine_cmd(
|
def imagine_cmd(
|
||||||
@ -708,7 +712,7 @@ def _imagine_cmd(
|
|||||||
output_file_extension="jpg",
|
output_file_extension="jpg",
|
||||||
print_caption=caption,
|
print_caption=caption,
|
||||||
precision=precision,
|
precision=precision,
|
||||||
make_comparison_gif=make_gif,
|
make_gif=make_gif,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -79,6 +79,7 @@ class ImageLoggingContext:
|
|||||||
progress_img_callback=None,
|
progress_img_callback=None,
|
||||||
progress_img_interval_steps=3,
|
progress_img_interval_steps=3,
|
||||||
progress_img_interval_min_s=0.1,
|
progress_img_interval_min_s=0.1,
|
||||||
|
progress_latent_callback=None,
|
||||||
):
|
):
|
||||||
self.prompt = prompt
|
self.prompt = prompt
|
||||||
self.model = model
|
self.model = model
|
||||||
@ -89,6 +90,7 @@ class ImageLoggingContext:
|
|||||||
self.progress_img_callback = progress_img_callback
|
self.progress_img_callback = progress_img_callback
|
||||||
self.progress_img_interval_steps = progress_img_interval_steps
|
self.progress_img_interval_steps = progress_img_interval_steps
|
||||||
self.progress_img_interval_min_s = progress_img_interval_min_s
|
self.progress_img_interval_min_s = progress_img_interval_min_s
|
||||||
|
self.progress_latent_callback = progress_latent_callback
|
||||||
|
|
||||||
self.start_ts = time.perf_counter()
|
self.start_ts = time.perf_counter()
|
||||||
self.timings = {}
|
self.timings = {}
|
||||||
@ -124,6 +126,8 @@ class ImageLoggingContext:
|
|||||||
from imaginairy.img_utils import model_latents_to_pillow_imgs # noqa
|
from imaginairy.img_utils import model_latents_to_pillow_imgs # noqa
|
||||||
|
|
||||||
if "predicted_latent" in description:
|
if "predicted_latent" in description:
|
||||||
|
if self.progress_latent_callback is not None:
|
||||||
|
self.progress_latent_callback(latents)
|
||||||
if (
|
if (
|
||||||
self.step_count - self.last_progress_img_step
|
self.step_count - self.last_progress_img_step
|
||||||
) > self.progress_img_interval_steps:
|
) > self.progress_img_interval_steps:
|
||||||
|
@ -116,6 +116,7 @@ class ImaginePrompt:
|
|||||||
model=config.DEFAULT_MODEL,
|
model=config.DEFAULT_MODEL,
|
||||||
model_config_path=None,
|
model_config_path=None,
|
||||||
is_intermediate=False,
|
is_intermediate=False,
|
||||||
|
collect_progress_latents=False,
|
||||||
):
|
):
|
||||||
|
|
||||||
self.prompts = self.process_prompt_input(prompt)
|
self.prompts = self.process_prompt_input(prompt)
|
||||||
@ -164,6 +165,7 @@ class ImaginePrompt:
|
|||||||
self.model_config_path = model_config_path
|
self.model_config_path = model_config_path
|
||||||
# we don't want to save intermediate images
|
# we don't want to save intermediate images
|
||||||
self.is_intermediate = is_intermediate
|
self.is_intermediate = is_intermediate
|
||||||
|
self.collect_progress_latents = collect_progress_latents
|
||||||
|
|
||||||
if self.height is None or self.width is None or self.steps is None:
|
if self.height is None or self.width is None or self.steps is None:
|
||||||
SamplerCls = SAMPLER_LOOKUP[self.sampler_type]
|
SamplerCls = SAMPLER_LOOKUP[self.sampler_type]
|
||||||
@ -263,6 +265,7 @@ class ImagineResult:
|
|||||||
mask_grayscale=None,
|
mask_grayscale=None,
|
||||||
depth_image=None,
|
depth_image=None,
|
||||||
timings=None,
|
timings=None,
|
||||||
|
progress_latents=None,
|
||||||
):
|
):
|
||||||
self.prompt = prompt
|
self.prompt = prompt
|
||||||
|
|
||||||
@ -284,6 +287,7 @@ class ImagineResult:
|
|||||||
self.images["depth_image"] = depth_image
|
self.images["depth_image"] = depth_image
|
||||||
|
|
||||||
self.timings = timings
|
self.timings = timings
|
||||||
|
self.progress_latents = progress_latents
|
||||||
|
|
||||||
# for backward compat
|
# for backward compat
|
||||||
self.img = img
|
self.img = img
|
||||||
|
@ -191,7 +191,7 @@ def create_surprise_me_images(
|
|||||||
record_step_images=False,
|
record_step_images=False,
|
||||||
output_file_extension="jpg",
|
output_file_extension="jpg",
|
||||||
print_caption=False,
|
print_caption=False,
|
||||||
make_comparison_gif=make_gif,
|
make_gif=make_gif,
|
||||||
)
|
)
|
||||||
if make_gif:
|
if make_gif:
|
||||||
imgs_path = os.path.join(outdir, "compilations")
|
imgs_path = os.path.join(outdir, "compilations")
|
||||||
|
Loading…
Reference in New Issue
Block a user