feature: video interpolation (#448)

- uses rife algorithm to interpolate frames
pull/449/head
Bryce Drennan 4 months ago committed by GitHub
parent bb2dd45cf2
commit 907e80d1f2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -19,6 +19,7 @@ from PIL import Image
from torchvision.transforms import ToTensor
from imaginairy import config
from imaginairy.enhancers.video_interpolation.rife.interpolate import interpolate_images
from imaginairy.schema import LazyLoadingImage
from imaginairy.utils import (
default,
@ -91,7 +92,6 @@ def generate_video(
)
logger.warning(msg)
start_time = time.perf_counter()
output_fps = default(output_fps, fps_id)
video_model_config = config.MODEL_WEIGHT_CONFIG_LOOKUP.get(model_name, None)
@ -139,6 +139,7 @@ def generate_video(
expected_size = (vid_width, vid_height)
for _ in range(repetitions):
for input_path in all_img_paths:
start_time = time.perf_counter()
_seed = default(seed, random.randint(0, 1000000))
torch.manual_seed(_seed)
logger.info(
@ -318,15 +319,32 @@ def save_video(samples: torch.Tensor, video_filename: str, output_fps: int):
os.system(f"ffmpeg -i {video_filename} -c:v libx264 {video_path_h264}")
def save_video_bounce(samples: torch.Tensor, video_filename: str, output_fps: int):
def save_video_bounce(
samples: torch.Tensor, video_filename: str, output_fps: int, interpolate_fps=60
):
frames_np = (
(torch.permute(samples, (0, 2, 3, 1)) * 255).cpu().numpy().astype(np.uint8)
)
transition_duration = len(frames_np) / float(output_fps)
frames_pil = [Image.fromarray(frame) for frame in frames_np]
if interpolate_fps:
# bring it up to at least 60 fps
fps_multiplier = int(math.ceil(interpolate_fps / output_fps))
frames_pil = interpolate_images(frames_pil, fps_multiplier=fps_multiplier)
transition_duration_ms = transition_duration * 1000
logger.info(
f"Interpolated from {len(frames_np)} to {len(frames_pil)} frames ({fps_multiplier} multiplier)"
)
logger.info(
f"Making bounce animation with transition duration {transition_duration_ms:.1f}ms"
)
make_bounce_animation(
imgs=[Image.fromarray(frame) for frame in frames_np],
imgs=frames_pil,
outpath=video_filename,
end_pause_duration_ms=750,
transition_duration_ms=transition_duration_ms,
end_pause_duration_ms=100,
max_fps=60,
)

@ -0,0 +1,227 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .warplayer import warp
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(
in_planes,
out_planes,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=True,
),
nn.LeakyReLU(0.2, True),
)
def conv_bn(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(
in_planes,
out_planes,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=False,
),
nn.BatchNorm2d(out_planes),
nn.LeakyReLU(0.2, True),
)
class Head(nn.Module):
def __init__(self):
super().__init__()
self.cnn0 = nn.Conv2d(3, 32, 3, 2, 1)
self.cnn1 = nn.Conv2d(32, 32, 3, 1, 1)
self.cnn2 = nn.Conv2d(32, 32, 3, 1, 1)
self.cnn3 = nn.ConvTranspose2d(32, 8, 4, 2, 1)
self.relu = nn.LeakyReLU(0.2, True)
def forward(self, x, feat=False):
x0 = self.cnn0(x)
x = self.relu(x0)
x1 = self.cnn1(x)
x = self.relu(x1)
x2 = self.cnn2(x)
x = self.relu(x2)
x3 = self.cnn3(x)
if feat:
return [x0, x1, x2, x3]
return x3
class ResConv(nn.Module):
def __init__(self, c, dilation=1):
super().__init__()
self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1)
self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True)
self.relu = nn.LeakyReLU(0.2, True)
def forward(self, x):
return self.relu(self.conv(x) * self.beta + x)
class IFBlock(nn.Module):
def __init__(self, in_planes, c=64):
super().__init__()
self.conv0 = nn.Sequential(
conv(in_planes, c // 2, 3, 2, 1),
conv(c // 2, c, 3, 2, 1),
)
self.convblock = nn.Sequential(
ResConv(c),
ResConv(c),
ResConv(c),
ResConv(c),
ResConv(c),
ResConv(c),
ResConv(c),
ResConv(c),
)
self.lastconv = nn.Sequential(
nn.ConvTranspose2d(c, 4 * 6, 4, 2, 1), nn.PixelShuffle(2)
)
def forward(self, x, flow=None, scale=1):
x = F.interpolate(
x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False
)
if flow is not None:
flow = (
F.interpolate(
flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False
)
* 1.0
/ scale
)
x = torch.cat((x, flow), 1)
feat = self.conv0(x)
feat = self.convblock(feat)
tmp = self.lastconv(feat)
tmp = F.interpolate(
tmp, scale_factor=scale, mode="bilinear", align_corners=False
)
flow = tmp[:, :4] * scale
mask = tmp[:, 4:5]
return flow, mask
class IFNet(nn.Module):
def __init__(self):
super().__init__()
self.block0 = IFBlock(7 + 16, c=192)
self.block1 = IFBlock(8 + 4 + 16, c=128)
self.block2 = IFBlock(8 + 4 + 16, c=96)
self.block3 = IFBlock(8 + 4 + 16, c=64)
self.encode = Head()
# self.contextnet = Contextnet()
# self.unet = Unet()
def forward(
self,
x,
timestep=0.5,
scale_list=[8, 4, 2, 1],
training=False,
fastmode=True,
ensemble=False,
):
if training is False:
channel = x.shape[1] // 2
img0 = x[:, :channel]
img1 = x[:, channel:]
if not torch.is_tensor(timestep):
timestep = (x[:, :1].clone() * 0 + 1) * timestep
else:
timestep = timestep.repeat(1, 1, img0.shape[2], img0.shape[3])
f0 = self.encode(img0[:, :3])
f1 = self.encode(img1[:, :3])
flow_list = []
merged = []
mask_list = []
warped_img0 = img0
warped_img1 = img1
flow = None
mask = None
block = [self.block0, self.block1, self.block2, self.block3]
for i in range(4):
if flow is None:
flow, mask = block[i](
torch.cat((img0[:, :3], img1[:, :3], f0, f1, timestep), 1),
None,
scale=scale_list[i],
)
if ensemble:
f_, m_ = block[i](
torch.cat((img1[:, :3], img0[:, :3], f1, f0, 1 - timestep), 1),
None,
scale=scale_list[i],
)
flow = (flow + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2
mask = (mask + (-m_)) / 2
else:
wf0 = warp(f0, flow[:, :2])
wf1 = warp(f1, flow[:, 2:4])
fd, m0 = block[i](
torch.cat(
(
warped_img0[:, :3],
warped_img1[:, :3],
wf0,
wf1,
timestep,
mask,
),
1,
),
flow,
scale=scale_list[i],
)
if ensemble:
f_, m_ = block[i](
torch.cat(
(
warped_img1[:, :3],
warped_img0[:, :3],
wf1,
wf0,
1 - timestep,
-mask,
),
1,
),
torch.cat((flow[:, 2:4], flow[:, :2]), 1),
scale=scale_list[i],
)
fd = (fd + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2
mask = (m0 + (-m_)) / 2
else:
mask = m0
flow = flow + fd
mask_list.append(mask)
flow_list.append(flow)
warped_img0 = warp(img0, flow[:, :2])
warped_img1 = warp(img1, flow[:, 2:4])
merged.append((warped_img0, warped_img1))
mask = torch.sigmoid(mask)
merged[3] = warped_img0 * mask + warped_img1 * (1 - mask)
if not fastmode:
print("contextnet is removed")
"""
c0 = self.contextnet(img0, flow[:, :2])
c1 = self.contextnet(img1, flow[:, 2:4])
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
res = tmp[:, :3] * 2 - 1
merged[3] = torch.clamp(merged[3] + res, 0, 1)
"""
return flow_list, mask_list[3], merged

@ -0,0 +1,50 @@
import torch
from .IFNet_HDv3 import IFNet
class Model:
def __init__(self):
self.flownet = IFNet()
self.version: float
def eval(self):
self.flownet.eval()
def load_model(self, path, version: float):
from safetensors import safe_open
tensors = {}
with safe_open(path, framework="pt") as f: # type: ignore
for key in f.keys(): # noqa
tensors[key] = f.get_tensor(key)
self.flownet.load_state_dict(tensors, assign=True)
self.version = version
def load_model_old(self, path, rank=0):
def convert(param):
if rank == -1:
return {
k.replace("module.", ""): v
for k, v in param.items()
if "module." in k
}
else:
return param
if rank <= 0:
if torch.cuda.is_available():
self.flownet.load_state_dict(
convert(torch.load(f"{path}/flownet.pkl")), False
)
else:
self.flownet.load_state_dict(
convert(torch.load(f"{path}/flownet.pkl", map_location="cpu")),
False,
)
def inference(self, img0, img1, timestep=0.5, scale=1.0):
imgs = torch.cat((img0, img1), 1)
scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale]
flow, mask, merged = self.flownet(imgs, timestep, scale_list)
return merged[3]

@ -0,0 +1,410 @@
import _thread
import logging
import os
import shutil
import time
from functools import lru_cache
from queue import Queue
from typing import List
import cv2
import numpy as np
import torch
from PIL import Image
from torch.nn import functional as F
from tqdm import tqdm
from imaginairy.utils import get_device
from imaginairy.utils.model_manager import get_cached_url_path
from .msssim import ssim_matlab
from .RIFE_HDv3 import Model
logger = logging.getLogger(__name__)
def transfer_audio(sourceVideo, targetVideo):
tempAudioFileName = "./temp/audio.mkv"
# split audio from original video file and store in "temp" directory
if True:
# clear old "temp" directory if it exits
if os.path.isdir("temp"):
# remove temp directory
shutil.rmtree("temp")
# create new "temp" directory
os.makedirs("temp")
# extract audio from video
os.system(f'ffmpeg -y -i "{sourceVideo}" -c:a copy -vn {tempAudioFileName}')
targetNoAudio = (
os.path.splitext(targetVideo)[0] + "_noaudio" + os.path.splitext(targetVideo)[1]
)
os.rename(targetVideo, targetNoAudio)
# combine audio file and new video file
os.system(
f'ffmpeg -y -i "{targetNoAudio}" -i {tempAudioFileName} -c copy "{targetVideo}"'
)
if (
os.path.getsize(targetVideo) == 0
): # if ffmpeg failed to merge the video and audio together try converting the audio to aac
tempAudioFileName = "./temp/audio.m4a"
os.system(
f'ffmpeg -y -i "{sourceVideo}" -c:a aac -b:a 160k -vn {tempAudioFileName}'
)
os.system(
'ffmpeg -y -i "{}" -i {} -c copy "{}"'.format(
targetNoAudio, tempAudioFileName, targetVideo
)
)
if (
os.path.getsize(targetVideo) == 0
): # if aac is not supported by selected format
os.rename(targetNoAudio, targetVideo)
print("Audio transfer failed. Interpolated video will have no audio")
else:
print(
"Lossless audio transfer failed. Audio was transcoded to AAC (M4A) instead."
)
# remove audio-less video
os.remove(targetNoAudio)
else:
os.remove(targetNoAudio)
# remove temp directory
shutil.rmtree("temp")
RIFE_WEIGHTS_URL = "https://huggingface.co/imaginairy/rife-interpolation/resolve/26442e52cc30b88c5cb490702647b8de9aaee8a7/rife-flownet-4.13.2.safetensors"
@lru_cache(maxsize=1)
def load_rife_model(model_path=None, version=4.13, device=None):
if model_path is None:
model_path = RIFE_WEIGHTS_URL
model_path = get_cached_url_path(model_path)
device = device if device else get_device()
model = Model()
model.load_model(model_path, version=version)
model.eval()
model.flownet.to(device)
return model
def make_inference(I0, I1, n, *, model, scale):
if model.version >= 3.9:
res = []
for i in range(n):
res.append(model.inference(I0, I1, (i + 1) * 1.0 / (n + 1), scale))
return res
else:
middle = model.inference(I0, I1, scale)
if n == 1:
return [middle]
first_half = make_inference(I0, middle, n=n // 2, model=model, scale=scale)
second_half = make_inference(middle, I1, n=n // 2, model=model, scale=scale)
if n % 2:
return [*first_half, middle, *second_half]
else:
return [*first_half, *second_half]
def interpolate_video_file(
video_path: str | None = None,
images_source_path: str | None = None,
scale: float = 1.0,
vid_out_name: str | None = None,
target_fps: float | None = None,
fps_multiplier: int = 2,
model_weights_path: str | None = None,
fp16: bool = False,
montage: bool = False,
png_out: bool = False,
output_extension: str = "mp4",
device=None,
):
assert video_path is not None or images_source_path is not None
assert scale in [0.25, 0.5, 1.0, 2.0, 4.0]
device = device if device else get_device()
if images_source_path is not None:
png_out = True
torch.set_grad_enabled(False)
if torch.cuda.is_available():
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
if fp16:
torch.set_default_tensor_type(torch.cuda.HalfTensor) # type: ignore
model = load_rife_model(model_weights_path, version=4.13)
logger.info(f"Loaded RIFE from {model_weights_path}")
if video_path is not None:
import skvideo.io
videoCapture = cv2.VideoCapture(video_path)
fps = videoCapture.get(cv2.CAP_PROP_FPS)
tot_frame = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT)
videoCapture.release()
if target_fps is None:
fpsNotAssigned = True
target_fps = fps * fps_multiplier
else:
fpsNotAssigned = False
videogen = skvideo.io.vreader(video_path)
lastframe = next(videogen)
fourcc = cv2.VideoWriter_fourcc("m", "p", "4", "v") # type: ignore
video_path_wo_ext, ext = os.path.splitext(video_path)
print(
"{}.{}, {} frames in total, {}FPS to {}FPS".format(
video_path_wo_ext, output_extension, tot_frame, fps, target_fps
)
)
if png_out is False and fpsNotAssigned is True:
print("The audio will be merged after interpolation process")
else:
print("Will not merge audio because using png or fps flag!")
else:
assert images_source_path is not None
videogen = []
for f in os.listdir(images_source_path):
if "png" in f:
videogen.append(f)
tot_frame = len(videogen)
videogen.sort(key=lambda x: int(x[:-4]))
lastframe = cv2.imread(
os.path.join(images_source_path, videogen[0]), cv2.IMREAD_UNCHANGED
)[:, :, ::-1].copy()
videogen = videogen[1:]
h, w, _ = lastframe.shape
vid_out = None
if png_out:
if not os.path.exists("vid_out"):
os.mkdir("vid_out")
else:
if vid_out_name is None:
assert video_path_wo_ext is not None
assert target_fps is not None
vid_out_name = f"{video_path_wo_ext}_{fps_multiplier}X_{int(np.round(target_fps))}fps.{output_extension}"
vid_out = cv2.VideoWriter(vid_out_name, fourcc, target_fps, (w, h)) # type: ignore
def clear_write_buffer(png_out, write_buffer):
cnt = 0
while True:
item = write_buffer.get()
if item is None:
break
if png_out:
cv2.imwrite(f"vid_out/{cnt:0>7d}.png", item[:, :, ::-1])
cnt += 1
else:
vid_out.write(item[:, :, ::-1])
def build_read_buffer(img, montage, read_buffer, videogen):
try:
for frame in videogen:
if img is not None:
frame = cv2.imread(os.path.join(img, frame), cv2.IMREAD_UNCHANGED)[
:, :, ::-1
].copy()
if montage:
frame = frame[:, left : left + w]
read_buffer.put(frame)
except Exception as e: # noqa
print(f"skipping frame due to error: {e}")
read_buffer.put(None)
def pad_image(img):
if fp16:
return F.pad(img, padding).half()
else:
return F.pad(img, padding)
if montage:
left = w // 4
w = w // 2
tmp = max(128, int(128 / scale))
ph = ((h - 1) // tmp + 1) * tmp
pw = ((w - 1) // tmp + 1) * tmp
padding = (0, pw - w, 0, ph - h)
pbar = tqdm(total=tot_frame)
if montage:
lastframe = lastframe[:, left : left + w]
write_buffer: Queue = Queue(maxsize=500)
read_buffer: Queue = Queue(maxsize=500)
_thread.start_new_thread(
build_read_buffer, (images_source_path, montage, read_buffer, videogen)
)
_thread.start_new_thread(clear_write_buffer, (png_out, write_buffer))
I1 = (
torch.from_numpy(np.transpose(lastframe, (2, 0, 1)))
.to(device, non_blocking=True)
.unsqueeze(0)
.float()
/ 255.0
)
I1 = pad_image(I1)
temp = None # save lastframe when processing static frame
while True:
if temp is not None:
frame = temp
temp = None
else:
frame = read_buffer.get()
if frame is None:
break
I0 = I1
I1 = (
torch.from_numpy(np.transpose(frame, (2, 0, 1)))
.to(device, non_blocking=True)
.unsqueeze(0)
.float()
/ 255.0
)
I1 = pad_image(I1)
I0_small = F.interpolate(I0, (32, 32), mode="bilinear", align_corners=False)
I1_small = F.interpolate(I1, (32, 32), mode="bilinear", align_corners=False)
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
break_flag = False
if ssim > 0.996:
frame = read_buffer.get() # read a new frame
if frame is None:
break_flag = True
frame = lastframe
else:
temp = frame
I1 = (
torch.from_numpy(np.transpose(frame, (2, 0, 1)))
.to(device, non_blocking=True)
.unsqueeze(0)
.float()
/ 255.0
)
I1 = pad_image(I1)
I1 = model.inference(I0, I1, scale)
I1_small = F.interpolate(I1, (32, 32), mode="bilinear", align_corners=False)
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
frame = (I1[0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w]
if ssim < 0.2:
output = []
for i in range(fps_multiplier - 1):
output.append(I0)
"""
output = []
step = 1 / fps_multiplier
alpha = 0
for i in range(fps_multiplier - 1):
alpha += step
beta = 1-alpha
output.append(torch.from_numpy(np.transpose((cv2.addWeighted(frame[:, :, ::-1], alpha, lastframe[:, :, ::-1], beta, 0)[:, :, ::-1].copy()), (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.)
"""
else:
output = make_inference(
I0, I1, fps_multiplier - 1, model=model, scale=scale
)
if montage:
write_buffer.put(np.concatenate((lastframe, lastframe), 1))
for mid in output:
mid = (mid[0] * 255.0).byte().cpu().numpy().transpose(1, 2, 0)
write_buffer.put(np.concatenate((lastframe, mid[:h, :w]), 1))
else:
write_buffer.put(lastframe)
for mid in output:
mid = (mid[0] * 255.0).byte().cpu().numpy().transpose(1, 2, 0)
write_buffer.put(mid[:h, :w])
pbar.update(1)
lastframe = frame
if break_flag:
break
if montage:
write_buffer.put(np.concatenate((lastframe, lastframe), 1))
else:
write_buffer.put(lastframe)
while not write_buffer.empty():
time.sleep(0.1)
pbar.close()
if vid_out is not None:
vid_out.release()
assert vid_out_name is not None
# move audio to new video file if appropriate
if png_out is False and fpsNotAssigned is True and video_path is not None:
try:
transfer_audio(video_path, vid_out_name)
except Exception as e: # noqa
logger.info(
f"Audio transfer failed. Interpolated video will have no audio. {e}"
)
targetNoAudio = (
os.path.splitext(vid_out_name)[0]
+ "_noaudio"
+ os.path.splitext(vid_out_name)[1]
)
os.rename(targetNoAudio, vid_out_name)
def pad_image(img, scale):
tmp = max(128, int(128 / scale))
ph, pw = (
((img.shape[1] - 1) // tmp + 1) * tmp,
((img.shape[2] - 1) // tmp + 1) * tmp,
)
padding = (0, pw - img.shape[2], 0, ph - img.shape[1])
return F.pad(img, padding)
def interpolate_images(
image_list,
scale=1.0,
fps_multiplier=2,
model_weights_path=None,
device=None,
) -> List[Image.Image]:
assert scale in [0.25, 0.5, 1.0, 2.0, 4.0]
torch.set_grad_enabled(False)
device = device if device else get_device()
model = load_rife_model(model_weights_path, version=4.13)
interpolated_images = []
for i in range(len(image_list) - 1):
I0 = image_to_tensor(image_list[i], device)
I1 = image_to_tensor(image_list[i + 1], device)
# I0, I1 = pad_image(I0, scale), pad_image(I1, scale)
interpolated = make_inference(
I0, I1, n=fps_multiplier - 1, model=model, scale=scale
)
interpolated_images.append(image_list[i])
for img in interpolated:
img = (img[0] * 255.0).byte().cpu().numpy().transpose(1, 2, 0)
interpolated_images.append(Image.fromarray(img))
interpolated_images.append(image_list[-1])
return interpolated_images
def image_to_tensor(image, device):
"""
Converts a PIL image to a PyTorch tensor.
Args:
- image (PIL.Image): The image to convert.
- device (torch.device): The device to use (CPU or CUDA).
Returns:
- torch.Tensor: The image converted to a PyTorch tensor.
"""
tensor = torch.from_numpy(np.array(image).transpose((2, 0, 1)))
tensor = tensor.to(device, non_blocking=True).unsqueeze(0).float() / 255.0
return tensor

@ -0,0 +1,298 @@
from math import exp
import torch
from torch.nn import functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def gaussian(window_size, sigma):
gauss = torch.Tensor(
[
exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2))
for x in range(window_size)
]
)
return gauss / gauss.sum()
def create_window(window_size, channel=1):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = (
_1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0).to(device)
)
window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
return window
def create_window_3d(window_size, channel=1):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t())
_3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t())
window = (
_3D_window.expand(1, channel, window_size, window_size, window_size)
.contiguous()
.to(device)
)
return window
def ssim(
img1,
img2,
window_size=11,
window=None,
size_average=True,
full=False,
val_range=None,
):
# Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
if val_range is None:
max_val = 255 if torch.max(img1) > 128 else 1
min_val = -1 if torch.min(img1) < -0.5 else 0
L = max_val - min_val
else:
L = val_range
padd = 0
(_, channel, height, width) = img1.size()
if window is None:
real_size = min(window_size, height, width)
window = create_window(real_size, channel=channel).to(img1.device)
# mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
# mu2 = F.conv2d(img2, window, padding=padd, groups=channel)
mu1 = F.conv2d(
F.pad(img1, (5, 5, 5, 5), mode="replicate"),
window,
padding=padd,
groups=channel,
)
mu2 = F.conv2d(
F.pad(img2, (5, 5, 5, 5), mode="replicate"),
window,
padding=padd,
groups=channel,
)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = (
F.conv2d(
F.pad(img1 * img1, (5, 5, 5, 5), "replicate"),
window,
padding=padd,
groups=channel,
)
- mu1_sq
)
sigma2_sq = (
F.conv2d(
F.pad(img2 * img2, (5, 5, 5, 5), "replicate"),
window,
padding=padd,
groups=channel,
)
- mu2_sq
)
sigma12 = (
F.conv2d(
F.pad(img1 * img2, (5, 5, 5, 5), "replicate"),
window,
padding=padd,
groups=channel,
)
- mu1_mu2
)
C1 = (0.01 * L) ** 2
C2 = (0.03 * L) ** 2
v1 = 2.0 * sigma12 + C2
v2 = sigma1_sq + sigma2_sq + C2
cs = torch.mean(v1 / v2) # contrast sensitivity
ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
ret = ssim_map.mean() if size_average else ssim_map.mean(1).mean(1).mean(1)
if full:
return ret, cs
return ret
def ssim_matlab(
img1,
img2,
window_size=11,
window=None,
size_average=True,
full=False,
val_range=None,
):
# Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
if val_range is None:
max_val = 255 if torch.max(img1) > 128 else 1
min_val = -1 if torch.min(img1) < -0.5 else 0
L = max_val - min_val
else:
L = val_range
padd = 0
(_, _, height, width) = img1.size()
if window is None:
real_size = min(window_size, height, width)
window = create_window_3d(real_size, channel=1).to(img1.device)
# Channel is set to 1 since we consider color images as volumetric images
img1 = img1.unsqueeze(1)
img2 = img2.unsqueeze(1)
mu1 = F.conv3d(
F.pad(img1, (5, 5, 5, 5, 5, 5), mode="replicate"),
window,
padding=padd,
groups=1,
)
mu2 = F.conv3d(
F.pad(img2, (5, 5, 5, 5, 5, 5), mode="replicate"),
window,
padding=padd,
groups=1,
)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = (
F.conv3d(
F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), "replicate"),
window,
padding=padd,
groups=1,
)
- mu1_sq
)
sigma2_sq = (
F.conv3d(
F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), "replicate"),
window,
padding=padd,
groups=1,
)
- mu2_sq
)
sigma12 = (
F.conv3d(
F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), "replicate"),
window,
padding=padd,
groups=1,
)
- mu1_mu2
)
C1 = (0.01 * L) ** 2
C2 = (0.03 * L) ** 2
v1 = 2.0 * sigma12 + C2
v2 = sigma1_sq + sigma2_sq + C2
cs = torch.mean(v1 / v2) # contrast sensitivity
ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
ret = ssim_map.mean() if size_average else ssim_map.mean(1).mean(1).mean(1)
if full:
return ret, cs
return ret
def msssim(
img1, img2, window_size=11, size_average=True, val_range=None, normalize=False
):
device = img1.device
weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device)
levels = weights.size()[0]
mssim = []
mcs = []
for _ in range(levels):
sim, cs = ssim(
img1,
img2,
window_size=window_size,
size_average=size_average,
full=True,
val_range=val_range,
)
mssim.append(sim)
mcs.append(cs)
img1 = F.avg_pool2d(img1, (2, 2))
img2 = F.avg_pool2d(img2, (2, 2))
mssim = torch.stack(mssim)
mcs = torch.stack(mcs)
# Normalize (to avoid NaNs during training unstable models, not compliant with original definition)
if normalize:
mssim = (mssim + 1) / 2
mcs = (mcs + 1) / 2
pow1 = mcs**weights
pow2 = mssim**weights
# From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/
output = torch.prod(pow1[:-1] * pow2[-1])
return output
class SSIM(torch.nn.Module):
def __init__(self, window_size=11, size_average=True, val_range=None):
super().__init__()
self.window_size = window_size
self.size_average = size_average
self.val_range = val_range
# Assume 3 channel for SSIM
self.channel = 3
self.window = create_window(window_size, channel=self.channel)
def forward(self, img1, img2):
(_, channel, _, _) = img1.size()
if channel == self.channel and self.window.dtype == img1.dtype:
window = self.window
else:
window = (
create_window(self.window_size, channel)
.to(img1.device)
.type(img1.dtype)
)
self.window = window
self.channel = channel
_ssim = ssim(
img1,
img2,
window=window,
window_size=self.window_size,
size_average=self.size_average,
)
dssim = (1 - _ssim) / 2
return dssim
class MSSSIM(torch.nn.Module):
def __init__(self, window_size=11, size_average=True, channel=3):
super().__init__()
self.window_size = window_size
self.size_average = size_average
self.channel = channel
def forward(self, img1, img2):
return msssim(
img1, img2, window_size=self.window_size, size_average=self.size_average
)

