@ -26,7 +26,9 @@ from imaginairy.utils import (
instantiate_from_config ,
instantiate_from_config ,
platform_appropriate_autocast ,
platform_appropriate_autocast ,
)
)
from imaginairy . utils . animations import make_bounce_animation
from imaginairy . utils . model_manager import get_cached_url_path
from imaginairy . utils . model_manager import get_cached_url_path
from imaginairy . utils . named_resolutions import normalize_image_size
from imaginairy . utils . paths import PKG_ROOT
from imaginairy . utils . paths import PKG_ROOT
logger = logging . getLogger ( __name__ )
logger = logging . getLogger ( __name__ )
@ -35,6 +37,7 @@ logger = logging.getLogger(__name__)
def generate_video (
def generate_video (
input_path : str , # Can either be image file or folder with image files
input_path : str , # Can either be image file or folder with image files
output_folder : str | None = None ,
output_folder : str | None = None ,
size = ( 1024 , 576 ) ,
num_frames : int = 6 ,
num_frames : int = 6 ,
num_steps : int = 30 ,
num_steps : int = 30 ,
model_name : str = " svd-xt " ,
model_name : str = " svd-xt " ,
@ -46,6 +49,7 @@ def generate_video(
decoding_t : int = 1 , # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
decoding_t : int = 1 , # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
device : Optional [ str ] = None ,
device : Optional [ str ] = None ,
repetitions = 1 ,
repetitions = 1 ,
output_format = " webp " ,
) :
) :
"""
"""
Generates a video from a single image or multiple images , conditioned on the provided input_path .
Generates a video from a single image or multiple images , conditioned on the provided input_path .
@ -71,7 +75,7 @@ def generate_video(
None : The function saves the generated video ( s ) to the specified output folder .
None : The function saves the generated video ( s ) to the specified output folder .
"""
"""
device = default ( device , get_device )
device = default ( device , get_device )
vid_width , vid_height = normalize_image_size ( size )
if device == " mps " :
if device == " mps " :
msg = " Apple Silicon MPS (M1, M2, etc) is not currently supported for video generation. Switching to cpu generation. "
msg = " Apple Silicon MPS (M1, M2, etc) is not currently supported for video generation. Switching to cpu generation. "
logger . warning ( msg )
logger . warning ( msg )
@ -88,7 +92,6 @@ def generate_video(
logger . warning ( msg )
logger . warning ( msg )
start_time = time . perf_counter ( )
start_time = time . perf_counter ( )
seed = default ( seed , random . randint ( 0 , 1000000 ) )
output_fps = default ( output_fps , fps_id )
output_fps = default ( output_fps , fps_id )
video_model_config = config . MODEL_WEIGHT_CONFIG_LOOKUP . get ( model_name , None )
video_model_config = config . MODEL_WEIGHT_CONFIG_LOOKUP . get ( model_name , None )
@ -102,9 +105,6 @@ def generate_video(
del output_folder
del output_folder
video_config_path = f " { PKG_ROOT } / { video_model_config . architecture . config_path } "
video_config_path = f " { PKG_ROOT } / { video_model_config . architecture . config_path } "
logger . info (
f " Generating a { num_frames } frame video from { input_path } . Device: { device } seed: { seed } "
)
model , safety_filter = load_model (
model , safety_filter = load_model (
config = video_config_path ,
config = video_config_path ,
device = " cpu " ,
device = " cpu " ,
@ -112,7 +112,6 @@ def generate_video(
num_steps = num_steps ,
num_steps = num_steps ,
weights_url = video_model_config . weights_location ,
weights_url = video_model_config . weights_location ,
)
)
torch . manual_seed ( seed )
if input_path . startswith ( " http " ) :
if input_path . startswith ( " http " ) :
all_img_paths = [ input_path ]
all_img_paths = [ input_path ]
@ -137,9 +136,14 @@ def generate_video(
msg = f " Could not find file or folder at { input_path } "
msg = f " Could not find file or folder at { input_path } "
raise FileNotFoundError ( msg )
raise FileNotFoundError ( msg )
expected_size = ( 1024 , 576 )
expected_size = ( vid_width , vid_height )
for _ in range ( repetitions ) :
for _ in range ( repetitions ) :
for input_path in all_img_paths :
for input_path in all_img_paths :
_seed = default ( seed , random . randint ( 0 , 1000000 ) )
torch . manual_seed ( _seed )
logger . info (
f " Generating a { num_frames } frame video from { input_path } . Device: { device } seed: { _seed } "
)
if input_path . startswith ( " http " ) :
if input_path . startswith ( " http " ) :
image = LazyLoadingImage ( url = input_path ) . as_pillow ( )
image = LazyLoadingImage ( url = input_path ) . as_pillow ( )
else :
else :
@ -207,7 +211,6 @@ def generate_video(
value_dict [ " cond_aug " ] = cond_aug
value_dict [ " cond_aug " ] = cond_aug
value_dict [ " cond_frames_without_noise " ] = image
value_dict [ " cond_frames_without_noise " ] = image
value_dict [ " cond_frames " ] = image + cond_aug * torch . randn_like ( image )
value_dict [ " cond_frames " ] = image + cond_aug * torch . randn_like ( image )
value_dict [ " cond_aug " ] = cond_aug
with torch . no_grad ( ) , platform_appropriate_autocast ( ) :
with torch . no_grad ( ) , platform_appropriate_autocast ( ) :
reload_model ( model . conditioner , device = device )
reload_model ( model . conditioner , device = device )
@ -272,30 +275,14 @@ def generate_video(
samples = samples [ : , : , upper : lower , left : right ]
samples = samples [ : , : , upper : lower , left : right ]
os . makedirs ( output_folder_str , exist_ok = True )
os . makedirs ( output_folder_str , exist_ok = True )
base_count = len ( glob ( os . path . join ( output_folder_str , " *. mp4 " ) ) ) + 1
base_count = len ( glob ( os . path . join ( output_folder_str , " *. * " ) ) ) + 1
source_slug = make_safe_filename ( input_path )
source_slug = make_safe_filename ( input_path )
video_filename = f " { base_count : 06d } _ { model_name } _ { seed} _ { fps_id } fps_ { source_slug } . mp4 "
video_filename = f " { base_count : 06d } _ { model_name } _ { _ seed} _ { fps_id } fps_ { source_slug } . { output_format } "
video_path = os . path . join ( output_folder_str , video_filename )
video_path = os . path . join ( output_folder_str , video_filename )
writer = cv2 . VideoWriter (
video_path ,
cv2 . VideoWriter_fourcc ( * " MP4V " ) , # type: ignore
output_fps ,
( samples . shape [ - 1 ] , samples . shape [ - 2 ] ) ,
)
samples = safety_filter ( samples )
samples = safety_filter ( samples )
vid = (
# save_video(samples, video_path, output_fps)
( rearrange ( samples , " t c h w -> t h w c " ) * 255 )
save_video_bounce ( samples , video_path , output_fps )
. cpu ( )
. numpy ( )
. astype ( np . uint8 )
)
for frame in vid :
frame = cv2 . cvtColor ( frame , cv2 . COLOR_RGB2BGR )
writer . write ( frame )
writer . release ( )
video_path_h264 = video_path [ : - 4 ] + " _h264.mp4 "
os . system ( f " ffmpeg -i { video_path } -c:v libx264 { video_path_h264 } " )
duration = time . perf_counter ( ) - start_time
duration = time . perf_counter ( ) - start_time
logger . info (
logger . info (
@ -303,6 +290,46 @@ def generate_video(
)
)
def save_video ( samples : torch . Tensor , video_filename : str , output_fps : int ) :
"""
Saves a video from given tensor samples .
Args :
samples ( torch . Tensor ) : Tensor containing video frame data .
video_filename ( str ) : The full path and filename where the video will be saved .
output_fps ( int ) : Frames per second for the output video .
safety_filter ( Callable [ [ torch . Tensor ] , torch . Tensor ] ) : A function to apply a safety filter to the samples .
Returns :
str : The path to the saved video .
"""
vid = ( torch . permute ( samples , ( 0 , 2 , 3 , 1 ) ) * 255 ) . cpu ( ) . numpy ( ) . astype ( np . uint8 )
writer = cv2 . VideoWriter (
video_filename ,
cv2 . VideoWriter_fourcc ( * " MP4V " ) , # type: ignore
output_fps ,
( samples . shape [ - 1 ] , samples . shape [ - 2 ] ) ,
)
for frame in vid :
frame = cv2 . cvtColor ( frame , cv2 . COLOR_RGB2BGR )
writer . write ( frame )
writer . release ( )
video_path_h264 = video_filename [ : - 4 ] + " _h264.mp4 "
os . system ( f " ffmpeg -i { video_filename } -c:v libx264 { video_path_h264 } " )
def save_video_bounce ( samples : torch . Tensor , video_filename : str , output_fps : int ) :
frames_np = (
( torch . permute ( samples , ( 0 , 2 , 3 , 1 ) ) * 255 ) . cpu ( ) . numpy ( ) . astype ( np . uint8 )
)
make_bounce_animation (
imgs = [ Image . fromarray ( frame ) for frame in frames_np ] ,
outpath = video_filename ,
end_pause_duration_ms = 750 ,
)
def get_unique_embedder_keys_from_conditioner ( conditioner ) :
def get_unique_embedder_keys_from_conditioner ( conditioner ) :
return list ( { x . input_key for x in conditioner . embedders } )
return list ( { x . input_key for x in conditioner . embedders } )