mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-10-31 03:20:40 +00:00
fix: lower memory reqs for attention on M1s
This commit is contained in:
parent
64a0e0bb50
commit
b8a88370de
@ -1,5 +1,6 @@
|
||||
import math
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
@ -122,6 +123,23 @@ class SpatialSelfAttention(nn.Module):
|
||||
return x + h_
|
||||
|
||||
|
||||
def get_mem_free_total(device):
|
||||
device_type = "mps" if device.type == "mps" else "cuda"
|
||||
if device_type == "cuda":
|
||||
stats = torch.cuda.memory_stats(device)
|
||||
mem_active = stats["active_bytes.all.current"]
|
||||
mem_reserved = stats["reserved_bytes.all.current"]
|
||||
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_cuda + mem_free_torch
|
||||
mem_free_total *= 0.9
|
||||
else:
|
||||
# if we don't add a buffer, larger images come out as noise
|
||||
mem_free_total = psutil.virtual_memory().available * 0.6
|
||||
|
||||
return mem_free_total
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
||||
super().__init__()
|
||||
@ -145,8 +163,8 @@ class CrossAttention(nn.Module):
|
||||
# if mask is None and _global_mask_hack is not None:
|
||||
# mask = _global_mask_hack.to(torch.bool)
|
||||
|
||||
if get_device() == "cuda":
|
||||
return self.forward_cuda(x, context=context, mask=mask)
|
||||
if get_device() == "cuda" or "mps" in get_device():
|
||||
return self.forward_splitmem(x, context=context, mask=mask)
|
||||
|
||||
h = self.heads
|
||||
|
||||
@ -174,12 +192,12 @@ class CrossAttention(nn.Module):
|
||||
out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
|
||||
return self.to_out(out)
|
||||
|
||||
def forward_cuda(self, x, context=None, mask=None): # noqa
|
||||
def forward_splitmem(self, x, context=None, mask=None): # noqa
|
||||
h = self.heads
|
||||
|
||||
q_in = self.to_q(x)
|
||||
context = context if context is not None else x
|
||||
k_in = self.to_k(context)
|
||||
k_in = self.to_k(context) * self.scale
|
||||
v_in = self.to_v(context)
|
||||
del context, x
|
||||
|
||||
@ -190,12 +208,7 @@ class CrossAttention(nn.Module):
|
||||
|
||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
|
||||
|
||||
stats = torch.cuda.memory_stats(q.device)
|
||||
mem_active = stats["active_bytes.all.current"]
|
||||
mem_reserved = stats["reserved_bytes.all.current"]
|
||||
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_cuda + mem_free_torch
|
||||
mem_free_total = get_mem_free_total(q.device)
|
||||
|
||||
gb = 1024**3
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
||||
@ -205,8 +218,6 @@ class CrossAttention(nn.Module):
|
||||
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
||||
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
||||
|
||||
if steps > 64:
|
||||
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
||||
@ -218,7 +229,7 @@ class CrossAttention(nn.Module):
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
s1 = einsum("b i d, b j d -> b i j", q[:, i:end], k) * self.scale
|
||||
s1 = einsum("b i d, b j d -> b i j", q[:, i:end], k)
|
||||
|
||||
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
||||
del s1
|
||||
|
@ -12,11 +12,11 @@ addict==2.4.0
|
||||
# via basicsr
|
||||
aiohttp==3.8.3
|
||||
# via fsspec
|
||||
aiosignal==1.2.0
|
||||
aiosignal==1.3.1
|
||||
# via aiohttp
|
||||
antlr4-python3-runtime==4.8
|
||||
# via omegaconf
|
||||
astroid==2.12.11
|
||||
astroid==2.12.12
|
||||
# via pylint
|
||||
async-timeout==4.0.2
|
||||
# via aiohttp
|
||||
@ -46,18 +46,20 @@ click==8.1.3
|
||||
# typer
|
||||
click-shell==2.1
|
||||
# via imaginAIry (setup.py)
|
||||
contourpy==1.0.5
|
||||
contourpy==1.0.6
|
||||
# via matplotlib
|
||||
coverage==6.5.0
|
||||
# via -r requirements-dev.in
|
||||
cycler==0.11.0
|
||||
# via matplotlib
|
||||
diffusers==0.5.1
|
||||
diffusers==0.7.2
|
||||
# via imaginAIry (setup.py)
|
||||
dill==0.3.5.1
|
||||
dill==0.3.6
|
||||
# via pylint
|
||||
einops==0.3.0
|
||||
# via imaginAIry (setup.py)
|
||||
exceptiongroup==1.0.1
|
||||
# via pytest
|
||||
facexlib==0.2.5
|
||||
# via
|
||||
# gfpgan
|
||||
@ -71,13 +73,13 @@ filelock==3.8.0
|
||||
# transformers
|
||||
filterpy==1.4.5
|
||||
# via facexlib
|
||||
fonttools==4.37.4
|
||||
fonttools==4.38.0
|
||||
# via matplotlib
|
||||
frozenlist==1.3.1
|
||||
frozenlist==1.3.3
|
||||
# via
|
||||
# aiohttp
|
||||
# aiosignal
|
||||
fsspec[http]==2022.8.2
|
||||
fsspec[http]==2022.11.0
|
||||
# via pytorch-lightning
|
||||
ftfy==6.1.1
|
||||
# via imaginAIry (setup.py)
|
||||
@ -89,7 +91,7 @@ gfpgan==1.3.8
|
||||
# via
|
||||
# imaginAIry (setup.py)
|
||||
# realesrgan
|
||||
google-auth==2.12.0
|
||||
google-auth==2.14.1
|
||||
# via
|
||||
# google-auth-oauthlib
|
||||
# tb-nightly
|
||||
@ -98,7 +100,7 @@ google-auth-oauthlib==0.4.6
|
||||
# via
|
||||
# tb-nightly
|
||||
# tensorboard
|
||||
grpcio==1.49.1
|
||||
grpcio==1.50.0
|
||||
# via
|
||||
# tb-nightly
|
||||
# tensorboard
|
||||
@ -127,9 +129,9 @@ kiwisolver==1.4.4
|
||||
# via matplotlib
|
||||
kornia==0.6
|
||||
# via imaginAIry (setup.py)
|
||||
lazy-object-proxy==1.7.1
|
||||
lazy-object-proxy==1.8.0
|
||||
# via astroid
|
||||
libcst==0.4.7
|
||||
libcst==0.4.9
|
||||
# via pycln
|
||||
llvmlite==0.39.1
|
||||
# via numba
|
||||
@ -143,7 +145,7 @@ markdown==3.4.1
|
||||
# tensorboard
|
||||
markupsafe==2.1.1
|
||||
# via werkzeug
|
||||
matplotlib==3.6.1
|
||||
matplotlib==3.6.2
|
||||
# via filterpy
|
||||
mccabe==0.7.0
|
||||
# via
|
||||
@ -157,9 +159,9 @@ mypy-extensions==0.4.3
|
||||
# via
|
||||
# black
|
||||
# typing-inspect
|
||||
networkx==2.8.7
|
||||
networkx==2.8.8
|
||||
# via scikit-image
|
||||
numba==0.56.3
|
||||
numba==0.56.4
|
||||
# via facexlib
|
||||
numpy==1.23.4
|
||||
# via
|
||||
@ -186,7 +188,7 @@ numpy==1.23.4
|
||||
# torchmetrics
|
||||
# torchvision
|
||||
# transformers
|
||||
oauthlib==3.2.1
|
||||
oauthlib==3.2.2
|
||||
# via requests-oauthlib
|
||||
omegaconf==2.1.1
|
||||
# via imaginAIry (setup.py)
|
||||
@ -211,7 +213,7 @@ pathspec==0.9.0
|
||||
# via
|
||||
# black
|
||||
# pycln
|
||||
pillow==9.2.0
|
||||
pillow==9.3.0
|
||||
# via
|
||||
# basicsr
|
||||
# diffusers
|
||||
@ -222,26 +224,26 @@ pillow==9.2.0
|
||||
# realesrgan
|
||||
# scikit-image
|
||||
# torchvision
|
||||
platformdirs==2.5.2
|
||||
platformdirs==2.5.4
|
||||
# via
|
||||
# black
|
||||
# pylint
|
||||
pluggy==1.0.0
|
||||
# via pytest
|
||||
protobuf==3.19.6
|
||||
protobuf==3.20.3
|
||||
# via
|
||||
# imaginAIry (setup.py)
|
||||
# tb-nightly
|
||||
# tensorboard
|
||||
py==1.11.0
|
||||
# via pytest
|
||||
psutil==5.9.4
|
||||
# via imaginAIry (setup.py)
|
||||
pyasn1==0.4.8
|
||||
# via
|
||||
# pyasn1-modules
|
||||
# rsa
|
||||
pyasn1-modules==0.2.8
|
||||
# via google-auth
|
||||
pycln==2.1.1
|
||||
pycln==2.1.2
|
||||
# via -r requirements-dev.in
|
||||
pycodestyle==2.9.1
|
||||
# via pylama
|
||||
@ -255,20 +257,20 @@ pyflakes==2.5.0
|
||||
# via pylama
|
||||
pylama==8.4.1
|
||||
# via -r requirements-dev.in
|
||||
pylint==2.15.4
|
||||
pylint==2.15.5
|
||||
# via -r requirements-dev.in
|
||||
pyparsing==3.0.9
|
||||
# via
|
||||
# matplotlib
|
||||
# packaging
|
||||
pytest==7.1.3
|
||||
pytest==7.2.0
|
||||
# via
|
||||
# -r requirements-dev.in
|
||||
# pytest-randomly
|
||||
# pytest-sugar
|
||||
pytest-randomly==3.12.0
|
||||
# via -r requirements-dev.in
|
||||
pytest-sugar==0.9.5
|
||||
pytest-sugar==0.9.6
|
||||
# via -r requirements-dev.in
|
||||
python-dateutil==2.8.2
|
||||
# via matplotlib
|
||||
@ -289,7 +291,7 @@ pyyaml==6.0
|
||||
# transformers
|
||||
realesrgan==0.3.0
|
||||
# via imaginAIry (setup.py)
|
||||
regex==2022.9.13
|
||||
regex==2022.10.31
|
||||
# via
|
||||
# diffusers
|
||||
# transformers
|
||||
@ -313,7 +315,7 @@ rsa==4.9
|
||||
# via google-auth
|
||||
scikit-image==0.19.3
|
||||
# via basicsr
|
||||
scipy==1.9.2
|
||||
scipy==1.9.3
|
||||
# via
|
||||
# basicsr
|
||||
# facexlib
|
||||
@ -328,11 +330,11 @@ six==1.16.0
|
||||
# python-dateutil
|
||||
snowballstemmer==2.2.0
|
||||
# via pydocstyle
|
||||
tb-nightly==2.11.0a20221016
|
||||
tb-nightly==2.12.0a20221113
|
||||
# via
|
||||
# basicsr
|
||||
# gfpgan
|
||||
tensorboard==2.10.1
|
||||
tensorboard==2.11.0
|
||||
# via pytorch-lightning
|
||||
tensorboard-data-server==0.6.1
|
||||
# via
|
||||
@ -342,7 +344,7 @@ tensorboard-plugin-wit==1.8.1
|
||||
# via
|
||||
# tb-nightly
|
||||
# tensorboard
|
||||
termcolor==2.0.1
|
||||
termcolor==2.1.0
|
||||
# via pytest-sugar
|
||||
tifffile==2022.10.10
|
||||
# via scikit-image
|
||||
@ -357,11 +359,11 @@ tomli==2.0.1
|
||||
# black
|
||||
# pylint
|
||||
# pytest
|
||||
tomlkit==0.11.5
|
||||
tomlkit==0.11.6
|
||||
# via
|
||||
# pycln
|
||||
# pylint
|
||||
torch==1.12.1
|
||||
torch==1.13.0
|
||||
# via
|
||||
# basicsr
|
||||
# facexlib
|
||||
@ -381,7 +383,7 @@ torchmetrics==0.6.0
|
||||
# via
|
||||
# imaginAIry (setup.py)
|
||||
# pytorch-lightning
|
||||
torchvision==0.13.1
|
||||
torchvision==0.14.0
|
||||
# via
|
||||
# basicsr
|
||||
# facexlib
|
||||
@ -401,9 +403,9 @@ tqdm==4.64.1
|
||||
# transformers
|
||||
transformers==4.19.2
|
||||
# via imaginAIry (setup.py)
|
||||
typer==0.6.1
|
||||
typer==0.7.0
|
||||
# via pycln
|
||||
types-toml==0.10.8
|
||||
types-toml==0.10.8.1
|
||||
# via responses
|
||||
typing-extensions==4.4.0
|
||||
# via
|
||||
@ -425,7 +427,7 @@ werkzeug==2.2.2
|
||||
# via
|
||||
# tb-nightly
|
||||
# tensorboard
|
||||
wheel==0.37.1
|
||||
wheel==0.38.4
|
||||
# via
|
||||
# tb-nightly
|
||||
# tensorboard
|
||||
@ -437,7 +439,7 @@ yapf==0.32.0
|
||||
# gfpgan
|
||||
yarl==1.8.1
|
||||
# via aiohttp
|
||||
zipp==3.9.0
|
||||
zipp==3.10.0
|
||||
# via importlib-metadata
|
||||
|
||||
# The following packages are considered to be unsafe in a requirements file:
|
||||
|
Loading…
Reference in New Issue
Block a user