feature: videogen improvements

- better filenames
- allow urls as image inputs
- better memory efficiency
- add timing information
pull/404/head
Bryce 6 months ago
parent ab75c49b15
commit c24ed1f33d

@ -190,6 +190,10 @@ class LazyLoadingImage:
self._load_img()
return self.save_image_as_base64(self._img) # type: ignore
def as_pillow(self):
self._load_img()
return self._img
def __str__(self):
return self.as_base64()

@ -2,6 +2,7 @@ import logging
import math
import os
import random
import time
from glob import glob
from pathlib import Path
from typing import Optional
@ -11,10 +12,9 @@ import numpy as np
import torch
from einops import rearrange, repeat
from omegaconf import OmegaConf
from PIL import Image
from torchvision.transforms import ToTensor
from imaginairy import config
from imaginairy import LazyLoadingImage, config
from imaginairy.model_manager import get_cached_url_path
from imaginairy.paths import PKG_ROOT
from imaginairy.utils import (
@ -45,6 +45,7 @@ def generate_video(
Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each
image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`.
"""
start_time = time.perf_counter()
device = default(device, get_device)
seed = default(seed, random.randint(0, 1000000))
output_fps = default(output_fps, fps_id)
@ -64,48 +65,51 @@ def generate_video(
model, safety_filter = load_model(
config=video_config_path,
device=device,
device="cpu",
num_frames=num_frames,
num_steps=num_steps,
weights_url=video_model_config["weights_url"],
)
torch.manual_seed(seed)
path = Path(input_path)
all_img_paths = []
if path.is_file():
if any(input_path.endswith(x) for x in ["jpg", "jpeg", "png"]):
all_img_paths = [input_path]
else:
raise ValueError("Path is not valid image file.")
elif path.is_dir():
all_img_paths = sorted(
[
f
for f in path.iterdir()
if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"]
]
)
if len(all_img_paths) == 0:
raise ValueError("Folder does not contain any images.")
if input_path.startswith("http"):
input_images = [LazyLoadingImage(url=input_path)]
else:
raise ValueError
for input_img_path in all_img_paths:
with Image.open(input_img_path) as image:
if image.mode == "RGBA":
image = image.convert("RGB")
w, h = image.size
if h % 64 != 0 or w % 64 != 0:
width, height = (x - x % 64 for x in (w, h))
image = image.resize((width, height))
logger.info(
f"Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!"
)
path = Path(input_path)
all_img_paths = []
if path.is_file():
if any(input_path.endswith(x) for x in ["jpg", "jpeg", "png"]):
all_img_paths = [input_path]
else:
raise ValueError("Path is not valid image file.")
elif path.is_dir():
all_img_paths = sorted(
[
f
for f in path.iterdir()
if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"]
]
)
if len(all_img_paths) == 0:
raise ValueError("Folder does not contain any images.")
else:
raise ValueError
input_images = [LazyLoadingImage(filepath=str(x)) for x in all_img_paths]
for image in input_images:
image = image.as_pillow()
if image.mode == "RGBA":
image = image.convert("RGB")
w, h = image.size
if h % 64 != 0 or w % 64 != 0:
width, height = (x - x % 64 for x in (w, h))
image = image.resize((width, height))
logger.info(
f"Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!"
)
image = ToTensor()(image)
image = image * 2.0 - 1.0
image = ToTensor()(image)
image = image * 2.0 - 1.0
image = image.unsqueeze(0).to(device)
H, W = image.shape[2:]
@ -180,13 +184,17 @@ def generate_video(
samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)
unload_model(model.model)
unload_model(model.denoiser)
reload_model(model.first_stage_model)
model.en_and_decode_n_samples_a_time = decoding_t
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
unload_model(model.first_stage_model)
os.makedirs(output_folder, exist_ok=True)
base_count = len(glob(os.path.join(output_folder, "*.mp4")))
video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
base_count = len(glob(os.path.join(output_folder, "*.mp4"))) + 1
video_filename = f"{base_count:06d}_{model_name}_{seed}.mp4"
video_path = os.path.join(output_folder, video_filename)
writer = cv2.VideoWriter(
video_path,
cv2.VideoWriter_fourcc(*"MP4V"),
@ -209,7 +217,10 @@ def generate_video(
peak_memory_usage = torch.cuda.max_memory_allocated()
msg = f"Peak memory usage: {peak_memory_usage / (1024 ** 2)} MB"
logger.info(msg)
logger.info(f"Video saved to {video_path}\n")
duration = time.perf_counter() - start_time
logger.info(
f"Video of {num_frames} frames generated in {duration:.2f} seconds and saved to {video_path}\n"
)
def get_unique_embedder_keys_from_conditioner(conditioner):

Loading…
Cancel
Save