parent
bb2dd45cf2
commit
907e80d1f2
@ -0,0 +1,227 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .warplayer import warp
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
bias=True,
|
||||
),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
)
|
||||
|
||||
|
||||
def conv_bn(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
bias=False,
|
||||
),
|
||||
nn.BatchNorm2d(out_planes),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
)
|
||||
|
||||
|
||||
class Head(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.cnn0 = nn.Conv2d(3, 32, 3, 2, 1)
|
||||
self.cnn1 = nn.Conv2d(32, 32, 3, 1, 1)
|
||||
self.cnn2 = nn.Conv2d(32, 32, 3, 1, 1)
|
||||
self.cnn3 = nn.ConvTranspose2d(32, 8, 4, 2, 1)
|
||||
self.relu = nn.LeakyReLU(0.2, True)
|
||||
|
||||
def forward(self, x, feat=False):
|
||||
x0 = self.cnn0(x)
|
||||
x = self.relu(x0)
|
||||
x1 = self.cnn1(x)
|
||||
x = self.relu(x1)
|
||||
x2 = self.cnn2(x)
|
||||
x = self.relu(x2)
|
||||
x3 = self.cnn3(x)
|
||||
if feat:
|
||||
return [x0, x1, x2, x3]
|
||||
return x3
|
||||
|
||||
|
||||
class ResConv(nn.Module):
|
||||
def __init__(self, c, dilation=1):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1)
|
||||
self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True)
|
||||
self.relu = nn.LeakyReLU(0.2, True)
|
||||
|
||||
def forward(self, x):
|
||||
return self.relu(self.conv(x) * self.beta + x)
|
||||
|
||||
|
||||
class IFBlock(nn.Module):
|
||||
def __init__(self, in_planes, c=64):
|
||||
super().__init__()
|
||||
self.conv0 = nn.Sequential(
|
||||
conv(in_planes, c // 2, 3, 2, 1),
|
||||
conv(c // 2, c, 3, 2, 1),
|
||||
)
|
||||
self.convblock = nn.Sequential(
|
||||
ResConv(c),
|
||||
ResConv(c),
|
||||
ResConv(c),
|
||||
ResConv(c),
|
||||
ResConv(c),
|
||||
ResConv(c),
|
||||
ResConv(c),
|
||||
ResConv(c),
|
||||
)
|
||||
self.lastconv = nn.Sequential(
|
||||
nn.ConvTranspose2d(c, 4 * 6, 4, 2, 1), nn.PixelShuffle(2)
|
||||
)
|
||||
|
||||
def forward(self, x, flow=None, scale=1):
|
||||
x = F.interpolate(
|
||||
x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False
|
||||
)
|
||||
if flow is not None:
|
||||
flow = (
|
||||
F.interpolate(
|
||||
flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False
|
||||
)
|
||||
* 1.0
|
||||
/ scale
|
||||
)
|
||||
x = torch.cat((x, flow), 1)
|
||||
feat = self.conv0(x)
|
||||
feat = self.convblock(feat)
|
||||
tmp = self.lastconv(feat)
|
||||
tmp = F.interpolate(
|
||||
tmp, scale_factor=scale, mode="bilinear", align_corners=False
|
||||
)
|
||||
flow = tmp[:, :4] * scale
|
||||
mask = tmp[:, 4:5]
|
||||
return flow, mask
|
||||
|
||||
|
||||
class IFNet(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.block0 = IFBlock(7 + 16, c=192)
|
||||
self.block1 = IFBlock(8 + 4 + 16, c=128)
|
||||
self.block2 = IFBlock(8 + 4 + 16, c=96)
|
||||
self.block3 = IFBlock(8 + 4 + 16, c=64)
|
||||
self.encode = Head()
|
||||
# self.contextnet = Contextnet()
|
||||
# self.unet = Unet()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
timestep=0.5,
|
||||
scale_list=[8, 4, 2, 1],
|
||||
training=False,
|
||||
fastmode=True,
|
||||
ensemble=False,
|
||||
):
|
||||
if training is False:
|
||||
channel = x.shape[1] // 2
|
||||
img0 = x[:, :channel]
|
||||
img1 = x[:, channel:]
|
||||
if not torch.is_tensor(timestep):
|
||||
timestep = (x[:, :1].clone() * 0 + 1) * timestep
|
||||
else:
|
||||
timestep = timestep.repeat(1, 1, img0.shape[2], img0.shape[3])
|
||||
f0 = self.encode(img0[:, :3])
|
||||
f1 = self.encode(img1[:, :3])
|
||||
flow_list = []
|
||||
merged = []
|
||||
mask_list = []
|
||||
warped_img0 = img0
|
||||
warped_img1 = img1
|
||||
flow = None
|
||||
mask = None
|
||||
block = [self.block0, self.block1, self.block2, self.block3]
|
||||
for i in range(4):
|
||||
if flow is None:
|
||||
flow, mask = block[i](
|
||||
torch.cat((img0[:, :3], img1[:, :3], f0, f1, timestep), 1),
|
||||
None,
|
||||
scale=scale_list[i],
|
||||
)
|
||||
if ensemble:
|
||||
f_, m_ = block[i](
|
||||
torch.cat((img1[:, :3], img0[:, :3], f1, f0, 1 - timestep), 1),
|
||||
None,
|
||||
scale=scale_list[i],
|
||||
)
|
||||
flow = (flow + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2
|
||||
mask = (mask + (-m_)) / 2
|
||||
else:
|
||||
wf0 = warp(f0, flow[:, :2])
|
||||
wf1 = warp(f1, flow[:, 2:4])
|
||||
fd, m0 = block[i](
|
||||
torch.cat(
|
||||
(
|
||||
warped_img0[:, :3],
|
||||
warped_img1[:, :3],
|
||||
wf0,
|
||||
wf1,
|
||||
timestep,
|
||||
mask,
|
||||
),
|
||||
1,
|
||||
),
|
||||
flow,
|
||||
scale=scale_list[i],
|
||||
)
|
||||
if ensemble:
|
||||
f_, m_ = block[i](
|
||||
torch.cat(
|
||||
(
|
||||
warped_img1[:, :3],
|
||||
warped_img0[:, :3],
|
||||
wf1,
|
||||
wf0,
|
||||
1 - timestep,
|
||||
-mask,
|
||||
),
|
||||
1,
|
||||
),
|
||||
torch.cat((flow[:, 2:4], flow[:, :2]), 1),
|
||||
scale=scale_list[i],
|
||||
)
|
||||
fd = (fd + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2
|
||||
mask = (m0 + (-m_)) / 2
|
||||
else:
|
||||
mask = m0
|
||||
flow = flow + fd
|
||||
mask_list.append(mask)
|
||||
flow_list.append(flow)
|
||||
warped_img0 = warp(img0, flow[:, :2])
|
||||
warped_img1 = warp(img1, flow[:, 2:4])
|
||||
merged.append((warped_img0, warped_img1))
|
||||
mask = torch.sigmoid(mask)
|
||||
merged[3] = warped_img0 * mask + warped_img1 * (1 - mask)
|
||||
if not fastmode:
|
||||
print("contextnet is removed")
|
||||
"""
|
||||
c0 = self.contextnet(img0, flow[:, :2])
|
||||
c1 = self.contextnet(img1, flow[:, 2:4])
|
||||
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
|
||||
res = tmp[:, :3] * 2 - 1
|
||||
merged[3] = torch.clamp(merged[3] + res, 0, 1)
|
||||
"""
|
||||
return flow_list, mask_list[3], merged
|
@ -0,0 +1,50 @@
|
||||
import torch
|
||||
|
||||
from .IFNet_HDv3 import IFNet
|
||||
|
||||
|
||||
class Model:
|
||||
def __init__(self):
|
||||
self.flownet = IFNet()
|
||||
self.version: float
|
||||
|
||||
def eval(self):
|
||||
self.flownet.eval()
|
||||
|
||||
def load_model(self, path, version: float):
|
||||
from safetensors import safe_open
|
||||
|
||||
tensors = {}
|
||||
with safe_open(path, framework="pt") as f: # type: ignore
|
||||
for key in f.keys(): # noqa
|
||||
tensors[key] = f.get_tensor(key)
|
||||
self.flownet.load_state_dict(tensors, assign=True)
|
||||
self.version = version
|
||||
|
||||
def load_model_old(self, path, rank=0):
|
||||
def convert(param):
|
||||
if rank == -1:
|
||||
return {
|
||||
k.replace("module.", ""): v
|
||||
for k, v in param.items()
|
||||
if "module." in k
|
||||
}
|
||||
else:
|
||||
return param
|
||||
|
||||
if rank <= 0:
|
||||
if torch.cuda.is_available():
|
||||
self.flownet.load_state_dict(
|
||||
convert(torch.load(f"{path}/flownet.pkl")), False
|
||||
)
|
||||
else:
|
||||
self.flownet.load_state_dict(
|
||||
convert(torch.load(f"{path}/flownet.pkl", map_location="cpu")),
|
||||
False,
|
||||
)
|
||||
|
||||
def inference(self, img0, img1, timestep=0.5, scale=1.0):
|
||||
imgs = torch.cat((img0, img1), 1)
|
||||
scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale]
|
||||
flow, mask, merged = self.flownet(imgs, timestep, scale_list)
|
||||
return merged[3]
|
@ -0,0 +1,410 @@
|
||||
import _thread
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
from functools import lru_cache
|
||||
from queue import Queue
|
||||
from typing import List
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch.nn import functional as F
|
||||
from tqdm import tqdm
|
||||
|
||||
from imaginairy.utils import get_device
|
||||
from imaginairy.utils.model_manager import get_cached_url_path
|
||||
|
||||
from .msssim import ssim_matlab
|
||||
from .RIFE_HDv3 import Model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def transfer_audio(sourceVideo, targetVideo):
|
||||
tempAudioFileName = "./temp/audio.mkv"
|
||||
|
||||
# split audio from original video file and store in "temp" directory
|
||||
if True:
|
||||
# clear old "temp" directory if it exits
|
||||
if os.path.isdir("temp"):
|
||||
# remove temp directory
|
||||
shutil.rmtree("temp")
|
||||
# create new "temp" directory
|
||||
os.makedirs("temp")
|
||||
# extract audio from video
|
||||
os.system(f'ffmpeg -y -i "{sourceVideo}" -c:a copy -vn {tempAudioFileName}')
|
||||
|
||||
targetNoAudio = (
|
||||
os.path.splitext(targetVideo)[0] + "_noaudio" + os.path.splitext(targetVideo)[1]
|
||||
)
|
||||
os.rename(targetVideo, targetNoAudio)
|
||||
# combine audio file and new video file
|
||||
os.system(
|
||||
f'ffmpeg -y -i "{targetNoAudio}" -i {tempAudioFileName} -c copy "{targetVideo}"'
|
||||
)
|
||||
|
||||
if (
|
||||
os.path.getsize(targetVideo) == 0
|
||||
): # if ffmpeg failed to merge the video and audio together try converting the audio to aac
|
||||
tempAudioFileName = "./temp/audio.m4a"
|
||||
os.system(
|
||||
f'ffmpeg -y -i "{sourceVideo}" -c:a aac -b:a 160k -vn {tempAudioFileName}'
|
||||
)
|
||||
os.system(
|
||||
'ffmpeg -y -i "{}" -i {} -c copy "{}"'.format(
|
||||
targetNoAudio, tempAudioFileName, targetVideo
|
||||
)
|
||||
)
|
||||
if (
|
||||
os.path.getsize(targetVideo) == 0
|
||||
): # if aac is not supported by selected format
|
||||
os.rename(targetNoAudio, targetVideo)
|
||||
print("Audio transfer failed. Interpolated video will have no audio")
|
||||
else:
|
||||
print(
|
||||
"Lossless audio transfer failed. Audio was transcoded to AAC (M4A) instead."
|
||||
)
|
||||
|
||||
# remove audio-less video
|
||||
os.remove(targetNoAudio)
|
||||
else:
|
||||
os.remove(targetNoAudio)
|
||||
|
||||
# remove temp directory
|
||||
shutil.rmtree("temp")
|
||||
|
||||
|
||||
RIFE_WEIGHTS_URL = "https://huggingface.co/imaginairy/rife-interpolation/resolve/26442e52cc30b88c5cb490702647b8de9aaee8a7/rife-flownet-4.13.2.safetensors"
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def load_rife_model(model_path=None, version=4.13, device=None):
|
||||
if model_path is None:
|
||||
model_path = RIFE_WEIGHTS_URL
|
||||
model_path = get_cached_url_path(model_path)
|
||||
device = device if device else get_device()
|
||||
model = Model()
|
||||
model.load_model(model_path, version=version)
|
||||
model.eval()
|
||||
model.flownet.to(device)
|
||||
return model
|
||||
|
||||
|
||||
def make_inference(I0, I1, n, *, model, scale):
|
||||
if model.version >= 3.9:
|
||||
res = []
|
||||
for i in range(n):
|
||||
res.append(model.inference(I0, I1, (i + 1) * 1.0 / (n + 1), scale))
|
||||
return res
|
||||
else:
|
||||
middle = model.inference(I0, I1, scale)
|
||||
if n == 1:
|
||||
return [middle]
|
||||
first_half = make_inference(I0, middle, n=n // 2, model=model, scale=scale)
|
||||
second_half = make_inference(middle, I1, n=n // 2, model=model, scale=scale)
|
||||
if n % 2:
|
||||
return [*first_half, middle, *second_half]
|
||||
else:
|
||||
return [*first_half, *second_half]
|
||||
|
||||
|
||||
def interpolate_video_file(
|
||||
video_path: str | None = None,
|
||||
images_source_path: str | None = None,
|
||||
scale: float = 1.0,
|
||||
vid_out_name: str | None = None,
|
||||
target_fps: float | None = None,
|
||||
fps_multiplier: int = 2,
|
||||
model_weights_path: str | None = None,
|
||||
fp16: bool = False,
|
||||
montage: bool = False,
|
||||
png_out: bool = False,
|
||||
output_extension: str = "mp4",
|
||||
device=None,
|
||||
):
|
||||
assert video_path is not None or images_source_path is not None
|
||||
assert scale in [0.25, 0.5, 1.0, 2.0, 4.0]
|
||||
device = device if device else get_device()
|
||||
|
||||
if images_source_path is not None:
|
||||
png_out = True
|
||||
|
||||
torch.set_grad_enabled(False)
|
||||
if torch.cuda.is_available():
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
if fp16:
|
||||
torch.set_default_tensor_type(torch.cuda.HalfTensor) # type: ignore
|
||||
|
||||
model = load_rife_model(model_weights_path, version=4.13)
|
||||
logger.info(f"Loaded RIFE from {model_weights_path}")
|
||||
|
||||
if video_path is not None:
|
||||
import skvideo.io
|
||||
|
||||
videoCapture = cv2.VideoCapture(video_path)
|
||||
fps = videoCapture.get(cv2.CAP_PROP_FPS)
|
||||
tot_frame = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT)
|
||||
videoCapture.release()
|
||||
if target_fps is None:
|
||||
fpsNotAssigned = True
|
||||
target_fps = fps * fps_multiplier
|
||||
else:
|
||||
fpsNotAssigned = False
|
||||
videogen = skvideo.io.vreader(video_path)
|
||||
lastframe = next(videogen)
|
||||
fourcc = cv2.VideoWriter_fourcc("m", "p", "4", "v") # type: ignore
|
||||
video_path_wo_ext, ext = os.path.splitext(video_path)
|
||||
print(
|
||||
"{}.{}, {} frames in total, {}FPS to {}FPS".format(
|
||||
video_path_wo_ext, output_extension, tot_frame, fps, target_fps
|
||||
)
|
||||
)
|
||||
if png_out is False and fpsNotAssigned is True:
|
||||
print("The audio will be merged after interpolation process")
|
||||
else:
|
||||
print("Will not merge audio because using png or fps flag!")
|
||||
else:
|
||||
assert images_source_path is not None
|
||||
videogen = []
|
||||
for f in os.listdir(images_source_path):
|
||||
if "png" in f:
|
||||
videogen.append(f)
|
||||
tot_frame = len(videogen)
|
||||
videogen.sort(key=lambda x: int(x[:-4]))
|
||||
lastframe = cv2.imread(
|
||||
os.path.join(images_source_path, videogen[0]), cv2.IMREAD_UNCHANGED
|
||||
)[:, :, ::-1].copy()
|
||||
videogen = videogen[1:]
|
||||
h, w, _ = lastframe.shape
|
||||
|
||||
vid_out = None
|
||||
if png_out:
|
||||
if not os.path.exists("vid_out"):
|
||||
os.mkdir("vid_out")
|
||||
else:
|
||||
if vid_out_name is None:
|
||||
assert video_path_wo_ext is not None
|
||||
assert target_fps is not None
|
||||
vid_out_name = f"{video_path_wo_ext}_{fps_multiplier}X_{int(np.round(target_fps))}fps.{output_extension}"
|
||||
vid_out = cv2.VideoWriter(vid_out_name, fourcc, target_fps, (w, h)) # type: ignore
|
||||
|
||||
def clear_write_buffer(png_out, write_buffer):
|
||||
cnt = 0
|
||||
while True:
|
||||
item = write_buffer.get()
|
||||
if item is None:
|
||||
break
|
||||
if png_out:
|
||||
cv2.imwrite(f"vid_out/{cnt:0>7d}.png", item[:, :, ::-1])
|
||||
cnt += 1
|
||||
else:
|
||||
vid_out.write(item[:, :, ::-1])
|
||||
|
||||
def build_read_buffer(img, montage, read_buffer, videogen):
|
||||
try:
|
||||
for frame in videogen:
|
||||
if img is not None:
|
||||
frame = cv2.imread(os.path.join(img, frame), cv2.IMREAD_UNCHANGED)[
|
||||
:, :, ::-1
|
||||
].copy()
|
||||
if montage:
|
||||
frame = frame[:, left : left + w]
|
||||
read_buffer.put(frame)
|
||||
except Exception as e: # noqa
|
||||
print(f"skipping frame due to error: {e}")
|
||||
read_buffer.put(None)
|
||||
|
||||
def pad_image(img):
|
||||
if fp16:
|
||||
return F.pad(img, padding).half()
|
||||
else:
|
||||
return F.pad(img, padding)
|
||||
|
||||
if montage:
|
||||
left = w // 4
|
||||
w = w // 2
|
||||
tmp = max(128, int(128 / scale))
|
||||
ph = ((h - 1) // tmp + 1) * tmp
|
||||
pw = ((w - 1) // tmp + 1) * tmp
|
||||
padding = (0, pw - w, 0, ph - h)
|
||||
pbar = tqdm(total=tot_frame)
|
||||
if montage:
|
||||
lastframe = lastframe[:, left : left + w]
|
||||
write_buffer: Queue = Queue(maxsize=500)
|
||||
read_buffer: Queue = Queue(maxsize=500)
|
||||
_thread.start_new_thread(
|
||||
build_read_buffer, (images_source_path, montage, read_buffer, videogen)
|
||||
)
|
||||
_thread.start_new_thread(clear_write_buffer, (png_out, write_buffer))
|
||||
|
||||
I1 = (
|
||||
torch.from_numpy(np.transpose(lastframe, (2, 0, 1)))
|
||||
.to(device, non_blocking=True)
|
||||
.unsqueeze(0)
|
||||
.float()
|
||||
/ 255.0
|
||||
)
|
||||
I1 = pad_image(I1)
|
||||
temp = None # save lastframe when processing static frame
|
||||
|
||||
while True:
|
||||
if temp is not None:
|
||||
frame = temp
|
||||
temp = None
|
||||
else:
|
||||
frame = read_buffer.get()
|
||||
if frame is None:
|
||||
break
|
||||
I0 = I1
|
||||
I1 = (
|
||||
torch.from_numpy(np.transpose(frame, (2, 0, 1)))
|
||||
.to(device, non_blocking=True)
|
||||
.unsqueeze(0)
|
||||
.float()
|
||||
/ 255.0
|
||||
)
|
||||
I1 = pad_image(I1)
|
||||
I0_small = F.interpolate(I0, (32, 32), mode="bilinear", align_corners=False)
|
||||
I1_small = F.interpolate(I1, (32, 32), mode="bilinear", align_corners=False)
|
||||
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
|
||||
|
||||
break_flag = False
|
||||
if ssim > 0.996:
|
||||
frame = read_buffer.get() # read a new frame
|
||||
if frame is None:
|
||||
break_flag = True
|
||||
frame = lastframe
|
||||
else:
|
||||
temp = frame
|
||||
I1 = (
|
||||
torch.from_numpy(np.transpose(frame, (2, 0, 1)))
|
||||
.to(device, non_blocking=True)
|
||||
.unsqueeze(0)
|
||||
.float()
|
||||
/ 255.0
|
||||
)
|
||||
I1 = pad_image(I1)
|
||||
I1 = model.inference(I0, I1, scale)
|
||||
I1_small = F.interpolate(I1, (32, 32), mode="bilinear", align_corners=False)
|
||||
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
|
||||
frame = (I1[0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w]
|
||||
|
||||
if ssim < 0.2:
|
||||
output = []
|
||||
for i in range(fps_multiplier - 1):
|
||||
output.append(I0)
|
||||
"""
|
||||
output = []
|
||||
step = 1 / fps_multiplier
|
||||
alpha = 0
|
||||
for i in range(fps_multiplier - 1):
|
||||
alpha += step
|
||||
beta = 1-alpha
|
||||
output.append(torch.from_numpy(np.transpose((cv2.addWeighted(frame[:, :, ::-1], alpha, lastframe[:, :, ::-1], beta, 0)[:, :, ::-1].copy()), (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.)
|
||||
"""
|
||||
else:
|
||||
output = make_inference(
|
||||
I0, I1, fps_multiplier - 1, model=model, scale=scale
|
||||
)
|
||||
|
||||
if montage:
|
||||
write_buffer.put(np.concatenate((lastframe, lastframe), 1))
|
||||
for mid in output:
|
||||
mid = (mid[0] * 255.0).byte().cpu().numpy().transpose(1, 2, 0)
|
||||
write_buffer.put(np.concatenate((lastframe, mid[:h, :w]), 1))
|
||||
else:
|
||||
write_buffer.put(lastframe)
|
||||
for mid in output:
|
||||
mid = (mid[0] * 255.0).byte().cpu().numpy().transpose(1, 2, 0)
|
||||
write_buffer.put(mid[:h, :w])
|
||||
pbar.update(1)
|
||||
lastframe = frame
|
||||
if break_flag:
|
||||
break
|
||||
|
||||
if montage:
|
||||
write_buffer.put(np.concatenate((lastframe, lastframe), 1))
|
||||
else:
|
||||
write_buffer.put(lastframe)
|
||||
|
||||
while not write_buffer.empty():
|
||||
time.sleep(0.1)
|
||||
pbar.close()
|
||||
if vid_out is not None:
|
||||
vid_out.release()
|
||||
assert vid_out_name is not None
|
||||
|
||||
# move audio to new video file if appropriate
|
||||
if png_out is False and fpsNotAssigned is True and video_path is not None:
|
||||
try:
|
||||
transfer_audio(video_path, vid_out_name)
|
||||
except Exception as e: # noqa
|
||||
logger.info(
|
||||
f"Audio transfer failed. Interpolated video will have no audio. {e}"
|
||||
)
|
||||
targetNoAudio = (
|
||||
os.path.splitext(vid_out_name)[0]
|
||||
+ "_noaudio"
|
||||
+ os.path.splitext(vid_out_name)[1]
|
||||
)
|
||||
os.rename(targetNoAudio, vid_out_name)
|
||||
|
||||
|
||||
def pad_image(img, scale):
|
||||
tmp = max(128, int(128 / scale))
|
||||
ph, pw = (
|
||||
((img.shape[1] - 1) // tmp + 1) * tmp,
|
||||
((img.shape[2] - 1) // tmp + 1) * tmp,
|
||||
)
|
||||
padding = (0, pw - img.shape[2], 0, ph - img.shape[1])
|
||||
return F.pad(img, padding)
|
||||
|
||||
|
||||
def interpolate_images(
|
||||
image_list,
|
||||
scale=1.0,
|
||||
fps_multiplier=2,
|
||||
model_weights_path=None,
|
||||
device=None,
|
||||
) -> List[Image.Image]:
|
||||
assert scale in [0.25, 0.5, 1.0, 2.0, 4.0]
|
||||
torch.set_grad_enabled(False)
|
||||
device = device if device else get_device()
|
||||
model = load_rife_model(model_weights_path, version=4.13)
|
||||
|
||||
interpolated_images = []
|
||||
for i in range(len(image_list) - 1):
|
||||
I0 = image_to_tensor(image_list[i], device)
|
||||
I1 = image_to_tensor(image_list[i + 1], device)
|
||||
# I0, I1 = pad_image(I0, scale), pad_image(I1, scale)
|
||||
|
||||
interpolated = make_inference(
|
||||
I0, I1, n=fps_multiplier - 1, model=model, scale=scale
|
||||
)
|
||||
interpolated_images.append(image_list[i])
|
||||
for img in interpolated:
|
||||
img = (img[0] * 255.0).byte().cpu().numpy().transpose(1, 2, 0)
|
||||
interpolated_images.append(Image.fromarray(img))
|
||||
|
||||
interpolated_images.append(image_list[-1])
|
||||
return interpolated_images
|
||||
|
||||
|
||||
def image_to_tensor(image, device):
|
||||
"""
|
||||
Converts a PIL image to a PyTorch tensor.
|
||||
|
||||
Args:
|
||||
- image (PIL.Image): The image to convert.
|
||||
- device (torch.device): The device to use (CPU or CUDA).
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The image converted to a PyTorch tensor.
|
||||
"""
|
||||
tensor = torch.from_numpy(np.array(image).transpose((2, 0, 1)))
|
||||
tensor = tensor.to(device, non_blocking=True).unsqueeze(0).float() / 255.0
|
||||
return tensor
|
@ -0,0 +1,298 @@
|
||||
from math import exp
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
def gaussian(window_size, sigma):
|
||||
gauss = torch.Tensor(
|
||||
[
|
||||
exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2))
|
||||
for x in range(window_size)
|
||||
]
|
||||
)
|
||||
return gauss / gauss.sum()
|
||||
|
||||
|
||||
def create_window(window_size, channel=1):
|
||||
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
||||
_2D_window = (
|
||||
_1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0).to(device)
|
||||
)
|
||||
window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
|
||||
return window
|
||||
|
||||
|
||||
def create_window_3d(window_size, channel=1):
|
||||
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
||||
_2D_window = _1D_window.mm(_1D_window.t())
|
||||
_3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t())
|
||||
window = (
|
||||
_3D_window.expand(1, channel, window_size, window_size, window_size)
|
||||
.contiguous()
|
||||
.to(device)
|
||||
)
|
||||
return window
|
||||
|
||||
|
||||
def ssim(
|
||||
img1,
|
||||
img2,
|
||||
window_size=11,
|
||||
window=None,
|
||||
size_average=True,
|
||||
full=False,
|
||||
val_range=None,
|
||||
):
|
||||
# Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
|
||||
if val_range is None:
|
||||
max_val = 255 if torch.max(img1) > 128 else 1
|
||||
min_val = -1 if torch.min(img1) < -0.5 else 0
|
||||
L = max_val - min_val
|
||||
else:
|
||||
L = val_range
|
||||
|
||||
padd = 0
|
||||
(_, channel, height, width) = img1.size()
|
||||
if window is None:
|
||||
real_size = min(window_size, height, width)
|
||||
window = create_window(real_size, channel=channel).to(img1.device)
|
||||
|
||||
# mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
|
||||
# mu2 = F.conv2d(img2, window, padding=padd, groups=channel)
|
||||
mu1 = F.conv2d(
|
||||
F.pad(img1, (5, 5, 5, 5), mode="replicate"),
|
||||
window,
|
||||
padding=padd,
|
||||
groups=channel,
|
||||
)
|
||||
mu2 = F.conv2d(
|
||||
F.pad(img2, (5, 5, 5, 5), mode="replicate"),
|
||||
window,
|
||||
padding=padd,
|
||||
groups=channel,
|
||||
)
|
||||
|
||||
mu1_sq = mu1.pow(2)
|
||||
mu2_sq = mu2.pow(2)
|
||||
mu1_mu2 = mu1 * mu2
|
||||
|
||||
sigma1_sq = (
|
||||
F.conv2d(
|
||||
F.pad(img1 * img1, (5, 5, 5, 5), "replicate"),
|
||||
window,
|
||||
padding=padd,
|
||||
groups=channel,
|
||||
)
|
||||
- mu1_sq
|
||||
)
|
||||
sigma2_sq = (
|
||||
F.conv2d(
|
||||
F.pad(img2 * img2, (5, 5, 5, 5), "replicate"),
|
||||
window,
|
||||
padding=padd,
|
||||
groups=channel,
|
||||
)
|
||||
- mu2_sq
|
||||
)
|
||||
sigma12 = (
|
||||
F.conv2d(
|
||||
F.pad(img1 * img2, (5, 5, 5, 5), "replicate"),
|
||||
window,
|
||||
padding=padd,
|
||||
groups=channel,
|
||||
)
|
||||
- mu1_mu2
|
||||
)
|
||||
|
||||
C1 = (0.01 * L) ** 2
|
||||
C2 = (0.03 * L) ** 2
|
||||
|
||||
v1 = 2.0 * sigma12 + C2
|
||||
v2 = sigma1_sq + sigma2_sq + C2
|
||||
cs = torch.mean(v1 / v2) # contrast sensitivity
|
||||
|
||||
ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
|
||||
|
||||
ret = ssim_map.mean() if size_average else ssim_map.mean(1).mean(1).mean(1)
|
||||
|
||||
if full:
|
||||
return ret, cs
|
||||
return ret
|
||||
|
||||
|
||||
def ssim_matlab(
|
||||
img1,
|
||||
img2,
|
||||
window_size=11,
|
||||
window=None,
|
||||
size_average=True,
|
||||
full=False,
|
||||
val_range=None,
|
||||
):
|
||||
# Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
|
||||
if val_range is None:
|
||||
max_val = 255 if torch.max(img1) > 128 else 1
|
||||
min_val = -1 if torch.min(img1) < -0.5 else 0
|
||||
L = max_val - min_val
|
||||
else:
|
||||
L = val_range
|
||||
|
||||
padd = 0
|
||||
(_, _, height, width) = img1.size()
|
||||
if window is None:
|
||||
real_size = min(window_size, height, width)
|
||||
window = create_window_3d(real_size, channel=1).to(img1.device)
|
||||
# Channel is set to 1 since we consider color images as volumetric images
|
||||
|
||||
img1 = img1.unsqueeze(1)
|
||||
img2 = img2.unsqueeze(1)
|
||||
|
||||
mu1 = F.conv3d(
|
||||
F.pad(img1, (5, 5, 5, 5, 5, 5), mode="replicate"),
|
||||
window,
|
||||
padding=padd,
|
||||
groups=1,
|
||||
)
|
||||
mu2 = F.conv3d(
|
||||
F.pad(img2, (5, 5, 5, 5, 5, 5), mode="replicate"),
|
||||
window,
|
||||
padding=padd,
|
||||
groups=1,
|
||||
)
|
||||
|
||||
mu1_sq = mu1.pow(2)
|
||||
mu2_sq = mu2.pow(2)
|
||||
mu1_mu2 = mu1 * mu2
|
||||
|
||||
sigma1_sq = (
|
||||
F.conv3d(
|
||||
F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), "replicate"),
|
||||
window,
|
||||
padding=padd,
|
||||
groups=1,
|
||||
)
|
||||
- mu1_sq
|
||||
)
|
||||
sigma2_sq = (
|
||||
F.conv3d(
|
||||
F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), "replicate"),
|
||||
window,
|
||||
padding=padd,
|
||||
groups=1,
|
||||
)
|
||||
- mu2_sq
|
||||
)
|
||||
sigma12 = (
|
||||
F.conv3d(
|
||||
F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), "replicate"),
|
||||
window,
|
||||
padding=padd,
|
||||
groups=1,
|
||||
)
|
||||
- mu1_mu2
|
||||
)
|
||||
|
||||
C1 = (0.01 * L) ** 2
|
||||
C2 = (0.03 * L) ** 2
|
||||
|
||||
v1 = 2.0 * sigma12 + C2
|
||||
v2 = sigma1_sq + sigma2_sq + C2
|
||||
cs = torch.mean(v1 / v2) # contrast sensitivity
|
||||
|
||||
ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
|
||||
|
||||
ret = ssim_map.mean() if size_average else ssim_map.mean(1).mean(1).mean(1)
|
||||
|
||||
if full:
|
||||
return ret, cs
|
||||
return ret
|
||||
|
||||
|
||||
def msssim(
|
||||
img1, img2, window_size=11, size_average=True, val_range=None, normalize=False
|
||||
):
|
||||
device = img1.device
|
||||
weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device)
|
||||
levels = weights.size()[0]
|
||||
mssim = []
|
||||
mcs = []
|
||||
for _ in range(levels):
|
||||
sim, cs = ssim(
|
||||
img1,
|
||||
img2,
|
||||
window_size=window_size,
|
||||
size_average=size_average,
|
||||
full=True,
|
||||
val_range=val_range,
|
||||
)
|
||||
mssim.append(sim)
|
||||
mcs.append(cs)
|
||||
|
||||
img1 = F.avg_pool2d(img1, (2, 2))
|
||||
img2 = F.avg_pool2d(img2, (2, 2))
|
||||
|
||||
mssim = torch.stack(mssim)
|
||||
mcs = torch.stack(mcs)
|
||||
|
||||
# Normalize (to avoid NaNs during training unstable models, not compliant with original definition)
|
||||
if normalize:
|
||||
mssim = (mssim + 1) / 2
|
||||
mcs = (mcs + 1) / 2
|
||||
|
||||
pow1 = mcs**weights
|
||||
pow2 = mssim**weights
|
||||
# From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/
|
||||
output = torch.prod(pow1[:-1] * pow2[-1])
|
||||
return output
|
||||
|
||||
|
||||
class SSIM(torch.nn.Module):
|
||||
def __init__(self, window_size=11, size_average=True, val_range=None):
|
||||
super().__init__()
|
||||
self.window_size = window_size
|
||||
self.size_average = size_average
|
||||
self.val_range = val_range
|
||||
|
||||
# Assume 3 channel for SSIM
|
||||
self.channel = 3
|
||||
self.window = create_window(window_size, channel=self.channel)
|
||||
|
||||
def forward(self, img1, img2):
|
||||
(_, channel, _, _) = img1.size()
|
||||
|
||||
if channel == self.channel and self.window.dtype == img1.dtype:
|
||||
window = self.window
|
||||
else:
|
||||
window = (
|
||||
create_window(self.window_size, channel)
|
||||
.to(img1.device)
|
||||
.type(img1.dtype)
|
||||
)
|
||||
self.window = window
|
||||
self.channel = channel
|
||||
|
||||
_ssim = ssim(
|
||||
img1,
|
||||
img2,
|
||||
window=window,
|
||||
window_size=self.window_size,
|
||||
size_average=self.size_average,
|
||||
)
|
||||
dssim = (1 - _ssim) / 2
|
||||
return dssim
|
||||
|
||||
|
||||
class MSSSIM(torch.nn.Module):
|
||||
def __init__(self, window_size=11, size_average=True, channel=3):
|
||||
super().__init__()
|
||||
self.window_size = window_size
|
||||
self.size_average = size_average
|
||||
self.channel = channel
|
||||
|
||||
def forward(self, img1, img2):
|
||||
return msssim(
|
||||
img1, img2, window_size=self.window_size, size_average=self.size_average
|
||||
)
|
@ -0,0 +1,39 @@
|
||||
import torch
|
||||
|
||||
from . import msssim
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
backwarp_tenGrid = {}
|
||||
|
||||
|
||||
def warp(tenInput, tenFlow):
|
||||
k = (str(msssim.device), str(tenFlow.size()))
|
||||
if k not in backwarp_tenGrid:
|
||||
tenHorizontal = (
|
||||
torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device)
|
||||
.view(1, 1, 1, tenFlow.shape[3])
|
||||
.expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
|
||||
)
|
||||
tenVertical = (
|
||||
torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device)
|
||||
.view(1, 1, tenFlow.shape[2], 1)
|
||||
.expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
|
||||
)
|
||||
backwarp_tenGrid[k] = torch.cat([tenHorizontal, tenVertical], 1).to(device)
|
||||
|
||||
tenFlow = torch.cat(
|
||||
[
|
||||
tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
|
||||
tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0),
|
||||
],
|
||||
1,
|
||||
)
|
||||
|
||||
g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
|
||||
return torch.nn.functional.grid_sample(
|
||||
input=tenInput,
|
||||
grid=g,
|
||||
mode="bilinear",
|
||||
padding_mode="border",
|
||||
align_corners=True,
|
||||
)
|
Loading…
Reference in New Issue