|
|
|
@ -120,6 +120,34 @@ class SD1AutoencoderSliced(SD1Autoencoder):
|
|
|
|
|
max_chunk_size = 2048
|
|
|
|
|
min_chunk_size = 64
|
|
|
|
|
|
|
|
|
|
def encode(self, x: Tensor) -> Tensor:
|
|
|
|
|
return self.sliced_encode(x)
|
|
|
|
|
|
|
|
|
|
def sliced_encode(self, x: Tensor, chunk_size: int = 128 * 8) -> Tensor:
|
|
|
|
|
"""
|
|
|
|
|
Encodes the image in slices (for lower memory usage).
|
|
|
|
|
"""
|
|
|
|
|
b, c, h, w = x.size()
|
|
|
|
|
final_tensor = torch.zeros(
|
|
|
|
|
[1, 4, math.ceil(h / 8), math.ceil(w / 8)], device=x.device
|
|
|
|
|
)
|
|
|
|
|
overlap_pct = 0.5
|
|
|
|
|
|
|
|
|
|
for x_img in x.split(1):
|
|
|
|
|
chunks = tile_image(
|
|
|
|
|
x_img, tile_size=chunk_size, overlap_percent=overlap_pct
|
|
|
|
|
)
|
|
|
|
|
encoded_chunks = [super(SD1Autoencoder, self).encode(ic) for ic in chunks]
|
|
|
|
|
|
|
|
|
|
final_tensor = rebuild_image(
|
|
|
|
|
encoded_chunks,
|
|
|
|
|
base_img=final_tensor,
|
|
|
|
|
tile_size=chunk_size // 8,
|
|
|
|
|
overlap_percent=overlap_pct,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return final_tensor
|
|
|
|
|
|
|
|
|
|
def decode(self, x):
|
|
|
|
|
while self.__class__.max_chunk_size > self.__class__.min_chunk_size:
|
|
|
|
|
if self.max_chunk_size**2 > x.shape[2] * x.shape[3]:
|
|
|
|
|