@ -0,0 +1,39 @@
import torch
from . import msssim
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
backwarp_tenGrid = {}
def warp(tenInput, tenFlow):
k = (str(msssim.device), str(tenFlow.size()))
if k not in backwarp_tenGrid:
tenHorizontal = (
torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device)
.view(1, 1, 1, tenFlow.shape[3])
.expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
)
tenVertical = (
torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device)
.view(1, 1, tenFlow.shape[2], 1)
.expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
)
backwarp_tenGrid[k] = torch.cat([tenHorizontal, tenVertical], 1).to(device)
tenFlow = torch.cat(
[
tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0),
],
1,
)
g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
return torch.nn.functional.grid_sample(
input=tenInput,
grid=g,
mode="bilinear",
padding_mode="border",
align_corners=True,
)

@ -1,6 +1,7 @@
"""Functions for creating animations from images."""
import logging
import os.path
from typing import TYPE_CHECKING, List
from typing import TYPE_CHECKING, List, Sequence
import cv2
import torch
@ -19,8 +20,11 @@ if TYPE_CHECKING:
from imaginairy.utils.img_utils import LazyLoadingImage
logger = logging.getLogger(__name__)
def make_bounce_animation(
imgs: "List[Image.Image | LazyLoadingImage | torch.Tensor]",
imgs: "Sequence[Image.Image | LazyLoadingImage | torch.Tensor]",
outpath: str,
transition_duration_ms=500,
start_pause_duration_ms=1000,
@ -32,7 +36,7 @@ def make_bounce_animation(
last_img = imgs[-1]
max_frames = int(round(transition_duration_ms / 1000 * max_fps))
min_duration = int(1000 / 20)
min_duration = int(1000 / max_fps)
if middle_imgs:
progress_duration = int(round(transition_duration_ms / len(middle_imgs)))
else:
@ -53,7 +57,9 @@ def make_bounce_animation(
+ [end_pause_duration_ms]
+ [progress_duration] * len(middle_imgs)
)
logger.info(
f"Making animation with {len(converted_frames)} frames and {progress_duration:.1f}ms per transition frame."
)
make_animation(imgs=converted_frames, outpath=outpath, frame_duration_ms=durations)
@ -150,8 +156,8 @@ def select_images_by_duration_at_fps(images, durations_ms, fps=30):
for i, image in enumerate(images):
duration = durations_ms[i] / 1000
num_frames = int(round(duration * fps))
print(
f"Showing image {i} for {num_frames} frames for {durations_ms[i]}ms at {fps} fps."
)
# print(
# f"Showing image {i} for {num_frames} frames for {durations_ms[i]}ms at {fps} fps."
# )
for j in range(num_frames):
yield image

Loading…
Cancel
Save