|
|
|
@ -1,4 +1,5 @@
|
|
|
|
|
import math
|
|
|
|
|
from functools import lru_cache
|
|
|
|
|
|
|
|
|
|
import psutil
|
|
|
|
|
import torch
|
|
|
|
@ -151,6 +152,11 @@ def get_mem_free_total(device):
|
|
|
|
|
return mem_free_total
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@lru_cache(maxsize=1)
|
|
|
|
|
def get_mps_gb_ram():
|
|
|
|
|
return psutil.virtual_memory().total / (1024**3)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CrossAttention(nn.Module):
|
|
|
|
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
|
|
|
|
super().__init__()
|
|
|
|
@ -235,23 +241,34 @@ class CrossAttention(nn.Module):
|
|
|
|
|
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
|
|
|
|
modifier = 3 if q.element_size() == 2 else 2.5
|
|
|
|
|
mem_required = tensor_size * modifier
|
|
|
|
|
|
|
|
|
|
steps = 1
|
|
|
|
|
|
|
|
|
|
if mem_required > mem_free_total:
|
|
|
|
|
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
|
|
|
|
|
if "mps" in get_device():
|
|
|
|
|
# https://github.com/brycedrennan/imaginAIry/issues/175
|
|
|
|
|
# https://github.com/invoke-ai/InvokeAI/issues/1244
|
|
|
|
|
mps_gb = get_mps_gb_ram()
|
|
|
|
|
factor = 32 / mps_gb
|
|
|
|
|
|
|
|
|
|
if steps > 64:
|
|
|
|
|
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
f"Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). "
|
|
|
|
|
f"Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free"
|
|
|
|
|
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1] * 16 * factor))
|
|
|
|
|
else:
|
|
|
|
|
if mem_required > mem_free_total:
|
|
|
|
|
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
|
|
|
|
|
|
|
|
|
|
if steps > 64:
|
|
|
|
|
max_res = (
|
|
|
|
|
math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
|
|
|
|
)
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
f"Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). "
|
|
|
|
|
f"Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free"
|
|
|
|
|
)
|
|
|
|
|
slice_size = (
|
|
|
|
|
q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
|
|
|
|
if get_device() == "mps":
|
|
|
|
|
# https://github.com/brycedrennan/imaginAIry/issues/175
|
|
|
|
|
# https://github.com/invoke-ai/InvokeAI/issues/1244
|
|
|
|
|
slice_size = min(slice_size, 2**30)
|
|
|
|
|
# steps = len(range(0, q.shape[1], slice_size))
|
|
|
|
|
# print(f"Splitting attention into {steps} steps of {slice_size} slices")
|
|
|
|
|
|
|
|
|
|
for i in range(0, q.shape[1], slice_size):
|
|
|
|
|
end = i + slice_size
|
|
|
|
|