mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-11-19 03:25:41 +00:00
316114e660
Wrote an openai script and custom prompt to generate them.
76 lines
2.9 KiB
Python
76 lines
2.9 KiB
Python
"""Classes for latent space perceptual loss"""
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from imaginairy.modules.sgm.autoencoding.lpips.loss.lpips import LPIPS
|
|
from imaginairy.utils import default, instantiate_from_config
|
|
|
|
|
|
class LatentLPIPS(nn.Module):
|
|
def __init__(
|
|
self,
|
|
decoder_config,
|
|
perceptual_weight=1.0,
|
|
latent_weight=1.0,
|
|
scale_input_to_tgt_size=False,
|
|
scale_tgt_to_input_size=False,
|
|
perceptual_weight_on_inputs=0.0,
|
|
):
|
|
super().__init__()
|
|
self.scale_input_to_tgt_size = scale_input_to_tgt_size
|
|
self.scale_tgt_to_input_size = scale_tgt_to_input_size
|
|
self.init_decoder(decoder_config)
|
|
self.perceptual_loss = LPIPS().eval()
|
|
self.perceptual_weight = perceptual_weight
|
|
self.latent_weight = latent_weight
|
|
self.perceptual_weight_on_inputs = perceptual_weight_on_inputs
|
|
|
|
def init_decoder(self, config):
|
|
self.decoder = instantiate_from_config(config)
|
|
if hasattr(self.decoder, "encoder"):
|
|
del self.decoder.encoder
|
|
|
|
def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"):
|
|
log = {}
|
|
loss = (latent_inputs - latent_predictions) ** 2
|
|
log[f"{split}/latent_l2_loss"] = loss.mean().detach()
|
|
image_reconstructions = None
|
|
if self.perceptual_weight > 0.0:
|
|
image_reconstructions = self.decoder.decode(latent_predictions)
|
|
image_targets = self.decoder.decode(latent_inputs)
|
|
perceptual_loss = self.perceptual_loss(
|
|
image_targets.contiguous(), image_reconstructions.contiguous()
|
|
)
|
|
loss = (
|
|
self.latent_weight * loss.mean()
|
|
+ self.perceptual_weight * perceptual_loss.mean()
|
|
)
|
|
log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach()
|
|
|
|
if self.perceptual_weight_on_inputs > 0.0:
|
|
image_reconstructions = default(
|
|
image_reconstructions, self.decoder.decode(latent_predictions)
|
|
)
|
|
if self.scale_input_to_tgt_size:
|
|
image_inputs = torch.nn.functional.interpolate(
|
|
image_inputs,
|
|
image_reconstructions.shape[2:],
|
|
mode="bicubic",
|
|
antialias=True,
|
|
)
|
|
elif self.scale_tgt_to_input_size:
|
|
image_reconstructions = torch.nn.functional.interpolate(
|
|
image_reconstructions,
|
|
image_inputs.shape[2:],
|
|
mode="bicubic",
|
|
antialias=True,
|
|
)
|
|
|
|
perceptual_loss2 = self.perceptual_loss(
|
|
image_inputs.contiguous(), image_reconstructions.contiguous()
|
|
)
|
|
loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean()
|
|
log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach()
|
|
return loss, log
|