fix: lower memory reqs for attention on M1s

pull/91/head
Bryce 2 years ago committed by Bryce Drennan
parent 64a0e0bb50
commit b8a88370de

@ -1,5 +1,6 @@
import math
import psutil
import torch
import torch.nn.functional as F
from einops import rearrange
@ -122,6 +123,23 @@ class SpatialSelfAttention(nn.Module):
return x + h_
def get_mem_free_total(device):
device_type = "mps" if device.type == "mps" else "cuda"
if device_type == "cuda":
stats = torch.cuda.memory_stats(device)
mem_active = stats["active_bytes.all.current"]
mem_reserved = stats["reserved_bytes.all.current"]
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
mem_free_total *= 0.9
else:
# if we don't add a buffer, larger images come out as noise
mem_free_total = psutil.virtual_memory().available * 0.6
return mem_free_total
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
super().__init__()
@ -145,8 +163,8 @@ class CrossAttention(nn.Module):
# if mask is None and _global_mask_hack is not None:
# mask = _global_mask_hack.to(torch.bool)
if get_device() == "cuda":
return self.forward_cuda(x, context=context, mask=mask)
if get_device() == "cuda" or "mps" in get_device():
return self.forward_splitmem(x, context=context, mask=mask)
h = self.heads
@ -174,12 +192,12 @@ class CrossAttention(nn.Module):
out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
return self.to_out(out)
def forward_cuda(self, x, context=None, mask=None): # noqa
def forward_splitmem(self, x, context=None, mask=None): # noqa
h = self.heads
q_in = self.to_q(x)
context = context if context is not None else x
k_in = self.to_k(context)
k_in = self.to_k(context) * self.scale
v_in = self.to_v(context)
del context, x
@ -190,12 +208,7 @@ class CrossAttention(nn.Module):
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
stats = torch.cuda.memory_stats(q.device)
mem_active = stats["active_bytes.all.current"]
mem_reserved = stats["reserved_bytes.all.current"]
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
mem_free_total = get_mem_free_total(q.device)
gb = 1024**3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
@ -205,8 +218,6 @@ class CrossAttention(nn.Module):
if mem_required > mem_free_total:
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
if steps > 64:
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
@ -218,7 +229,7 @@ class CrossAttention(nn.Module):
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
s1 = einsum("b i d, b j d -> b i j", q[:, i:end], k) * self.scale
s1 = einsum("b i d, b j d -> b i j", q[:, i:end], k)
s2 = s1.softmax(dim=-1, dtype=q.dtype)
del s1

