performance: memory management improvements (#210)

- tile mode made more efficient. especially when not being used
- add script to iteratively make bigger images
pull/215/head
Bryce Drennan 1 year ago committed by GitHub
parent a338902ab5
commit 7bdde559cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

3
.gitignore vendored

@ -26,4 +26,5 @@ tests/vastai_cli.py
*.pyprof
**/.polyscope.ini
**/imgui.ini
**/.eggs
**/.eggs
/img_size_memory_usage.csv

@ -283,9 +283,17 @@ docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface -
## ChangeLog
**8.1.0**
- feature: some memory optimizations and documentation
- feature: surprise-me improvements
- feature: image sizes can now be multiples of 8 instead of 64. Inputs will be silently rounded down.
- feature: cleaned up `aimg` shell logs
- feature: auto-regen for unsafe images
- fix: make blip filename windows compatible
- fix: make captioning work with alpha pngs
**8.0.5**
- fix: bypass huggingface cache retrieval bug

@ -8,6 +8,7 @@ import torch.nn
from einops import rearrange, repeat
from PIL import Image, ImageDraw, ImageOps
from pytorch_lightning import seed_everything
from torch.cuda import OutOfMemoryError
from imaginairy.enhancers.clip_masking import get_img_mask
from imaginairy.enhancers.describe_image_blip import generate_caption
@ -440,9 +441,17 @@ def _generate_single_image(
)
# from torch.nn.functional import interpolate
# samples = interpolate(samples, scale_factor=2, mode='nearest')
x_samples = model.decode_first_stage(samples)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
with lc.timing("decoding"):
try:
x_samples = model.decode_first_stage(samples)
except OutOfMemoryError:
model.cond_stage_model.to("cpu")
model.model.to("cpu")
x_samples = model.decode_first_stage(samples)
model.cond_stage_model.to(get_device())
model.model.to(get_device())
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
for x_sample in x_samples:
x_sample = x_sample.to(torch.float32)

@ -714,12 +714,18 @@ class DDPM(pl.LightningModule):
def _TileModeConv2DConvForward(
self, input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor # noqa
):
working = F.pad(input, self.paddingX, mode=self.padding_modeX)
working = F.pad(working, self.paddingY, mode=self.padding_modeY)
if self.padding_modeX == self.padding_modeY:
return F.conv2d(
input, weight, bias, self.stride, self.padding, self.dilation, self.groups
)
w1 = F.pad(input, self.paddingX, mode=self.padding_modeX)
del input
w2 = F.pad(w1, self.paddingY, mode=self.padding_modeY)
del w1
return F.conv2d(
working, weight, bias, self.stride, _pair(0), self.dilation, self.groups
)
return F.conv2d(w2, weight, bias, self.stride, _pair(0), self.dilation, self.groups)
class LatentDiffusion(DDPM):
@ -798,6 +804,8 @@ class LatentDiffusion(DDPM):
if isinstance(m, nn.Conv2d):
m.padding_modeX = "circular" if tile_x else "constant"
m.padding_modeY = "circular" if tile_y else "constant"
if m.padding_modeY == m.padding_modeX:
m.padding_mode = m.padding_modeX
m.paddingX = (
m._reversed_padding_repeated_twice[0], # noqa
m._reversed_padding_repeated_twice[1], # noqa

@ -0,0 +1,40 @@
import torch
from torch.cuda import OutOfMemoryError
from imaginairy import ImaginePrompt, imagine_image_files
from imaginairy.utils import get_device
def assess_memory_usage():
assert get_device() == "cuda"
img_size = 1664
prompt = ImaginePrompt("strawberries", width=64, height=64, seed=1)
imagine_image_files([prompt], outdir="outputs")
datalog = []
while True:
torch.cuda.reset_peak_memory_stats()
prompt = ImaginePrompt(
"beautiful landscape, Unreal Engine 5, RTX, AAA Game, Detailed 3D Render, Cinema4D",
width=img_size,
height=img_size,
seed=1,
)
try:
imagine_image_files([prompt], outdir="outputs")
except OutOfMemoryError as e:
print(f"Out of memory at {img_size}x{img_size} size image.")
print(e)
break
max_used = torch.cuda.max_memory_allocated() / 1024**3
datalog.append((img_size, max_used))
print(f"{img_size},{max_used:.2f}\n")
img_size += 128
with open("img_size_memory_usage.csv", "w", encoding="utf-8") as f:
f.write("img_size,max_used\n")
for img_size, max_used in datalog:
f.write(f"{img_size},{max_used:.2f}\n")
if __name__ == "__main__":
assess_memory_usage()
Loading…
Cancel
Save