feature: sliced image encoding for SD1Autoencoder

pull/408/head
Bryce 6 months ago committed by Bryce Drennan
parent b61d06651c
commit 1b15d6dcd4

@ -120,6 +120,34 @@ class SD1AutoencoderSliced(SD1Autoencoder):
max_chunk_size = 2048
min_chunk_size = 64
def encode(self, x: Tensor) -> Tensor:
return self.sliced_encode(x)
def sliced_encode(self, x: Tensor, chunk_size: int = 128 * 8) -> Tensor:
"""
Encodes the image in slices (for lower memory usage).
"""
b, c, h, w = x.size()
final_tensor = torch.zeros(
[1, 4, math.ceil(h / 8), math.ceil(w / 8)], device=x.device
)
overlap_pct = 0.5
for x_img in x.split(1):
chunks = tile_image(
x_img, tile_size=chunk_size, overlap_percent=overlap_pct
)
encoded_chunks = [super(SD1Autoencoder, self).encode(ic) for ic in chunks]
final_tensor = rebuild_image(
encoded_chunks,
base_img=final_tensor,
tile_size=chunk_size // 8,
overlap_percent=overlap_pct,
)
return final_tensor
def decode(self, x):
while self.__class__.max_chunk_size > self.__class__.min_chunk_size:
if self.max_chunk_size**2 > x.shape[2] * x.shape[3]:

Loading…
Cancel
Save