@ -12,11 +12,11 @@ addict==2.4.0
# via basicsr
aiohttp==3.8.3
# via fsspec
aiosignal==1.2.0
aiosignal==1.3.1
# via aiohttp
antlr4-python3-runtime==4.8
# via omegaconf
astroid==2.12.11
astroid==2.12.12
# via pylint
async-timeout==4.0.2
# via aiohttp
@ -46,18 +46,20 @@ click==8.1.3
# typer
click-shell==2.1
# via imaginAIry (setup.py)
contourpy==1.0.5
contourpy==1.0.6
# via matplotlib
coverage==6.5.0
# via -r requirements-dev.in
cycler==0.11.0
# via matplotlib
diffusers==0.5.1
diffusers==0.7.2
# via imaginAIry (setup.py)
dill==0.3.5.1
dill==0.3.6
# via pylint
einops==0.3.0
# via imaginAIry (setup.py)
exceptiongroup==1.0.1
# via pytest
facexlib==0.2.5
# via
# gfpgan
@ -71,13 +73,13 @@ filelock==3.8.0
# transformers
filterpy==1.4.5
# via facexlib
fonttools==4.37.4
fonttools==4.38.0
# via matplotlib
frozenlist==1.3.1
frozenlist==1.3.3
# via
# aiohttp
# aiosignal
fsspec[http]==2022.8.2
fsspec[http]==2022.11.0
# via pytorch-lightning
ftfy==6.1.1
# via imaginAIry (setup.py)
@ -89,7 +91,7 @@ gfpgan==1.3.8
# via
# imaginAIry (setup.py)
# realesrgan
google-auth==2.12.0
google-auth==2.14.1
# via
# google-auth-oauthlib
# tb-nightly
@ -98,7 +100,7 @@ google-auth-oauthlib==0.4.6
# via
# tb-nightly
# tensorboard
grpcio==1.49.1
grpcio==1.50.0
# via
# tb-nightly
# tensorboard
@ -127,9 +129,9 @@ kiwisolver==1.4.4
# via matplotlib
kornia==0.6
# via imaginAIry (setup.py)
lazy-object-proxy==1.7.1
lazy-object-proxy==1.8.0
# via astroid
libcst==0.4.7
libcst==0.4.9
# via pycln
llvmlite==0.39.1
# via numba
@ -143,7 +145,7 @@ markdown==3.4.1
# tensorboard
markupsafe==2.1.1
# via werkzeug
matplotlib==3.6.1
matplotlib==3.6.2
# via filterpy
mccabe==0.7.0
# via
@ -157,9 +159,9 @@ mypy-extensions==0.4.3
# via
# black
# typing-inspect
networkx==2.8.7
networkx==2.8.8
# via scikit-image
numba==0.56.3
numba==0.56.4
# via facexlib
numpy==1.23.4
# via
@ -186,7 +188,7 @@ numpy==1.23.4
# torchmetrics
# torchvision
# transformers
oauthlib==3.2.1
oauthlib==3.2.2
# via requests-oauthlib
omegaconf==2.1.1
# via imaginAIry (setup.py)
@ -211,7 +213,7 @@ pathspec==0.9.0
# via
# black
# pycln
pillow==9.2.0
pillow==9.3.0
# via
# basicsr
# diffusers
@ -222,26 +224,26 @@ pillow==9.2.0
# realesrgan
# scikit-image
# torchvision
platformdirs==2.5.2
platformdirs==2.5.4
# via
# black
# pylint
pluggy==1.0.0
# via pytest
protobuf==3.19.6
protobuf==3.20.3
# via
# imaginAIry (setup.py)
# tb-nightly
# tensorboard
py==1.11.0
# via pytest
psutil==5.9.4
# via imaginAIry (setup.py)
pyasn1==0.4.8
# via
# pyasn1-modules
# rsa
pyasn1-modules==0.2.8
# via google-auth
pycln==2.1.1
pycln==2.1.2
# via -r requirements-dev.in
pycodestyle==2.9.1
# via pylama
@ -255,20 +257,20 @@ pyflakes==2.5.0
# via pylama
pylama==8.4.1
# via -r requirements-dev.in
pylint==2.15.4
pylint==2.15.5
# via -r requirements-dev.in
pyparsing==3.0.9
# via
# matplotlib
# packaging
pytest==7.1.3
pytest==7.2.0
# via
# -r requirements-dev.in
# pytest-randomly
# pytest-sugar
pytest-randomly==3.12.0
# via -r requirements-dev.in
pytest-sugar==0.9.5
pytest-sugar==0.9.6
# via -r requirements-dev.in
python-dateutil==2.8.2
# via matplotlib
@ -289,7 +291,7 @@ pyyaml==6.0
# transformers
realesrgan==0.3.0
# via imaginAIry (setup.py)
regex==2022.9.13
regex==2022.10.31
# via
# diffusers
# transformers
@ -313,7 +315,7 @@ rsa==4.9
# via google-auth
scikit-image==0.19.3
# via basicsr
scipy==1.9.2
scipy==1.9.3
# via
# basicsr
# facexlib
@ -328,11 +330,11 @@ six==1.16.0
# python-dateutil
snowballstemmer==2.2.0
# via pydocstyle
tb-nightly==2.11.0a20221016
tb-nightly==2.12.0a20221113
# via
# basicsr
# gfpgan
tensorboard==2.10.1
tensorboard==2.11.0
# via pytorch-lightning
tensorboard-data-server==0.6.1
# via
@ -342,7 +344,7 @@ tensorboard-plugin-wit==1.8.1
# via
# tb-nightly
# tensorboard
termcolor==2.0.1
termcolor==2.1.0
# via pytest-sugar
tifffile==2022.10.10
# via scikit-image
@ -357,11 +359,11 @@ tomli==2.0.1
# black
# pylint
# pytest
tomlkit==0.11.5
tomlkit==0.11.6
# via
# pycln
# pylint
torch==1.12.1
torch==1.13.0
# via
# basicsr
# facexlib
@ -381,7 +383,7 @@ torchmetrics==0.6.0
# via
# imaginAIry (setup.py)
# pytorch-lightning
torchvision==0.13.1
torchvision==0.14.0
# via
# basicsr
# facexlib
@ -401,9 +403,9 @@ tqdm==4.64.1
# transformers
transformers==4.19.2
# via imaginAIry (setup.py)
typer==0.6.1
typer==0.7.0
# via pycln
types-toml==0.10.8
types-toml==0.10.8.1
# via responses
typing-extensions==4.4.0
# via
@ -425,7 +427,7 @@ werkzeug==2.2.2
# via
# tb-nightly
# tensorboard
wheel==0.37.1
wheel==0.38.4
# via
# tb-nightly
# tensorboard
@ -437,7 +439,7 @@ yapf==0.32.0
# gfpgan
yarl==1.8.1
# via aiohttp
zipp==3.9.0
zipp==3.10.0
# via importlib-metadata
# The following packages are considered to be unsafe in a requirements file:

@ -45,6 +45,7 @@ setup(
"diffusers",
"imageio==2.9.0",
"Pillow>=8.0.0",
"psutil",
"pytorch-lightning==1.4.2",
"omegaconf==2.1.1",
"einops==0.3.0",

Loading…
Cancel
Save