From fd2ed3211527e722bbd5073518f820f296adf66c Mon Sep 17 00:00:00 2001 From: jaydrennan Date: Fri, 1 Mar 2024 16:55:42 -0800 Subject: [PATCH] feature: makes generating videos more programmatic. the generate_video function previously involved logic for saving the new file and wouldn't return anything. now it will return a list of the generated samples. --- imaginairy/api/generate.py | 5 +- imaginairy/api/video_sample.py | 103 +++++++++----------------------- imaginairy/cli/videogen.py | 106 ++++++++++++++++++++++++++++----- 3 files changed, 122 insertions(+), 92 deletions(-) diff --git a/imaginairy/api/generate.py b/imaginairy/api/generate.py index 3806b34..72724da 100755 --- a/imaginairy/api/generate.py +++ b/imaginairy/api/generate.py @@ -58,6 +58,7 @@ def imagine_image_files( from PIL import ImageDraw from imaginairy.api.video_sample import generate_video + from imaginairy.schema import LazyLoadingImage from imaginairy.utils import get_next_filenumber, prompt_normalized from imaginairy.utils.animations import make_bounce_animation from imaginairy.utils.img_utils import pillow_fit_image_within @@ -116,9 +117,11 @@ def imagine_image_files( if image_type == return_filename_type: result_filenames.append(filepath) if videogen: + # neeeds to be updated. try: + images = [LazyLoadingImage(filepath=filepath)] generate_video( - input_path=filepath, + input_images=images, ) except FileNotFoundError as e: logger.error(str(e)) diff --git a/imaginairy/api/video_sample.py b/imaginairy/api/video_sample.py index 0d3e17a..439fb5a 100644 --- a/imaginairy/api/video_sample.py +++ b/imaginairy/api/video_sample.py @@ -4,11 +4,8 @@ import logging import math import os import random -import re import time -from glob import glob -from pathlib import Path -from typing import Any, Optional +from typing import Any, List, Optional import cv2 import numpy as np @@ -36,8 +33,7 @@ logger = logging.getLogger(__name__) def generate_video( - input_path: str, # Can either be image file or folder with image files - output_folder: str | None = None, + input_images: List[LazyLoadingImage], size=(1024, 576), num_frames: int = 6, num_steps: int = 30, @@ -50,13 +46,12 @@ def generate_video( decoding_t: int = 1, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary. device: Optional[str] = None, repetitions=1, - output_format="webp", ): """ Generates a video from a single image or multiple images, conditioned on the provided input_path. Args: - input_path (str): Path to an image file or a directory containing image files. + input_images (List[LazyLoadingImage]): List of LazyLoading images to be transformed into videos output_folder (str | None, optional): Directory where the generated video will be saved. Defaults to "outputs/video/" if None. num_frames (int, optional): Number of frames in the generated video. Defaults to 6. @@ -101,8 +96,7 @@ def generate_video( num_frames = default(num_frames, video_model_config.defaults.get("frames", 12)) num_steps = default(num_steps, video_model_config.defaults.get("steps", 30)) - output_folder_str = default(output_folder, "outputs/video/") - del output_folder + video_config_path = f"{PKG_ROOT}/{video_model_config.architecture.config_path}" model, safety_filter = load_model( @@ -113,58 +107,35 @@ def generate_video( weights_url=video_model_config.weights_location, ) - if input_path.startswith("http"): - all_img_paths = [input_path] - else: - path = Path(input_path) - if path.is_file(): - if any(input_path.endswith(x) for x in ["jpg", "jpeg", "png"]): - all_img_paths = [input_path] - else: - raise ValueError("Path is not valid image file.") - elif path.is_dir(): - all_img_paths = sorted( - [ - str(f) - for f in path.iterdir() - if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"] - ] - ) - if len(all_img_paths) == 0: - raise ValueError("Folder does not contain any images.") - else: - msg = f"Could not find file or folder at {input_path}" - raise FileNotFoundError(msg) - expected_size = (vid_width, vid_height) + all_samples = [] for _ in range(repetitions): - for input_path in all_img_paths: + for image in input_images: start_time = time.perf_counter() _seed = default(seed, random.randint(0, 1000000)) torch.manual_seed(_seed) logger.info( - f"Generating a {num_frames} frame video from {input_path}. Device:{device} seed:{_seed}" + f"Generating a {num_frames} frame video from {image}. Device:{device} seed:{_seed}" ) - if input_path.startswith("http"): - image = LazyLoadingImage(url=input_path).as_pillow() - else: - image = LazyLoadingImage(filepath=input_path).as_pillow() + + pil_image = image.as_pillow() + crop_coords = None if image.mode == "RGBA": - image = image.convert("RGB") + pil_image = image.convert("RGB") if image.size != expected_size: logger.info( f"Resizing image from {image.size} to {expected_size}. (w, h)" ) - image = pillow_fit_image_within( - image, max_height=expected_size[1], max_width=expected_size[0] + pil_image = pillow_fit_image_within( + pil_image, max_height=expected_size[1], max_width=expected_size[0] ) logger.debug(f"Image is now of size: {image.size}") background = Image.new("RGB", expected_size, "white") # Calculate the position to center the original image x = (background.width - image.width) // 2 y = (background.height - image.height) // 2 - background.paste(image, (x, y)) + background.paste(pil_image, (x, y)) # crop_coords = (x, y, x + image.width, y + image.height) # image = background @@ -173,17 +144,17 @@ def generate_video( if h % snap_to != 0 or w % snap_to != 0: width = w - w % snap_to height = h - h % snap_to - image = image.resize((width, height)) + pil_image = pil_image.resize((width, height)) logger.warning( f"Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!" ) - image = ToTensor()(image) - image = image * 2.0 - 1.0 + tensor_image = ToTensor()(pil_image) + tensor_image = tensor_image * 2.0 - 1.0 - image = image.unsqueeze(0).to(device) - H, W = image.shape[2:] - assert image.shape[1] == 3 + tensor_image = tensor_image.unsqueeze(0).to(device) + H, W = tensor_image.shape[2:] + assert tensor_image.shape[1] == 3 F = 8 C = 4 shape = (num_frames, C, H // F, W // F) @@ -210,8 +181,10 @@ def generate_video( value_dict["motion_bucket_id"] = motion_bucket_id value_dict["fps_id"] = fps_id value_dict["cond_aug"] = cond_aug - value_dict["cond_frames_without_noise"] = image - value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image) + value_dict["cond_frames_without_noise"] = tensor_image + value_dict["cond_frames"] = tensor_image + cond_aug * torch.randn_like( + tensor_image + ) with torch.no_grad(), platform_appropriate_autocast(): reload_model(model.conditioner, device=device) @@ -275,21 +248,16 @@ def generate_video( left, upper, right, lower = crop_coords samples = samples[:, :, upper:lower, left:right] - os.makedirs(output_folder_str, exist_ok=True) - base_count = len(glob(os.path.join(output_folder_str, "*.*"))) + 1 - source_slug = make_safe_filename(input_path) - video_filename = f"{base_count:06d}_{model_name}_{_seed}_{fps_id}fps_{source_slug}.{output_format}" - video_path = os.path.join(output_folder_str, video_filename) - samples = safety_filter(samples) - # save_video(samples, video_path, output_fps) - save_video_bounce(samples, video_path, output_fps) + all_samples.append(samples) duration = time.perf_counter() - start_time logger.info( - f"Video of {num_frames} frames generated in {duration:.2f} seconds and saved to {video_path}\n" + f"Video of {num_frames} frames generated in {duration:.2f} seconds\n" ) + return all_samples, output_fps + def save_video(samples: torch.Tensor, video_filename: str, output_fps: int): """ @@ -458,18 +426,3 @@ def pillow_fit_image_within( if (w, h) != image.size: image = image.resize((w, h), resample=Image.Resampling.LANCZOS) return image - - -def make_safe_filename(input_string): - stripped_url = re.sub(r"^https?://[^/]+/", "", input_string) - - # Remove directory path if present - base_name = os.path.basename(stripped_url) - - # Remove file extension - name_without_extension = os.path.splitext(base_name)[0] - - # Keep only alphanumeric characters and dashes - safe_name = re.sub(r"[^a-zA-Z0-9\-]", "", name_without_extension) - - return safe_name diff --git a/imaginairy/cli/videogen.py b/imaginairy/cli/videogen.py index 9f3a45a..05ec894 100644 --- a/imaginairy/cli/videogen.py +++ b/imaginairy/cli/videogen.py @@ -85,29 +85,103 @@ def videogen_cmd( aimg videogen --start-image assets/rocket-wide.png """ + import os + from glob import glob + from imaginairy.api.video_sample import generate_video + from imaginairy.utils import default from imaginairy.utils.log_utils import configure_logging configure_logging() output_fps = output_fps or fps + + all_images = [] + try: - generate_video( - input_path=start_image, - num_frames=num_frames, - num_steps=steps, - model_name=model, - fps_id=fps, - size=size, - output_fps=output_fps, - output_format=output_format, - motion_bucket_id=motion_amount, - cond_aug=cond_aug, - seed=seed, - decoding_t=decoding_t, - output_folder=output_folder, - repetitions=repeats, - ) + all_images.extend(load_images(start_image)) except FileNotFoundError as e: logger.error(str(e)) exit(1) + + output_folder_str = default(output_folder, "outputs/video/") + + os.makedirs(output_folder_str, exist_ok=True) + + samples, output_fps = generate_video( + input_images=all_images, + num_frames=num_frames, + num_steps=steps, + model_name=model, + fps_id=fps, + size=size, + output_fps=output_fps, + motion_bucket_id=motion_amount, + cond_aug=cond_aug, + seed=seed, + decoding_t=decoding_t, + repetitions=repeats, + ) + + for sample in samples: + base_count = len(glob(os.path.join(output_folder_str, "*.*"))) + 1 + source_slug = make_safe_filename(sample) + video_filename = ( + f"{base_count:06d}_{model}_{seed}_{fps}fps_{source_slug}.{output_format}" + ) + video_path = os.path.join(output_folder_str, video_filename) + + from imaginairy.api.video_sample import save_video_bounce + + save_video_bounce(samples, video_path, output_fps) + + +def load_images(start_image): + from pathlib import Path + + from imaginairy.schema import LazyLoadingImage + + if start_image.startswith("http"): + image = LazyLoadingImage(url=start_image).as_pillow() + return [image] + else: + path = Path(start_image) + if path.is_file(): + if any(start_image.endswith(x) for x in ["jpg", "jpeg", "png"]): + return [LazyLoadingImage(filepath=start_image).as_pillow()] + else: + raise ValueError("Path is not a valid image file.") + elif path.is_dir(): + all_img_paths = sorted( + [ + str(f) + for f in path.iterdir() + if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"] + ] + ) + if len(all_img_paths) == 0: + raise ValueError("Folder does not contain any images.") + return [ + LazyLoadingImage(filepath=image).as_pillow() for image in all_img_paths + ] + else: + msg = f"Could not find file or folder at {start_image}" + raise FileNotFoundError(msg) + + +def make_safe_filename(input_string): + import os + import re + + stripped_url = re.sub(r"^https?://[^/]+/", "", input_string) + + # Remove directory path if present + base_name = os.path.basename(stripped_url) + + # Remove file extension + name_without_extension = os.path.splitext(base_name)[0] + + # Keep only alphanumeric characters and dashes + safe_name = re.sub(r"[^a-zA-Z0-9\-]", "", name_without_extension) + + return safe_name