diff --git a/imaginairy/api/video_sample.py b/imaginairy/api/video_sample.py index 7e17b33..77e4607 100644 --- a/imaginairy/api/video_sample.py +++ b/imaginairy/api/video_sample.py @@ -19,6 +19,7 @@ from PIL import Image from torchvision.transforms import ToTensor from imaginairy import config +from imaginairy.enhancers.video_interpolation.rife.interpolate import interpolate_images from imaginairy.schema import LazyLoadingImage from imaginairy.utils import ( default, @@ -91,7 +92,6 @@ def generate_video( ) logger.warning(msg) - start_time = time.perf_counter() output_fps = default(output_fps, fps_id) video_model_config = config.MODEL_WEIGHT_CONFIG_LOOKUP.get(model_name, None) @@ -139,6 +139,7 @@ def generate_video( expected_size = (vid_width, vid_height) for _ in range(repetitions): for input_path in all_img_paths: + start_time = time.perf_counter() _seed = default(seed, random.randint(0, 1000000)) torch.manual_seed(_seed) logger.info( @@ -318,15 +319,32 @@ def save_video(samples: torch.Tensor, video_filename: str, output_fps: int): os.system(f"ffmpeg -i {video_filename} -c:v libx264 {video_path_h264}") -def save_video_bounce(samples: torch.Tensor, video_filename: str, output_fps: int): +def save_video_bounce( + samples: torch.Tensor, video_filename: str, output_fps: int, interpolate_fps=60 +): frames_np = ( (torch.permute(samples, (0, 2, 3, 1)) * 255).cpu().numpy().astype(np.uint8) ) - + transition_duration = len(frames_np) / float(output_fps) + frames_pil = [Image.fromarray(frame) for frame in frames_np] + if interpolate_fps: + # bring it up to at least 60 fps + fps_multiplier = int(math.ceil(interpolate_fps / output_fps)) + frames_pil = interpolate_images(frames_pil, fps_multiplier=fps_multiplier) + + transition_duration_ms = transition_duration * 1000 + logger.info( + f"Interpolated from {len(frames_np)} to {len(frames_pil)} frames ({fps_multiplier} multiplier)" + ) + logger.info( + f"Making bounce animation with transition duration {transition_duration_ms:.1f}ms" + ) make_bounce_animation( - imgs=[Image.fromarray(frame) for frame in frames_np], + imgs=frames_pil, outpath=video_filename, - end_pause_duration_ms=750, + transition_duration_ms=transition_duration_ms, + end_pause_duration_ms=100, + max_fps=60, ) diff --git a/imaginairy/enhancers/video_interpolation/__init__.py b/imaginairy/enhancers/video_interpolation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/imaginairy/enhancers/video_interpolation/rife/IFNet_HDv3.py b/imaginairy/enhancers/video_interpolation/rife/IFNet_HDv3.py new file mode 100755 index 0000000..fd87b9e --- /dev/null +++ b/imaginairy/enhancers/video_interpolation/rife/IFNet_HDv3.py @@ -0,0 +1,227 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .warplayer import warp + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=True, + ), + nn.LeakyReLU(0.2, True), + ) + + +def conv_bn(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=False, + ), + nn.BatchNorm2d(out_planes), + nn.LeakyReLU(0.2, True), + ) + + +class Head(nn.Module): + def __init__(self): + super().__init__() + self.cnn0 = nn.Conv2d(3, 32, 3, 2, 1) + self.cnn1 = nn.Conv2d(32, 32, 3, 1, 1) + self.cnn2 = nn.Conv2d(32, 32, 3, 1, 1) + self.cnn3 = nn.ConvTranspose2d(32, 8, 4, 2, 1) + self.relu = nn.LeakyReLU(0.2, True) + + def forward(self, x, feat=False): + x0 = self.cnn0(x) + x = self.relu(x0) + x1 = self.cnn1(x) + x = self.relu(x1) + x2 = self.cnn2(x) + x = self.relu(x2) + x3 = self.cnn3(x) + if feat: + return [x0, x1, x2, x3] + return x3 + + +class ResConv(nn.Module): + def __init__(self, c, dilation=1): + super().__init__() + self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1) + self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) + self.relu = nn.LeakyReLU(0.2, True) + + def forward(self, x): + return self.relu(self.conv(x) * self.beta + x) + + +class IFBlock(nn.Module): + def __init__(self, in_planes, c=64): + super().__init__() + self.conv0 = nn.Sequential( + conv(in_planes, c // 2, 3, 2, 1), + conv(c // 2, c, 3, 2, 1), + ) + self.convblock = nn.Sequential( + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ) + self.lastconv = nn.Sequential( + nn.ConvTranspose2d(c, 4 * 6, 4, 2, 1), nn.PixelShuffle(2) + ) + + def forward(self, x, flow=None, scale=1): + x = F.interpolate( + x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False + ) + if flow is not None: + flow = ( + F.interpolate( + flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False + ) + * 1.0 + / scale + ) + x = torch.cat((x, flow), 1) + feat = self.conv0(x) + feat = self.convblock(feat) + tmp = self.lastconv(feat) + tmp = F.interpolate( + tmp, scale_factor=scale, mode="bilinear", align_corners=False + ) + flow = tmp[:, :4] * scale + mask = tmp[:, 4:5] + return flow, mask + + +class IFNet(nn.Module): + def __init__(self): + super().__init__() + self.block0 = IFBlock(7 + 16, c=192) + self.block1 = IFBlock(8 + 4 + 16, c=128) + self.block2 = IFBlock(8 + 4 + 16, c=96) + self.block3 = IFBlock(8 + 4 + 16, c=64) + self.encode = Head() + # self.contextnet = Contextnet() + # self.unet = Unet() + + def forward( + self, + x, + timestep=0.5, + scale_list=[8, 4, 2, 1], + training=False, + fastmode=True, + ensemble=False, + ): + if training is False: + channel = x.shape[1] // 2 + img0 = x[:, :channel] + img1 = x[:, channel:] + if not torch.is_tensor(timestep): + timestep = (x[:, :1].clone() * 0 + 1) * timestep + else: + timestep = timestep.repeat(1, 1, img0.shape[2], img0.shape[3]) + f0 = self.encode(img0[:, :3]) + f1 = self.encode(img1[:, :3]) + flow_list = [] + merged = [] + mask_list = [] + warped_img0 = img0 + warped_img1 = img1 + flow = None + mask = None + block = [self.block0, self.block1, self.block2, self.block3] + for i in range(4): + if flow is None: + flow, mask = block[i]( + torch.cat((img0[:, :3], img1[:, :3], f0, f1, timestep), 1), + None, + scale=scale_list[i], + ) + if ensemble: + f_, m_ = block[i]( + torch.cat((img1[:, :3], img0[:, :3], f1, f0, 1 - timestep), 1), + None, + scale=scale_list[i], + ) + flow = (flow + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 + mask = (mask + (-m_)) / 2 + else: + wf0 = warp(f0, flow[:, :2]) + wf1 = warp(f1, flow[:, 2:4]) + fd, m0 = block[i]( + torch.cat( + ( + warped_img0[:, :3], + warped_img1[:, :3], + wf0, + wf1, + timestep, + mask, + ), + 1, + ), + flow, + scale=scale_list[i], + ) + if ensemble: + f_, m_ = block[i]( + torch.cat( + ( + warped_img1[:, :3], + warped_img0[:, :3], + wf1, + wf0, + 1 - timestep, + -mask, + ), + 1, + ), + torch.cat((flow[:, 2:4], flow[:, :2]), 1), + scale=scale_list[i], + ) + fd = (fd + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 + mask = (m0 + (-m_)) / 2 + else: + mask = m0 + flow = flow + fd + mask_list.append(mask) + flow_list.append(flow) + warped_img0 = warp(img0, flow[:, :2]) + warped_img1 = warp(img1, flow[:, 2:4]) + merged.append((warped_img0, warped_img1)) + mask = torch.sigmoid(mask) + merged[3] = warped_img0 * mask + warped_img1 * (1 - mask) + if not fastmode: + print("contextnet is removed") + """ + c0 = self.contextnet(img0, flow[:, :2]) + c1 = self.contextnet(img1, flow[:, 2:4]) + tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1) + res = tmp[:, :3] * 2 - 1 + merged[3] = torch.clamp(merged[3] + res, 0, 1) + """ + return flow_list, mask_list[3], merged diff --git a/imaginairy/enhancers/video_interpolation/rife/RIFE_HDv3.py b/imaginairy/enhancers/video_interpolation/rife/RIFE_HDv3.py new file mode 100755 index 0000000..1086e4c --- /dev/null +++ b/imaginairy/enhancers/video_interpolation/rife/RIFE_HDv3.py @@ -0,0 +1,50 @@ +import torch + +from .IFNet_HDv3 import IFNet + + +class Model: + def __init__(self): + self.flownet = IFNet() + self.version: float + + def eval(self): + self.flownet.eval() + + def load_model(self, path, version: float): + from safetensors import safe_open + + tensors = {} + with safe_open(path, framework="pt") as f: # type: ignore + for key in f.keys(): # noqa + tensors[key] = f.get_tensor(key) + self.flownet.load_state_dict(tensors, assign=True) + self.version = version + + def load_model_old(self, path, rank=0): + def convert(param): + if rank == -1: + return { + k.replace("module.", ""): v + for k, v in param.items() + if "module." in k + } + else: + return param + + if rank <= 0: + if torch.cuda.is_available(): + self.flownet.load_state_dict( + convert(torch.load(f"{path}/flownet.pkl")), False + ) + else: + self.flownet.load_state_dict( + convert(torch.load(f"{path}/flownet.pkl", map_location="cpu")), + False, + ) + + def inference(self, img0, img1, timestep=0.5, scale=1.0): + imgs = torch.cat((img0, img1), 1) + scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale] + flow, mask, merged = self.flownet(imgs, timestep, scale_list) + return merged[3] diff --git a/imaginairy/enhancers/video_interpolation/rife/__init__.py b/imaginairy/enhancers/video_interpolation/rife/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/imaginairy/enhancers/video_interpolation/rife/interpolate.py b/imaginairy/enhancers/video_interpolation/rife/interpolate.py new file mode 100644 index 0000000..ef95ef4 --- /dev/null +++ b/imaginairy/enhancers/video_interpolation/rife/interpolate.py @@ -0,0 +1,410 @@ +import _thread +import logging +import os +import shutil +import time +from functools import lru_cache +from queue import Queue +from typing import List + +import cv2 +import numpy as np +import torch +from PIL import Image +from torch.nn import functional as F +from tqdm import tqdm + +from imaginairy.utils import get_device +from imaginairy.utils.model_manager import get_cached_url_path + +from .msssim import ssim_matlab +from .RIFE_HDv3 import Model + +logger = logging.getLogger(__name__) + + +def transfer_audio(sourceVideo, targetVideo): + tempAudioFileName = "./temp/audio.mkv" + + # split audio from original video file and store in "temp" directory + if True: + # clear old "temp" directory if it exits + if os.path.isdir("temp"): + # remove temp directory + shutil.rmtree("temp") + # create new "temp" directory + os.makedirs("temp") + # extract audio from video + os.system(f'ffmpeg -y -i "{sourceVideo}" -c:a copy -vn {tempAudioFileName}') + + targetNoAudio = ( + os.path.splitext(targetVideo)[0] + "_noaudio" + os.path.splitext(targetVideo)[1] + ) + os.rename(targetVideo, targetNoAudio) + # combine audio file and new video file + os.system( + f'ffmpeg -y -i "{targetNoAudio}" -i {tempAudioFileName} -c copy "{targetVideo}"' + ) + + if ( + os.path.getsize(targetVideo) == 0 + ): # if ffmpeg failed to merge the video and audio together try converting the audio to aac + tempAudioFileName = "./temp/audio.m4a" + os.system( + f'ffmpeg -y -i "{sourceVideo}" -c:a aac -b:a 160k -vn {tempAudioFileName}' + ) + os.system( + 'ffmpeg -y -i "{}" -i {} -c copy "{}"'.format( + targetNoAudio, tempAudioFileName, targetVideo + ) + ) + if ( + os.path.getsize(targetVideo) == 0 + ): # if aac is not supported by selected format + os.rename(targetNoAudio, targetVideo) + print("Audio transfer failed. Interpolated video will have no audio") + else: + print( + "Lossless audio transfer failed. Audio was transcoded to AAC (M4A) instead." + ) + + # remove audio-less video + os.remove(targetNoAudio) + else: + os.remove(targetNoAudio) + + # remove temp directory + shutil.rmtree("temp") + + +RIFE_WEIGHTS_URL = "https://huggingface.co/imaginairy/rife-interpolation/resolve/26442e52cc30b88c5cb490702647b8de9aaee8a7/rife-flownet-4.13.2.safetensors" + + +@lru_cache(maxsize=1) +def load_rife_model(model_path=None, version=4.13, device=None): + if model_path is None: + model_path = RIFE_WEIGHTS_URL + model_path = get_cached_url_path(model_path) + device = device if device else get_device() + model = Model() + model.load_model(model_path, version=version) + model.eval() + model.flownet.to(device) + return model + + +def make_inference(I0, I1, n, *, model, scale): + if model.version >= 3.9: + res = [] + for i in range(n): + res.append(model.inference(I0, I1, (i + 1) * 1.0 / (n + 1), scale)) + return res + else: + middle = model.inference(I0, I1, scale) + if n == 1: + return [middle] + first_half = make_inference(I0, middle, n=n // 2, model=model, scale=scale) + second_half = make_inference(middle, I1, n=n // 2, model=model, scale=scale) + if n % 2: + return [*first_half, middle, *second_half] + else: + return [*first_half, *second_half] + + +def interpolate_video_file( + video_path: str | None = None, + images_source_path: str | None = None, + scale: float = 1.0, + vid_out_name: str | None = None, + target_fps: float | None = None, + fps_multiplier: int = 2, + model_weights_path: str | None = None, + fp16: bool = False, + montage: bool = False, + png_out: bool = False, + output_extension: str = "mp4", + device=None, +): + assert video_path is not None or images_source_path is not None + assert scale in [0.25, 0.5, 1.0, 2.0, 4.0] + device = device if device else get_device() + + if images_source_path is not None: + png_out = True + + torch.set_grad_enabled(False) + if torch.cuda.is_available(): + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = True + if fp16: + torch.set_default_tensor_type(torch.cuda.HalfTensor) # type: ignore + + model = load_rife_model(model_weights_path, version=4.13) + logger.info(f"Loaded RIFE from {model_weights_path}") + + if video_path is not None: + import skvideo.io + + videoCapture = cv2.VideoCapture(video_path) + fps = videoCapture.get(cv2.CAP_PROP_FPS) + tot_frame = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT) + videoCapture.release() + if target_fps is None: + fpsNotAssigned = True + target_fps = fps * fps_multiplier + else: + fpsNotAssigned = False + videogen = skvideo.io.vreader(video_path) + lastframe = next(videogen) + fourcc = cv2.VideoWriter_fourcc("m", "p", "4", "v") # type: ignore + video_path_wo_ext, ext = os.path.splitext(video_path) + print( + "{}.{}, {} frames in total, {}FPS to {}FPS".format( + video_path_wo_ext, output_extension, tot_frame, fps, target_fps + ) + ) + if png_out is False and fpsNotAssigned is True: + print("The audio will be merged after interpolation process") + else: + print("Will not merge audio because using png or fps flag!") + else: + assert images_source_path is not None + videogen = [] + for f in os.listdir(images_source_path): + if "png" in f: + videogen.append(f) + tot_frame = len(videogen) + videogen.sort(key=lambda x: int(x[:-4])) + lastframe = cv2.imread( + os.path.join(images_source_path, videogen[0]), cv2.IMREAD_UNCHANGED + )[:, :, ::-1].copy() + videogen = videogen[1:] + h, w, _ = lastframe.shape + + vid_out = None + if png_out: + if not os.path.exists("vid_out"): + os.mkdir("vid_out") + else: + if vid_out_name is None: + assert video_path_wo_ext is not None + assert target_fps is not None + vid_out_name = f"{video_path_wo_ext}_{fps_multiplier}X_{int(np.round(target_fps))}fps.{output_extension}" + vid_out = cv2.VideoWriter(vid_out_name, fourcc, target_fps, (w, h)) # type: ignore + + def clear_write_buffer(png_out, write_buffer): + cnt = 0 + while True: + item = write_buffer.get() + if item is None: + break + if png_out: + cv2.imwrite(f"vid_out/{cnt:0>7d}.png", item[:, :, ::-1]) + cnt += 1 + else: + vid_out.write(item[:, :, ::-1]) + + def build_read_buffer(img, montage, read_buffer, videogen): + try: + for frame in videogen: + if img is not None: + frame = cv2.imread(os.path.join(img, frame), cv2.IMREAD_UNCHANGED)[ + :, :, ::-1 + ].copy() + if montage: + frame = frame[:, left : left + w] + read_buffer.put(frame) + except Exception as e: # noqa + print(f"skipping frame due to error: {e}") + read_buffer.put(None) + + def pad_image(img): + if fp16: + return F.pad(img, padding).half() + else: + return F.pad(img, padding) + + if montage: + left = w // 4 + w = w // 2 + tmp = max(128, int(128 / scale)) + ph = ((h - 1) // tmp + 1) * tmp + pw = ((w - 1) // tmp + 1) * tmp + padding = (0, pw - w, 0, ph - h) + pbar = tqdm(total=tot_frame) + if montage: + lastframe = lastframe[:, left : left + w] + write_buffer: Queue = Queue(maxsize=500) + read_buffer: Queue = Queue(maxsize=500) + _thread.start_new_thread( + build_read_buffer, (images_source_path, montage, read_buffer, videogen) + ) + _thread.start_new_thread(clear_write_buffer, (png_out, write_buffer)) + + I1 = ( + torch.from_numpy(np.transpose(lastframe, (2, 0, 1))) + .to(device, non_blocking=True) + .unsqueeze(0) + .float() + / 255.0 + ) + I1 = pad_image(I1) + temp = None # save lastframe when processing static frame + + while True: + if temp is not None: + frame = temp + temp = None + else: + frame = read_buffer.get() + if frame is None: + break + I0 = I1 + I1 = ( + torch.from_numpy(np.transpose(frame, (2, 0, 1))) + .to(device, non_blocking=True) + .unsqueeze(0) + .float() + / 255.0 + ) + I1 = pad_image(I1) + I0_small = F.interpolate(I0, (32, 32), mode="bilinear", align_corners=False) + I1_small = F.interpolate(I1, (32, 32), mode="bilinear", align_corners=False) + ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3]) + + break_flag = False + if ssim > 0.996: + frame = read_buffer.get() # read a new frame + if frame is None: + break_flag = True + frame = lastframe + else: + temp = frame + I1 = ( + torch.from_numpy(np.transpose(frame, (2, 0, 1))) + .to(device, non_blocking=True) + .unsqueeze(0) + .float() + / 255.0 + ) + I1 = pad_image(I1) + I1 = model.inference(I0, I1, scale) + I1_small = F.interpolate(I1, (32, 32), mode="bilinear", align_corners=False) + ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3]) + frame = (I1[0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w] + + if ssim < 0.2: + output = [] + for i in range(fps_multiplier - 1): + output.append(I0) + """ + output = [] + step = 1 / fps_multiplier + alpha = 0 + for i in range(fps_multiplier - 1): + alpha += step + beta = 1-alpha + output.append(torch.from_numpy(np.transpose((cv2.addWeighted(frame[:, :, ::-1], alpha, lastframe[:, :, ::-1], beta, 0)[:, :, ::-1].copy()), (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.) + """ + else: + output = make_inference( + I0, I1, fps_multiplier - 1, model=model, scale=scale + ) + + if montage: + write_buffer.put(np.concatenate((lastframe, lastframe), 1)) + for mid in output: + mid = (mid[0] * 255.0).byte().cpu().numpy().transpose(1, 2, 0) + write_buffer.put(np.concatenate((lastframe, mid[:h, :w]), 1)) + else: + write_buffer.put(lastframe) + for mid in output: + mid = (mid[0] * 255.0).byte().cpu().numpy().transpose(1, 2, 0) + write_buffer.put(mid[:h, :w]) + pbar.update(1) + lastframe = frame + if break_flag: + break + + if montage: + write_buffer.put(np.concatenate((lastframe, lastframe), 1)) + else: + write_buffer.put(lastframe) + + while not write_buffer.empty(): + time.sleep(0.1) + pbar.close() + if vid_out is not None: + vid_out.release() + assert vid_out_name is not None + + # move audio to new video file if appropriate + if png_out is False and fpsNotAssigned is True and video_path is not None: + try: + transfer_audio(video_path, vid_out_name) + except Exception as e: # noqa + logger.info( + f"Audio transfer failed. Interpolated video will have no audio. {e}" + ) + targetNoAudio = ( + os.path.splitext(vid_out_name)[0] + + "_noaudio" + + os.path.splitext(vid_out_name)[1] + ) + os.rename(targetNoAudio, vid_out_name) + + +def pad_image(img, scale): + tmp = max(128, int(128 / scale)) + ph, pw = ( + ((img.shape[1] - 1) // tmp + 1) * tmp, + ((img.shape[2] - 1) // tmp + 1) * tmp, + ) + padding = (0, pw - img.shape[2], 0, ph - img.shape[1]) + return F.pad(img, padding) + + +def interpolate_images( + image_list, + scale=1.0, + fps_multiplier=2, + model_weights_path=None, + device=None, +) -> List[Image.Image]: + assert scale in [0.25, 0.5, 1.0, 2.0, 4.0] + torch.set_grad_enabled(False) + device = device if device else get_device() + model = load_rife_model(model_weights_path, version=4.13) + + interpolated_images = [] + for i in range(len(image_list) - 1): + I0 = image_to_tensor(image_list[i], device) + I1 = image_to_tensor(image_list[i + 1], device) + # I0, I1 = pad_image(I0, scale), pad_image(I1, scale) + + interpolated = make_inference( + I0, I1, n=fps_multiplier - 1, model=model, scale=scale + ) + interpolated_images.append(image_list[i]) + for img in interpolated: + img = (img[0] * 255.0).byte().cpu().numpy().transpose(1, 2, 0) + interpolated_images.append(Image.fromarray(img)) + + interpolated_images.append(image_list[-1]) + return interpolated_images + + +def image_to_tensor(image, device): + """ + Converts a PIL image to a PyTorch tensor. + + Args: + - image (PIL.Image): The image to convert. + - device (torch.device): The device to use (CPU or CUDA). + + Returns: + - torch.Tensor: The image converted to a PyTorch tensor. + """ + tensor = torch.from_numpy(np.array(image).transpose((2, 0, 1))) + tensor = tensor.to(device, non_blocking=True).unsqueeze(0).float() / 255.0 + return tensor diff --git a/imaginairy/enhancers/video_interpolation/rife/msssim.py b/imaginairy/enhancers/video_interpolation/rife/msssim.py new file mode 100644 index 0000000..2d6d9fe --- /dev/null +++ b/imaginairy/enhancers/video_interpolation/rife/msssim.py @@ -0,0 +1,298 @@ +from math import exp + +import torch +from torch.nn import functional as F + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def gaussian(window_size, sigma): + gauss = torch.Tensor( + [ + exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) + for x in range(window_size) + ] + ) + return gauss / gauss.sum() + + +def create_window(window_size, channel=1): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = ( + _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0).to(device) + ) + window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() + return window + + +def create_window_3d(window_size, channel=1): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()) + _3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t()) + window = ( + _3D_window.expand(1, channel, window_size, window_size, window_size) + .contiguous() + .to(device) + ) + return window + + +def ssim( + img1, + img2, + window_size=11, + window=None, + size_average=True, + full=False, + val_range=None, +): + # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh). + if val_range is None: + max_val = 255 if torch.max(img1) > 128 else 1 + min_val = -1 if torch.min(img1) < -0.5 else 0 + L = max_val - min_val + else: + L = val_range + + padd = 0 + (_, channel, height, width) = img1.size() + if window is None: + real_size = min(window_size, height, width) + window = create_window(real_size, channel=channel).to(img1.device) + + # mu1 = F.conv2d(img1, window, padding=padd, groups=channel) + # mu2 = F.conv2d(img2, window, padding=padd, groups=channel) + mu1 = F.conv2d( + F.pad(img1, (5, 5, 5, 5), mode="replicate"), + window, + padding=padd, + groups=channel, + ) + mu2 = F.conv2d( + F.pad(img2, (5, 5, 5, 5), mode="replicate"), + window, + padding=padd, + groups=channel, + ) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = ( + F.conv2d( + F.pad(img1 * img1, (5, 5, 5, 5), "replicate"), + window, + padding=padd, + groups=channel, + ) + - mu1_sq + ) + sigma2_sq = ( + F.conv2d( + F.pad(img2 * img2, (5, 5, 5, 5), "replicate"), + window, + padding=padd, + groups=channel, + ) + - mu2_sq + ) + sigma12 = ( + F.conv2d( + F.pad(img1 * img2, (5, 5, 5, 5), "replicate"), + window, + padding=padd, + groups=channel, + ) + - mu1_mu2 + ) + + C1 = (0.01 * L) ** 2 + C2 = (0.03 * L) ** 2 + + v1 = 2.0 * sigma12 + C2 + v2 = sigma1_sq + sigma2_sq + C2 + cs = torch.mean(v1 / v2) # contrast sensitivity + + ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) + + ret = ssim_map.mean() if size_average else ssim_map.mean(1).mean(1).mean(1) + + if full: + return ret, cs + return ret + + +def ssim_matlab( + img1, + img2, + window_size=11, + window=None, + size_average=True, + full=False, + val_range=None, +): + # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh). + if val_range is None: + max_val = 255 if torch.max(img1) > 128 else 1 + min_val = -1 if torch.min(img1) < -0.5 else 0 + L = max_val - min_val + else: + L = val_range + + padd = 0 + (_, _, height, width) = img1.size() + if window is None: + real_size = min(window_size, height, width) + window = create_window_3d(real_size, channel=1).to(img1.device) + # Channel is set to 1 since we consider color images as volumetric images + + img1 = img1.unsqueeze(1) + img2 = img2.unsqueeze(1) + + mu1 = F.conv3d( + F.pad(img1, (5, 5, 5, 5, 5, 5), mode="replicate"), + window, + padding=padd, + groups=1, + ) + mu2 = F.conv3d( + F.pad(img2, (5, 5, 5, 5, 5, 5), mode="replicate"), + window, + padding=padd, + groups=1, + ) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = ( + F.conv3d( + F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), "replicate"), + window, + padding=padd, + groups=1, + ) + - mu1_sq + ) + sigma2_sq = ( + F.conv3d( + F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), "replicate"), + window, + padding=padd, + groups=1, + ) + - mu2_sq + ) + sigma12 = ( + F.conv3d( + F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), "replicate"), + window, + padding=padd, + groups=1, + ) + - mu1_mu2 + ) + + C1 = (0.01 * L) ** 2 + C2 = (0.03 * L) ** 2 + + v1 = 2.0 * sigma12 + C2 + v2 = sigma1_sq + sigma2_sq + C2 + cs = torch.mean(v1 / v2) # contrast sensitivity + + ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) + + ret = ssim_map.mean() if size_average else ssim_map.mean(1).mean(1).mean(1) + + if full: + return ret, cs + return ret + + +def msssim( + img1, img2, window_size=11, size_average=True, val_range=None, normalize=False +): + device = img1.device + weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device) + levels = weights.size()[0] + mssim = [] + mcs = [] + for _ in range(levels): + sim, cs = ssim( + img1, + img2, + window_size=window_size, + size_average=size_average, + full=True, + val_range=val_range, + ) + mssim.append(sim) + mcs.append(cs) + + img1 = F.avg_pool2d(img1, (2, 2)) + img2 = F.avg_pool2d(img2, (2, 2)) + + mssim = torch.stack(mssim) + mcs = torch.stack(mcs) + + # Normalize (to avoid NaNs during training unstable models, not compliant with original definition) + if normalize: + mssim = (mssim + 1) / 2 + mcs = (mcs + 1) / 2 + + pow1 = mcs**weights + pow2 = mssim**weights + # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/ + output = torch.prod(pow1[:-1] * pow2[-1]) + return output + + +class SSIM(torch.nn.Module): + def __init__(self, window_size=11, size_average=True, val_range=None): + super().__init__() + self.window_size = window_size + self.size_average = size_average + self.val_range = val_range + + # Assume 3 channel for SSIM + self.channel = 3 + self.window = create_window(window_size, channel=self.channel) + + def forward(self, img1, img2): + (_, channel, _, _) = img1.size() + + if channel == self.channel and self.window.dtype == img1.dtype: + window = self.window + else: + window = ( + create_window(self.window_size, channel) + .to(img1.device) + .type(img1.dtype) + ) + self.window = window + self.channel = channel + + _ssim = ssim( + img1, + img2, + window=window, + window_size=self.window_size, + size_average=self.size_average, + ) + dssim = (1 - _ssim) / 2 + return dssim + + +class MSSSIM(torch.nn.Module): + def __init__(self, window_size=11, size_average=True, channel=3): + super().__init__() + self.window_size = window_size + self.size_average = size_average + self.channel = channel + + def forward(self, img1, img2): + return msssim( + img1, img2, window_size=self.window_size, size_average=self.size_average + ) diff --git a/imaginairy/enhancers/video_interpolation/rife/warplayer.py b/imaginairy/enhancers/video_interpolation/rife/warplayer.py new file mode 100644 index 0000000..5762c15 --- /dev/null +++ b/imaginairy/enhancers/video_interpolation/rife/warplayer.py @@ -0,0 +1,39 @@ +import torch + +from . import msssim + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +backwarp_tenGrid = {} + + +def warp(tenInput, tenFlow): + k = (str(msssim.device), str(tenFlow.size())) + if k not in backwarp_tenGrid: + tenHorizontal = ( + torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device) + .view(1, 1, 1, tenFlow.shape[3]) + .expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1) + ) + tenVertical = ( + torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device) + .view(1, 1, tenFlow.shape[2], 1) + .expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3]) + ) + backwarp_tenGrid[k] = torch.cat([tenHorizontal, tenVertical], 1).to(device) + + tenFlow = torch.cat( + [ + tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), + tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0), + ], + 1, + ) + + g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1) + return torch.nn.functional.grid_sample( + input=tenInput, + grid=g, + mode="bilinear", + padding_mode="border", + align_corners=True, + ) diff --git a/imaginairy/utils/animations.py b/imaginairy/utils/animations.py index 361a102..3ebcb12 100644 --- a/imaginairy/utils/animations.py +++ b/imaginairy/utils/animations.py @@ -1,6 +1,7 @@ """Functions for creating animations from images.""" +import logging import os.path -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, List, Sequence import cv2 import torch @@ -19,8 +20,11 @@ if TYPE_CHECKING: from imaginairy.utils.img_utils import LazyLoadingImage +logger = logging.getLogger(__name__) + + def make_bounce_animation( - imgs: "List[Image.Image | LazyLoadingImage | torch.Tensor]", + imgs: "Sequence[Image.Image | LazyLoadingImage | torch.Tensor]", outpath: str, transition_duration_ms=500, start_pause_duration_ms=1000, @@ -32,7 +36,7 @@ def make_bounce_animation( last_img = imgs[-1] max_frames = int(round(transition_duration_ms / 1000 * max_fps)) - min_duration = int(1000 / 20) + min_duration = int(1000 / max_fps) if middle_imgs: progress_duration = int(round(transition_duration_ms / len(middle_imgs))) else: @@ -53,7 +57,9 @@ def make_bounce_animation( + [end_pause_duration_ms] + [progress_duration] * len(middle_imgs) ) - + logger.info( + f"Making animation with {len(converted_frames)} frames and {progress_duration:.1f}ms per transition frame." + ) make_animation(imgs=converted_frames, outpath=outpath, frame_duration_ms=durations) @@ -150,8 +156,8 @@ def select_images_by_duration_at_fps(images, durations_ms, fps=30): for i, image in enumerate(images): duration = durations_ms[i] / 1000 num_frames = int(round(duration * fps)) - print( - f"Showing image {i} for {num_frames} frames for {durations_ms[i]}ms at {fps} fps." - ) + # print( + # f"Showing image {i} for {num_frames} frames for {durations_ms[i]}ms at {fps} fps." + # ) for j in range(num_frames): yield image