|
|
|
@ -15,6 +15,8 @@ import pytorch_lightning as pl
|
|
|
|
|
import torch
|
|
|
|
|
from einops import rearrange, repeat
|
|
|
|
|
from torch import nn
|
|
|
|
|
from torch.nn import functional as F
|
|
|
|
|
from torch.nn.modules.utils import _pair
|
|
|
|
|
from torchvision.utils import make_grid
|
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
|
|
|
@ -650,6 +652,17 @@ class DDPM(pl.LightningModule):
|
|
|
|
|
return opt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _TileModeConv2DConvForward(
|
|
|
|
|
self, input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor # noqa
|
|
|
|
|
):
|
|
|
|
|
working = F.pad(input, self.paddingX, mode=self.padding_modeX)
|
|
|
|
|
working = F.pad(working, self.paddingY, mode=self.padding_modeY)
|
|
|
|
|
|
|
|
|
|
return F.conv2d(
|
|
|
|
|
working, weight, bias, self.stride, _pair(0), self.dilation, self.groups
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LatentDiffusion(DDPM):
|
|
|
|
|
"""main class"""
|
|
|
|
|
|
|
|
|
@ -706,16 +719,34 @@ class LatentDiffusion(DDPM):
|
|
|
|
|
|
|
|
|
|
# store initial padding mode so we can switch to 'circular'
|
|
|
|
|
# when we want tiled images
|
|
|
|
|
# replace conv_forward with function that can do tiling in one direction
|
|
|
|
|
for m in self.modules():
|
|
|
|
|
if isinstance(m, nn.Conv2d):
|
|
|
|
|
m._initial_padding_mode = m.padding_mode
|
|
|
|
|
m._conv_forward = _TileModeConv2DConvForward.__get__( # noqa
|
|
|
|
|
m, nn.Conv2d
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def tile_mode(self, enabled):
|
|
|
|
|
def tile_mode(self, tile_mode):
|
|
|
|
|
"""For creating seamless tiles"""
|
|
|
|
|
tile_mode = tile_mode or ""
|
|
|
|
|
tile_x = "x" in tile_mode
|
|
|
|
|
tile_y = "y" in tile_mode
|
|
|
|
|
for m in self.modules():
|
|
|
|
|
if isinstance(m, nn.Conv2d):
|
|
|
|
|
m.padding_mode = (
|
|
|
|
|
"circular" if enabled else m._initial_padding_mode # noqa
|
|
|
|
|
m.padding_modeX = "circular" if tile_x else "constant"
|
|
|
|
|
m.padding_modeY = "circular" if tile_y else "constant"
|
|
|
|
|
m.paddingX = (
|
|
|
|
|
m._reversed_padding_repeated_twice[0], # noqa
|
|
|
|
|
m._reversed_padding_repeated_twice[1], # noqa
|
|
|
|
|
0,
|
|
|
|
|
0,
|
|
|
|
|
)
|
|
|
|
|
m.paddingY = (
|
|
|
|
|
0,
|
|
|
|
|
0,
|
|
|
|
|
m._reversed_padding_repeated_twice[2], # noqa
|
|
|
|
|
m._reversed_padding_repeated_twice[3], # noqa
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def make_cond_schedule(
|
|
|
|
|