diff --git a/imaginairy/modules/attention.py b/imaginairy/modules/attention.py index 6d6221b..0f982b6 100644 --- a/imaginairy/modules/attention.py +++ b/imaginairy/modules/attention.py @@ -248,6 +248,11 @@ class CrossAttention(nn.Module): ) slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + if get_device() == "mps": + # https://github.com/brycedrennan/imaginAIry/issues/175 + # https://github.com/invoke-ai/InvokeAI/issues/1244 + slice_size = min(slice_size, 2**30) + for i in range(0, q.shape[1], slice_size): end = i + slice_size