From c3a88c44cd93766e3f90fd3c2eb04016b6dfb128 Mon Sep 17 00:00:00 2001 From: Bryce Date: Thu, 16 Feb 2023 02:39:17 -0800 Subject: [PATCH] perf: tiled encoding of images The fold/unfold/split_input_params images didn't look good --- README.md | 1 + imaginairy/feather_tile.py | 10 ++++---- imaginairy/modules/autoencoder.py | 38 +++++++++++++++++++++++++++---- 3 files changed, 40 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 28e2b4e..b3a8b2f 100644 --- a/README.md +++ b/README.md @@ -300,6 +300,7 @@ docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface - ## ChangeLog +- perf: tiled encoding of images (removes memory bottleneck) - perf: use Silu for performance improvement over nonlinearity - perf: `xformers` added as a dependency for linux and windows. Gives a nice speed boost. - perf: sliced attention now runs on MacOS. A typo prevented that from happening previously. diff --git a/imaginairy/feather_tile.py b/imaginairy/feather_tile.py index 6ddb744..5f45a67 100644 --- a/imaginairy/feather_tile.py +++ b/imaginairy/feather_tile.py @@ -7,7 +7,7 @@ import torch def mask_tile(tile, overlap, std_overlap, side="bottom"): - h, w = tile.size(2), tile.size(3) + b, c, h, w = tile.shape top_overlap, bottom_overlap, right_overlap, left_overlap = overlap ( std_top_overlap, @@ -25,7 +25,7 @@ def mask_tile(tile, overlap, std_overlap, side="bottom"): lin_mask_left = ( torch.cat([zeros_mask, lin_mask_left], 0) .repeat(h, 1) - .repeat(3, 1, 1) + .repeat(c, 1, 1) .unsqueeze(0) ) @@ -33,7 +33,7 @@ def mask_tile(tile, overlap, std_overlap, side="bottom"): lin_mask_right = ( torch.linspace(1, 0, right_overlap, device=tile.device) .repeat(h, 1) - .repeat(3, 1, 1) + .repeat(c, 1, 1) .unsqueeze(0) ) if "top" in side: @@ -41,14 +41,14 @@ def mask_tile(tile, overlap, std_overlap, side="bottom"): if top_overlap > std_top_overlap: zeros_mask = torch.zeros(top_overlap - std_top_overlap, device=tile.device) lin_mask_top = torch.cat([zeros_mask, lin_mask_top], 0) - lin_mask_top = lin_mask_top.repeat(w, 1).rot90(3).repeat(3, 1, 1).unsqueeze(0) + lin_mask_top = lin_mask_top.repeat(w, 1).rot90(3).repeat(c, 1, 1).unsqueeze(0) if "bottom" in side: lin_mask_bottom = ( torch.linspace(1, 0, std_bottom_overlap, device=tile.device) .repeat(w, 1) .rot90(3) - .repeat(3, 1, 1) + .repeat(c, 1, 1) .unsqueeze(0) ) diff --git a/imaginairy/modules/autoencoder.py b/imaginairy/modules/autoencoder.py index 9d086cf..a33d8da 100644 --- a/imaginairy/modules/autoencoder.py +++ b/imaginairy/modules/autoencoder.py @@ -88,10 +88,40 @@ class AutoencoderKL(pl.LightningModule): self.model_ema(self) def encode(self, x): - h = self.encoder(x) - moments = self.quant_conv(h) - posterior = DiagonalGaussianDistribution(moments) - return posterior + return self.encode_sliced(x) + # h = self.encoder(x) + # moments = self.quant_conv(h) + # posterior = DiagonalGaussianDistribution(moments) + # return posterior.sample() + + def encode_sliced(self, x, chunk_size=128 * 8): + """ + encodes the image in slices. + """ + b, c, h, w = x.size() + final_tensor = torch.zeros( + [1, 4, math.ceil(h / 8), math.ceil(w / 8)], device=x.device + ) + for x_img in x.split(1): + encoded_chunks = [] + overlap_pct = 0.5 + chunks = tile_image( + x_img, tile_size=chunk_size, overlap_percent=overlap_pct + ) + + for img_chunk in chunks: + h = self.encoder(img_chunk) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + encoded_chunks.append(posterior.sample()) + final_tensor = rebuild_image( + encoded_chunks, + base_img=final_tensor, + tile_size=chunk_size // 8, + overlap_percent=overlap_pct, + ) + + return final_tensor def decode(self, z): try: