|
|
|
@ -22,7 +22,6 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import (
|
|
|
|
|
)
|
|
|
|
|
from torch import Tensor, nn
|
|
|
|
|
from torch.nn import functional as F
|
|
|
|
|
from torch.nn.modules.utils import _pair
|
|
|
|
|
|
|
|
|
|
from imaginairy.utils.feather_tile import rebuild_image, tile_image
|
|
|
|
|
from imaginairy.weight_management.conversion import cast_weights
|
|
|
|
@ -35,17 +34,17 @@ TileModeType = Literal["", "x", "y", "xy"]
|
|
|
|
|
def _tile_mode_conv2d_conv_forward(
|
|
|
|
|
self, input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor # noqa
|
|
|
|
|
):
|
|
|
|
|
if self.padding_modeX == self.padding_modeY:
|
|
|
|
|
self.padding_mode = self.padding_modeX
|
|
|
|
|
if self.padding_mode_x == self.padding_mode_y:
|
|
|
|
|
self.padding_mode = self.padding_mode_x
|
|
|
|
|
return self._orig_conv_forward(input, weight, bias)
|
|
|
|
|
|
|
|
|
|
w1 = F.pad(input, self.paddingX, mode=self.padding_modeX)
|
|
|
|
|
w1 = F.pad(input, self.padding_x, mode=self.padding_modeX)
|
|
|
|
|
del input
|
|
|
|
|
|
|
|
|
|
w2 = F.pad(w1, self.paddingY, mode=self.padding_modeY)
|
|
|
|
|
w2 = F.pad(w1, self.padding_y, mode=self.padding_modeY)
|
|
|
|
|
del w1
|
|
|
|
|
|
|
|
|
|
return F.conv2d(w2, weight, bias, self.stride, _pair(0), self.dilation, self.groups)
|
|
|
|
|
return F.conv2d(w2, weight, bias, self.stride, (0, 0), self.dilation, self.groups)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TileModeMixin(nn.Module):
|
|
|
|
@ -57,34 +56,21 @@ class TileModeMixin(nn.Module):
|
|
|
|
|
tile_mode: One of "", "x", "y", "xy". If "x", the image will be tiled horizontally. If "y", the image will be
|
|
|
|
|
tiled vertically. If "xy", the image will be tiled both horizontally and vertically.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
tile_x = "x" in tile_mode
|
|
|
|
|
tile_y = "y" in tile_mode
|
|
|
|
|
padding_mode_x = "circular" if "x" in tile_mode else "constant"
|
|
|
|
|
padding_mode_y = "circular" if "y" in tile_mode else "constant"
|
|
|
|
|
for m in self.modules():
|
|
|
|
|
if isinstance(m, nn.Conv2d):
|
|
|
|
|
if not hasattr(m, "_orig_conv_forward"):
|
|
|
|
|
# patch with a function that can handle tiling in a single direction
|
|
|
|
|
m._initial_padding_mode = m.padding_mode # type: ignore
|
|
|
|
|
m._orig_conv_forward = m._conv_forward # type: ignore
|
|
|
|
|
m._conv_forward = _tile_mode_conv2d_conv_forward.__get__( # type: ignore
|
|
|
|
|
m, nn.Conv2d
|
|
|
|
|
)
|
|
|
|
|
m.padding_modeX = "circular" if tile_x else "constant" # type: ignore
|
|
|
|
|
m.padding_modeY = "circular" if tile_y else "constant" # type: ignore
|
|
|
|
|
if m.padding_modeY == m.padding_modeX:
|
|
|
|
|
m.padding_mode = m.padding_modeX
|
|
|
|
|
m.paddingX = (
|
|
|
|
|
m._reversed_padding_repeated_twice[0],
|
|
|
|
|
m._reversed_padding_repeated_twice[1],
|
|
|
|
|
0,
|
|
|
|
|
0,
|
|
|
|
|
) # type: ignore
|
|
|
|
|
m.paddingY = (
|
|
|
|
|
0,
|
|
|
|
|
0,
|
|
|
|
|
m._reversed_padding_repeated_twice[2],
|
|
|
|
|
m._reversed_padding_repeated_twice[3],
|
|
|
|
|
) # type: ignore
|
|
|
|
|
if not isinstance(m, nn.Conv2d):
|
|
|
|
|
continue
|
|
|
|
|
if not hasattr(m, "_orig_conv_forward"):
|
|
|
|
|
# patch with a function that can handle tiling in a single direction
|
|
|
|
|
m._initial_padding_mode = m.padding_mode # type: ignore
|
|
|
|
|
m._orig_conv_forward = m._conv_forward # type: ignore
|
|
|
|
|
m._conv_forward = _tile_mode_conv2d_conv_forward.__get__(m, nn.Conv2d) # type: ignore
|
|
|
|
|
m.padding_mode_x = padding_mode_x # type: ignore
|
|
|
|
|
m.padding_mode_y = padding_mode_y # type: ignore
|
|
|
|
|
rprt: list[int] = m._reversed_padding_repeated_twice
|
|
|
|
|
m.padding_x = (rprt[0], rprt[1], 0, 0) # type: ignore
|
|
|
|
|
m.padding_y = (0, 0, rprt[2], rprt[3]) # type: ignore
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StableDiffusion_1(TileModeMixin, RefinerStableDiffusion_1):
|
|
|
|
|