perf: tiled encoding of images

The fold/unfold/split_input_params images didn't look good
pull/259/head
Bryce 2 years ago committed by Bryce Drennan
parent 2aef6089e0
commit c3a88c44cd

@ -300,6 +300,7 @@ docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface -
## ChangeLog
- perf: tiled encoding of images (removes memory bottleneck)
- perf: use Silu for performance improvement over nonlinearity
- perf: `xformers` added as a dependency for linux and windows. Gives a nice speed boost.
- perf: sliced attention now runs on MacOS. A typo prevented that from happening previously.

@ -7,7 +7,7 @@ import torch
def mask_tile(tile, overlap, std_overlap, side="bottom"):
h, w = tile.size(2), tile.size(3)
b, c, h, w = tile.shape
top_overlap, bottom_overlap, right_overlap, left_overlap = overlap
(
std_top_overlap,
@ -25,7 +25,7 @@ def mask_tile(tile, overlap, std_overlap, side="bottom"):
lin_mask_left = (
torch.cat([zeros_mask, lin_mask_left], 0)
.repeat(h, 1)
.repeat(3, 1, 1)
.repeat(c, 1, 1)
.unsqueeze(0)
)
@ -33,7 +33,7 @@ def mask_tile(tile, overlap, std_overlap, side="bottom"):
lin_mask_right = (
torch.linspace(1, 0, right_overlap, device=tile.device)
.repeat(h, 1)
.repeat(3, 1, 1)
.repeat(c, 1, 1)
.unsqueeze(0)
)
if "top" in side:
@ -41,14 +41,14 @@ def mask_tile(tile, overlap, std_overlap, side="bottom"):
if top_overlap > std_top_overlap:
zeros_mask = torch.zeros(top_overlap - std_top_overlap, device=tile.device)
lin_mask_top = torch.cat([zeros_mask, lin_mask_top], 0)
lin_mask_top = lin_mask_top.repeat(w, 1).rot90(3).repeat(3, 1, 1).unsqueeze(0)
lin_mask_top = lin_mask_top.repeat(w, 1).rot90(3).repeat(c, 1, 1).unsqueeze(0)
if "bottom" in side:
lin_mask_bottom = (
torch.linspace(1, 0, std_bottom_overlap, device=tile.device)
.repeat(w, 1)
.rot90(3)
.repeat(3, 1, 1)
.repeat(c, 1, 1)
.unsqueeze(0)
)

@ -88,10 +88,40 @@ class AutoencoderKL(pl.LightningModule):
self.model_ema(self)
def encode(self, x):
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior
return self.encode_sliced(x)
# h = self.encoder(x)
# moments = self.quant_conv(h)
# posterior = DiagonalGaussianDistribution(moments)
# return posterior.sample()
def encode_sliced(self, x, chunk_size=128 * 8):
"""
encodes the image in slices.
"""
b, c, h, w = x.size()
final_tensor = torch.zeros(
[1, 4, math.ceil(h / 8), math.ceil(w / 8)], device=x.device
)
for x_img in x.split(1):
encoded_chunks = []
overlap_pct = 0.5
chunks = tile_image(
x_img, tile_size=chunk_size, overlap_percent=overlap_pct
)
for img_chunk in chunks:
h = self.encoder(img_chunk)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
encoded_chunks.append(posterior.sample())
final_tensor = rebuild_image(
encoded_chunks,
base_img=final_tensor,
tile_size=chunk_size // 8,
overlap_percent=overlap_pct,
)
return final_tensor
def decode(self, z):
try:

Loading…
Cancel
Save