@ -2,6 +2,7 @@ import logging
import math
import os
import random
import re
import time
from glob import glob
from pathlib import Path
@ -12,6 +13,7 @@ import numpy as np
import torch
from einops import rearrange , repeat
from omegaconf import OmegaConf
from PIL import Image
from torchvision . transforms import ToTensor
from imaginairy import LazyLoadingImage , config
@ -40,6 +42,7 @@ def generate_video(
decoding_t : int = 1 , # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
device : Optional [ str ] = None ,
output_folder : Optional [ str ] = None ,
repetitions = 1 ,
) :
"""
Simple script to generate a single sample conditioned on an image ` input_path ` or multiple images , one for each
@ -50,8 +53,6 @@ def generate_video(
seed = default ( seed , random . randint ( 0 , 1000000 ) )
output_fps = default ( output_fps , fps_id )
logger . info ( f " Device: { device } seed: { seed } " )
torch . cuda . reset_peak_memory_stats ( )
video_model_config = config . video_models . get ( model_name , None )
if video_model_config is None :
@ -63,6 +64,9 @@ def generate_video(
output_folder = default ( output_folder , " outputs/video/ " )
video_config_path = f " { PKG_ROOT } / { video_model_config [ ' config_path ' ] } "
logger . info (
f " Generating { num_frames } frame video from { input_path } . Device: { device } seed: { seed } "
)
model , safety_filter = load_model (
config = video_config_path ,
device = " cpu " ,
@ -71,11 +75,11 @@ def generate_video(
weights_url = video_model_config [ " weights_url " ] ,
)
torch . manual_seed ( seed )
if input_path . startswith ( " http " ) :
input_images = [ LazyLoadingImage ( url = input_path ) ]
all_img_paths = [ input_path ]
else :
path = Path ( input_path )
all_img_paths = [ ]
if path . is_file ( ) :
if any ( input_path . endswith ( x ) for x in [ " jpg " , " jpeg " , " png " ] ) :
all_img_paths = [ input_path ]
@ -84,7 +88,7 @@ def generate_video(
elif path . is_dir ( ) :
all_img_paths = sorted (
[
f
str ( f )
for f in path . iterdir ( )
if f . is_file ( ) and f . suffix . lower ( ) in [ " .jpg " , " .jpeg " , " .png " ]
]
@ -93,134 +97,159 @@ def generate_video(
raise ValueError ( " Folder does not contain any images. " )
else :
raise ValueError
input_images = [ LazyLoadingImage ( filepath = str ( x ) ) for x in all_img_paths ]
for image in input_images :
image = image . as_pillow ( )
if image . mode == " RGBA " :
image = image . convert ( " RGB " )
w , h = image . size
if h % 64 != 0 or w % 64 != 0 :
width , height = ( x - x % 64 for x in ( w , h ) )
image = image . resize ( ( width , height ) )
logger . info (
f " Your image is of size { h } x { w } which is not divisible by 64. We are resizing to { height } x { width } ! "
)
expected_size = ( 1024 , 576 )
for _ in range ( repetitions ) :
for input_path in all_img_paths :
if input_path . startswith ( " http " ) :
image = LazyLoadingImage ( url = input_path )
else :
image = LazyLoadingImage ( filepath = input_path )
crop_coords = None
image = image . as_pillow ( )
if image . mode == " RGBA " :
image = image . convert ( " RGB " )
if image . size != expected_size :
logger . info (
f " Resizing image from { image . size } to { expected_size } . (w, h) "
)
image = pillow_fit_image_within (
image , max_height = expected_size [ 1 ] , max_width = expected_size [ 0 ]
)
logger . debug ( f " Image is now of size: { image . size } " )
background = Image . new ( " RGB " , expected_size , " white " )
# Calculate the position to center the original image
x = ( background . width - image . width ) / / 2
y = ( background . height - image . height ) / / 2
background . paste ( image , ( x , y ) )
crop_coords = ( x , y , x + image . width , y + image . height )
image = background
image = ToTensor ( ) ( image )
image = image * 2.0 - 1.0
image = image . unsqueeze ( 0 ) . to ( device )
H , W = image . shape [ 2 : ]
assert image . shape [ 1 ] == 3
F = 8
C = 4
shape = ( num_frames , C , H / / F , W / / F )
if expected_size != ( W , H ) :
logger . warning (
f " The { W , H } image you provided is not { expected_size } . This leads to suboptimal performance as model was only trained on 576x1024. Consider increasing `cond_aug`. "
)
if motion_bucket_id > 255 :
logger . warning (
" High motion bucket! This may lead to suboptimal performance. "
)
image = ToTensor ( ) ( image )
image = image * 2.0 - 1.0
image = image . unsqueeze ( 0 ) . to ( device )
H , W = image . shape [ 2 : ]
assert image . shape [ 1 ] == 3
F = 8
C = 4
shape = ( num_frames , C , H / / F , W / / F )
if ( H , W ) != ( 576 , 1024 ) :
logger . warning (
" The image you provided is not 576x1024. This leads to suboptimal performance as model was only trained on 576x1024. Consider increasing `cond_aug`. "
)
if motion_bucket_id > 255 :
logger . warning (
" High motion bucket! This may lead to suboptimal performance. "
)
if fps_id < 5 :
logger . warning (
" Small fps value! This may lead to suboptimal performance. "
)
if fps_id < 5 :
logger . warning ( " Small fps value! This may lead to suboptimal performance. " )
if fps_id > 30 :
logger . warning ( " Large fps value! This may lead to suboptimal performance. " )
value_dict = { }
value_dict [ " motion_bucket_id " ] = motion_bucket_id
value_dict [ " fps_id " ] = fps_id
value_dict [ " cond_aug " ] = cond_aug
value_dict [ " cond_frames_without_noise " ] = image
value_dict [ " cond_frames " ] = image + cond_aug * torch . randn_like ( image )
value_dict [ " cond_aug " ] = cond_aug
with torch . no_grad ( ) , platform_appropriate_autocast ( ) :
reload_model ( model . conditioner )
batch , batch_uc = get_batch (
get_unique_embedder_keys_from_conditioner ( model . conditioner ) ,
value_dict ,
[ 1 , num_frames ] ,
T = num_frames ,
device = device ,
)
c , uc = model . conditioner . get_unconditional_conditioning (
batch ,
batch_uc = batch_uc ,
force_uc_zero_embeddings = [
" cond_frames " ,
" cond_frames_without_noise " ,
] ,
)
unload_model ( model . conditioner )
for k in [ " crossattn " , " concat " ] :
uc [ k ] = repeat ( uc [ k ] , " b ... -> b t ... " , t = num_frames )
uc [ k ] = rearrange ( uc [ k ] , " b t ... -> (b t) ... " , t = num_frames )
c [ k ] = repeat ( c [ k ] , " b ... -> b t ... " , t = num_frames )
c [ k ] = rearrange ( c [ k ] , " b t ... -> (b t) ... " , t = num_frames )
randn = torch . randn ( shape , device = device )
additional_model_inputs = { }
additional_model_inputs [ " image_only_indicator " ] = torch . zeros (
2 , num_frames
) . to ( device )
additional_model_inputs [ " num_video_frames " ] = batch [ " num_video_frames " ]
def denoiser ( _input , sigma , c ) :
_input = _input . half ( )
return model . denoiser (
model . model , _input , sigma , c , * * additional_model_inputs
if fps_id > 30 :
logger . warning (
" Large fps value! This may lead to suboptimal performance. "
)
reload_model ( model . denoiser )
reload_model ( model . model )
samples_z = model . sampler ( denoiser , randn , cond = c , uc = uc )
unload_model ( model . model )
unload_model ( model . denoiser )
reload_model ( model . first_stage_model )
model . en_and_decode_n_samples_a_time = decoding_t
samples_x = model . decode_first_stage ( samples_z )
samples = torch . clamp ( ( samples_x + 1.0 ) / 2.0 , min = 0.0 , max = 1.0 )
unload_model ( model . first_stage_model )
os . makedirs ( output_folder , exist_ok = True )
base_count = len ( glob ( os . path . join ( output_folder , " *.mp4 " ) ) ) + 1
video_filename = f " { base_count : 06d } _ { model_name } _ { seed } .mp4 "
video_path = os . path . join ( output_folder , video_filename )
writer = cv2 . VideoWriter (
video_path ,
cv2 . VideoWriter_fourcc ( * " MP4V " ) ,
output_fps ,
( samples . shape [ - 1 ] , samples . shape [ - 2 ] ) ,
)
value_dict = { }
value_dict [ " motion_bucket_id " ] = motion_bucket_id
value_dict [ " fps_id " ] = fps_id
value_dict [ " cond_aug " ] = cond_aug
value_dict [ " cond_frames_without_noise " ] = image
value_dict [ " cond_frames " ] = image + cond_aug * torch . randn_like ( image )
value_dict [ " cond_aug " ] = cond_aug
with torch . no_grad ( ) , platform_appropriate_autocast ( ) :
reload_model ( model . conditioner )
batch , batch_uc = get_batch (
get_unique_embedder_keys_from_conditioner ( model . conditioner ) ,
value_dict ,
[ 1 , num_frames ] ,
T = num_frames ,
device = device ,
)
c , uc = model . conditioner . get_unconditional_conditioning (
batch ,
batch_uc = batch_uc ,
force_uc_zero_embeddings = [
" cond_frames " ,
" cond_frames_without_noise " ,
] ,
)
unload_model ( model . conditioner )
for k in [ " crossattn " , " concat " ] :
uc [ k ] = repeat ( uc [ k ] , " b ... -> b t ... " , t = num_frames )
uc [ k ] = rearrange ( uc [ k ] , " b t ... -> (b t) ... " , t = num_frames )
c [ k ] = repeat ( c [ k ] , " b ... -> b t ... " , t = num_frames )
c [ k ] = rearrange ( c [ k ] , " b t ... -> (b t) ... " , t = num_frames )
randn = torch . randn ( shape , device = device )
additional_model_inputs = { }
additional_model_inputs [ " image_only_indicator " ] = torch . zeros (
2 , num_frames
) . to ( device )
additional_model_inputs [ " num_video_frames " ] = batch [ " num_video_frames " ]
def denoiser ( _input , sigma , c ) :
_input = _input . half ( )
return model . denoiser (
model . model , _input , sigma , c , * * additional_model_inputs
)
reload_model ( model . denoiser )
reload_model ( model . model )
samples_z = model . sampler ( denoiser , randn , cond = c , uc = uc )
unload_model ( model . model )
unload_model ( model . denoiser )
reload_model ( model . first_stage_model )
model . en_and_decode_n_samples_a_time = decoding_t
samples_x = model . decode_first_stage ( samples_z )
samples = torch . clamp ( ( samples_x + 1.0 ) / 2.0 , min = 0.0 , max = 1.0 )
unload_model ( model . first_stage_model )
if crop_coords :
left , upper , right , lower = crop_coords
samples = samples [ : , : , upper : lower , left : right ]
os . makedirs ( output_folder , exist_ok = True )
base_count = len ( glob ( os . path . join ( output_folder , " *.mp4 " ) ) ) + 1
source_slug = make_safe_filename ( input_path )
video_filename = f " { base_count : 06d } _ { model_name } _ { seed } _ { fps_id } fps_ { source_slug } .mp4 "
video_path = os . path . join ( output_folder , video_filename )
writer = cv2 . VideoWriter (
video_path ,
cv2 . VideoWriter_fourcc ( * " MP4V " ) ,
output_fps ,
( samples . shape [ - 1 ] , samples . shape [ - 2 ] ) ,
)
samples = safety_filter ( samples )
vid = (
( rearrange ( samples , " t c h w -> t h w c " ) * 255 )
. cpu ( )
. numpy ( )
. astype ( np . uint8 )
samples = safety_filter ( samples )
vid = (
( rearrange ( samples , " t c h w -> t h w c " ) * 255 )
. cpu ( )
. numpy ( )
. astype ( np . uint8 )
)
for frame in vid :
frame = cv2 . cvtColor ( frame , cv2 . COLOR_RGB2BGR )
writer . write ( frame )
writer . release ( )
video_path_h264 = video_path [ : - 4 ] + " _h264.mp4 "
os . system ( f " ffmpeg -i { video_path } -c:v libx264 { video_path_h264 } " )
if torch . cuda . is_available ( ) :
peak_memory_usage = torch . cuda . max_memory_allocated ( )
msg = f " Peak memory usage: { peak_memory_usage / ( 1024 * * 2 ) } MB "
logger . info ( msg )
duration = time . perf_counter ( ) - start_time
logger . info (
f " Video of { num_frames } frames generated in { duration : .2f } seconds and saved to { video_path } \n "
)
for frame in vid :
frame = cv2 . cvtColor ( frame , cv2 . COLOR_RGB2BGR )
writer . write ( frame )
writer . release ( )
if torch . cuda . is_available ( ) :
peak_memory_usage = torch . cuda . max_memory_allocated ( )
msg = f " Peak memory usage: { peak_memory_usage / ( 1024 * * 2 ) } MB "
logger . info ( msg )
duration = time . perf_counter ( ) - start_time
logger . info (
f " Video of { num_frames } frames generated in { duration : .2f } seconds and saved to { video_path } \n "
)
def get_unique_embedder_keys_from_conditioner ( conditioner ) :
@ -310,6 +339,45 @@ def reload_model(model):
model . to ( get_device ( ) )
def pillow_fit_image_within (
image : Image . Image , max_height = 512 , max_width = 512 , convert = " RGB " , snap_size = 8
) :
image = image . convert ( convert )
w , h = image . size
resize_ratio = 1
if w > max_width or h > max_height :
resize_ratio = min ( max_width / w , max_height / h )
elif w < max_width and h < max_height :
# it's smaller than our target image, enlarge
resize_ratio = min ( max_width / w , max_height / h )
if resize_ratio != 1 :
w , h = int ( w * resize_ratio ) , int ( h * resize_ratio )
# resize to integer multiple of snap_size
w - = w % snap_size
h - = h % snap_size
if ( w , h ) != image . size :
image = image . resize ( ( w , h ) , resample = Image . Resampling . LANCZOS )
return image
def make_safe_filename ( input_string ) :
stripped_url = re . sub ( r " ^https?://[^/]+/ " , " " , input_string )
# Remove directory path if present
base_name = os . path . basename ( stripped_url )
# Remove file extension
name_without_extension = os . path . splitext ( base_name ) [ 0 ]
# Keep only alphanumeric characters and dashes
safe_name = re . sub ( r " [^a-zA-Z0-9 \ -] " , " " , name_without_extension )
return safe_name
if __name__ == " __main__ " :
# configure logging
logging . basicConfig (