diff --git a/imaginairy/api/video_sample.py b/imaginairy/api/video_sample.py index 8c63e8f..7e17b33 100644 --- a/imaginairy/api/video_sample.py +++ b/imaginairy/api/video_sample.py @@ -26,7 +26,9 @@ from imaginairy.utils import ( instantiate_from_config, platform_appropriate_autocast, ) +from imaginairy.utils.animations import make_bounce_animation from imaginairy.utils.model_manager import get_cached_url_path +from imaginairy.utils.named_resolutions import normalize_image_size from imaginairy.utils.paths import PKG_ROOT logger = logging.getLogger(__name__) @@ -35,6 +37,7 @@ logger = logging.getLogger(__name__) def generate_video( input_path: str, # Can either be image file or folder with image files output_folder: str | None = None, + size=(1024, 576), num_frames: int = 6, num_steps: int = 30, model_name: str = "svd-xt", @@ -46,6 +49,7 @@ def generate_video( decoding_t: int = 1, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary. device: Optional[str] = None, repetitions=1, + output_format="webp", ): """ Generates a video from a single image or multiple images, conditioned on the provided input_path. @@ -71,7 +75,7 @@ def generate_video( None: The function saves the generated video(s) to the specified output folder. """ device = default(device, get_device) - + vid_width, vid_height = normalize_image_size(size) if device == "mps": msg = "Apple Silicon MPS (M1, M2, etc) is not currently supported for video generation. Switching to cpu generation." logger.warning(msg) @@ -88,7 +92,6 @@ def generate_video( logger.warning(msg) start_time = time.perf_counter() - seed = default(seed, random.randint(0, 1000000)) output_fps = default(output_fps, fps_id) video_model_config = config.MODEL_WEIGHT_CONFIG_LOOKUP.get(model_name, None) @@ -102,9 +105,6 @@ def generate_video( del output_folder video_config_path = f"{PKG_ROOT}/{video_model_config.architecture.config_path}" - logger.info( - f"Generating a {num_frames} frame video from {input_path}. Device:{device} seed:{seed}" - ) model, safety_filter = load_model( config=video_config_path, device="cpu", @@ -112,7 +112,6 @@ def generate_video( num_steps=num_steps, weights_url=video_model_config.weights_location, ) - torch.manual_seed(seed) if input_path.startswith("http"): all_img_paths = [input_path] @@ -137,9 +136,14 @@ def generate_video( msg = f"Could not find file or folder at {input_path}" raise FileNotFoundError(msg) - expected_size = (1024, 576) + expected_size = (vid_width, vid_height) for _ in range(repetitions): for input_path in all_img_paths: + _seed = default(seed, random.randint(0, 1000000)) + torch.manual_seed(_seed) + logger.info( + f"Generating a {num_frames} frame video from {input_path}. Device:{device} seed:{_seed}" + ) if input_path.startswith("http"): image = LazyLoadingImage(url=input_path).as_pillow() else: @@ -207,7 +211,6 @@ def generate_video( value_dict["cond_aug"] = cond_aug value_dict["cond_frames_without_noise"] = image value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image) - value_dict["cond_aug"] = cond_aug with torch.no_grad(), platform_appropriate_autocast(): reload_model(model.conditioner, device=device) @@ -272,30 +275,14 @@ def generate_video( samples = samples[:, :, upper:lower, left:right] os.makedirs(output_folder_str, exist_ok=True) - base_count = len(glob(os.path.join(output_folder_str, "*.mp4"))) + 1 + base_count = len(glob(os.path.join(output_folder_str, "*.*"))) + 1 source_slug = make_safe_filename(input_path) - video_filename = f"{base_count:06d}_{model_name}_{seed}_{fps_id}fps_{source_slug}.mp4" + video_filename = f"{base_count:06d}_{model_name}_{_seed}_{fps_id}fps_{source_slug}.{output_format}" video_path = os.path.join(output_folder_str, video_filename) - writer = cv2.VideoWriter( - video_path, - cv2.VideoWriter_fourcc(*"MP4V"), # type: ignore - output_fps, - (samples.shape[-1], samples.shape[-2]), - ) samples = safety_filter(samples) - vid = ( - (rearrange(samples, "t c h w -> t h w c") * 255) - .cpu() - .numpy() - .astype(np.uint8) - ) - for frame in vid: - frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) - writer.write(frame) - writer.release() - video_path_h264 = video_path[:-4] + "_h264.mp4" - os.system(f"ffmpeg -i {video_path} -c:v libx264 {video_path_h264}") + # save_video(samples, video_path, output_fps) + save_video_bounce(samples, video_path, output_fps) duration = time.perf_counter() - start_time logger.info( @@ -303,6 +290,46 @@ def generate_video( ) +def save_video(samples: torch.Tensor, video_filename: str, output_fps: int): + """ + Saves a video from given tensor samples. + + Args: + samples (torch.Tensor): Tensor containing video frame data. + video_filename (str): The full path and filename where the video will be saved. + output_fps (int): Frames per second for the output video. + safety_filter (Callable[[torch.Tensor], torch.Tensor]): A function to apply a safety filter to the samples. + + Returns: + str: The path to the saved video. + """ + vid = (torch.permute(samples, (0, 2, 3, 1)) * 255).cpu().numpy().astype(np.uint8) + writer = cv2.VideoWriter( + video_filename, + cv2.VideoWriter_fourcc(*"MP4V"), # type: ignore + output_fps, + (samples.shape[-1], samples.shape[-2]), + ) + for frame in vid: + frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) + writer.write(frame) + writer.release() + video_path_h264 = video_filename[:-4] + "_h264.mp4" + os.system(f"ffmpeg -i {video_filename} -c:v libx264 {video_path_h264}") + + +def save_video_bounce(samples: torch.Tensor, video_filename: str, output_fps: int): + frames_np = ( + (torch.permute(samples, (0, 2, 3, 1)) * 255).cpu().numpy().astype(np.uint8) + ) + + make_bounce_animation( + imgs=[Image.fromarray(frame) for frame in frames_np], + outpath=video_filename, + end_pause_duration_ms=750, + ) + + def get_unique_embedder_keys_from_conditioner(conditioner): return list({x.input_key for x in conditioner.embedders}) diff --git a/imaginairy/cli/videogen.py b/imaginairy/cli/videogen.py index 2885aa9..9f3a45a 100644 --- a/imaginairy/cli/videogen.py +++ b/imaginairy/cli/videogen.py @@ -25,7 +25,20 @@ logger = logging.getLogger(__name__) @click.option( "--fps", default=6, type=int, help="FPS for the AI to target when generating video" ) +@click.option( + "--size", + default="1024,576", + show_default=True, + type=str, + help="Video dimensions. Can be a named size, single integer, or WIDTHxHEIGHT pair. Should be multiple of 8. Examples: SVD, 512x512, 4k, UHD, 8k, 512, 1080p", +) @click.option("--output-fps", default=None, type=int, help="FPS for the output video") +@click.option( + "--output-format", + default="webp", + help="Output video format", + type=click.Choice(["webp", "mp4", "gif"]), +) @click.option( "--motion-amount", default=127, @@ -54,7 +67,9 @@ def videogen_cmd( steps, model, fps, + size, output_fps, + output_format, motion_amount, repeats, cond_aug, @@ -83,7 +98,9 @@ def videogen_cmd( num_steps=steps, model_name=model, fps_id=fps, + size=size, output_fps=output_fps, + output_format=output_format, motion_bucket_id=motion_amount, cond_aug=cond_aug, seed=seed, diff --git a/imaginairy/utils/animations.py b/imaginairy/utils/animations.py index a6b1aeb..361a102 100644 --- a/imaginairy/utils/animations.py +++ b/imaginairy/utils/animations.py @@ -1,5 +1,6 @@ """Functions for creating animations from images.""" import os.path +from typing import TYPE_CHECKING, List import cv2 import torch @@ -12,18 +13,24 @@ from imaginairy.utils.img_utils import ( pillow_img_to_opencv_img, ) +if TYPE_CHECKING: + from PIL import Image + + from imaginairy.utils.img_utils import LazyLoadingImage + def make_bounce_animation( - imgs, - outpath, + imgs: "List[Image.Image | LazyLoadingImage | torch.Tensor]", + outpath: str, transition_duration_ms=500, start_pause_duration_ms=1000, end_pause_duration_ms=2000, + max_fps=20, ): first_img = imgs[0] - last_img = imgs[-1] middle_imgs = imgs[1:-1] - max_fps = 20 + last_img = imgs[-1] + max_frames = int(round(transition_duration_ms / 1000 * max_fps)) min_duration = int(1000 / 20) if middle_imgs: @@ -37,20 +44,8 @@ def make_bounce_animation( frames = [first_img, *middle_imgs, last_img, *list(reversed(middle_imgs))] # convert from latents - converted_frames = [] - - for frame in frames: - if isinstance(frame, torch.Tensor): - frame = model_latents_to_pillow_imgs(frame)[0] - converted_frames.append(frame) - frames = converted_frames - max_size = max([frame.size for frame in frames]) - converted_frames = [] - for frame in frames: - if frame.size != max_size: - frame = frame.resize(max_size) - converted_frames.append(frame) - frames = converted_frames + converted_frames = _ensure_pillow_images(frames) + converted_frames = _ensure_images_same_size(converted_frames) durations = ( [start_pause_duration_ms] @@ -59,7 +54,29 @@ def make_bounce_animation( + [progress_duration] * len(middle_imgs) ) - make_animation(imgs=frames, outpath=outpath, frame_duration_ms=durations) + make_animation(imgs=converted_frames, outpath=outpath, frame_duration_ms=durations) + + +def _ensure_pillow_images( + imgs: "List[Image.Image | LazyLoadingImage | torch.Tensor]", +) -> "List[Image.Image]": + converted_frames: "List[Image.Image]" = [] + for frame in imgs: + if isinstance(frame, torch.Tensor): + converted_frames.append(model_latents_to_pillow_imgs(frame)[0]) + else: + converted_frames.append(frame) # type: ignore + return converted_frames + + +def _ensure_images_same_size(imgs: "List[Image.Image]") -> "List[Image.Image]": + max_size = max([frame.size for frame in imgs]) + converted_frames = [] + for frame in imgs: + if frame.size != max_size: + frame = frame.resize(max_size) + converted_frames.append(frame) + return converted_frames def make_slideshow_animation( @@ -79,7 +96,9 @@ def make_slideshow_animation( make_animation(imgs=converted_frames, outpath=outpath, frame_duration_ms=durations) -def make_animation(imgs, outpath, frame_duration_ms=100, captions=None): +def make_animation( + imgs, outpath, frame_duration_ms: int | List[int] = 100, captions=None +): imgs = imgpaths_to_imgs(imgs) ext = os.path.splitext(outpath)[1].lower().strip(".") @@ -89,7 +108,7 @@ def make_animation(imgs, outpath, frame_duration_ms=100, captions=None): for img, caption in zip(imgs, captions): add_caption_to_image(img, caption) - if ext == "gif": + if ext == "gif" or ext == "webp": make_gif_animation( imgs=imgs, outpath=outpath, frame_duration_ms=frame_duration_ms )