|
|
|
@ -1,5 +1,6 @@
|
|
|
|
|
import math
|
|
|
|
|
|
|
|
|
|
import psutil
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
from einops import rearrange
|
|
|
|
@ -122,6 +123,23 @@ class SpatialSelfAttention(nn.Module):
|
|
|
|
|
return x + h_
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_mem_free_total(device):
|
|
|
|
|
device_type = "mps" if device.type == "mps" else "cuda"
|
|
|
|
|
if device_type == "cuda":
|
|
|
|
|
stats = torch.cuda.memory_stats(device)
|
|
|
|
|
mem_active = stats["active_bytes.all.current"]
|
|
|
|
|
mem_reserved = stats["reserved_bytes.all.current"]
|
|
|
|
|
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
|
|
|
|
mem_free_torch = mem_reserved - mem_active
|
|
|
|
|
mem_free_total = mem_free_cuda + mem_free_torch
|
|
|
|
|
mem_free_total *= 0.9
|
|
|
|
|
else:
|
|
|
|
|
# if we don't add a buffer, larger images come out as noise
|
|
|
|
|
mem_free_total = psutil.virtual_memory().available * 0.6
|
|
|
|
|
|
|
|
|
|
return mem_free_total
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CrossAttention(nn.Module):
|
|
|
|
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
|
|
|
|
super().__init__()
|
|
|
|
@ -145,8 +163,8 @@ class CrossAttention(nn.Module):
|
|
|
|
|
# if mask is None and _global_mask_hack is not None:
|
|
|
|
|
# mask = _global_mask_hack.to(torch.bool)
|
|
|
|
|
|
|
|
|
|
if get_device() == "cuda":
|
|
|
|
|
return self.forward_cuda(x, context=context, mask=mask)
|
|
|
|
|
if get_device() == "cuda" or "mps" in get_device():
|
|
|
|
|
return self.forward_splitmem(x, context=context, mask=mask)
|
|
|
|
|
|
|
|
|
|
h = self.heads
|
|
|
|
|
|
|
|
|
@ -174,12 +192,12 @@ class CrossAttention(nn.Module):
|
|
|
|
|
out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
|
|
|
|
|
return self.to_out(out)
|
|
|
|
|
|
|
|
|
|
def forward_cuda(self, x, context=None, mask=None): # noqa
|
|
|
|
|
def forward_splitmem(self, x, context=None, mask=None): # noqa
|
|
|
|
|
h = self.heads
|
|
|
|
|
|
|
|
|
|
q_in = self.to_q(x)
|
|
|
|
|
context = context if context is not None else x
|
|
|
|
|
k_in = self.to_k(context)
|
|
|
|
|
k_in = self.to_k(context) * self.scale
|
|
|
|
|
v_in = self.to_v(context)
|
|
|
|
|
del context, x
|
|
|
|
|
|
|
|
|
@ -190,12 +208,7 @@ class CrossAttention(nn.Module):
|
|
|
|
|
|
|
|
|
|
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
|
|
|
|
|
|
|
|
|
|
stats = torch.cuda.memory_stats(q.device)
|
|
|
|
|
mem_active = stats["active_bytes.all.current"]
|
|
|
|
|
mem_reserved = stats["reserved_bytes.all.current"]
|
|
|
|
|
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
|
|
|
|
mem_free_torch = mem_reserved - mem_active
|
|
|
|
|
mem_free_total = mem_free_cuda + mem_free_torch
|
|
|
|
|
mem_free_total = get_mem_free_total(q.device)
|
|
|
|
|
|
|
|
|
|
gb = 1024**3
|
|
|
|
|
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
|
|
|
@ -205,8 +218,6 @@ class CrossAttention(nn.Module):
|
|
|
|
|
|
|
|
|
|
if mem_required > mem_free_total:
|
|
|
|
|
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
|
|
|
|
|
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
|
|
|
|
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
|
|
|
|
|
|
|
|
|
if steps > 64:
|
|
|
|
|
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
|
|
|
@ -218,7 +229,7 @@ class CrossAttention(nn.Module):
|
|
|
|
|
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
|
|
|
|
for i in range(0, q.shape[1], slice_size):
|
|
|
|
|
end = i + slice_size
|
|
|
|
|
s1 = einsum("b i d, b j d -> b i j", q[:, i:end], k) * self.scale
|
|
|
|
|
s1 = einsum("b i d, b j d -> b i j", q[:, i:end], k)
|
|
|
|
|
|
|
|
|
|
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
|
|
|
|
del s1
|
|
|
|
|