|
|
|
@ -10,13 +10,16 @@ from torch import einsum, nn
|
|
|
|
|
from imaginairy.modules.diffusion.util import checkpoint as checkpoint_eval
|
|
|
|
|
from imaginairy.utils import get_device
|
|
|
|
|
|
|
|
|
|
XFORMERS_IS_AVAILABLE = False
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
if get_device() == "cuda":
|
|
|
|
|
import xformers # noqa
|
|
|
|
|
import xformers.ops # noqa
|
|
|
|
|
|
|
|
|
|
XFORMERS_IS_AVAILBLE = True
|
|
|
|
|
XFORMERS_IS_AVAILABLE = True
|
|
|
|
|
except ImportError:
|
|
|
|
|
XFORMERS_IS_AVAILBLE = False
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ALLOW_SPLITMEM = True
|
|
|
|
@ -181,7 +184,7 @@ class CrossAttention(nn.Module):
|
|
|
|
|
# mask = _global_mask_hack.to(torch.bool)
|
|
|
|
|
|
|
|
|
|
if get_device() == "cuda" or "mps" in get_device():
|
|
|
|
|
if not XFORMERS_IS_AVAILBLE and ALLOW_SPLITMEM:
|
|
|
|
|
if not XFORMERS_IS_AVAILABLE and ALLOW_SPLITMEM:
|
|
|
|
|
return self.forward_splitmem(x, context=context, mask=mask)
|
|
|
|
|
|
|
|
|
|
h = self.heads
|
|
|
|
@ -368,7 +371,7 @@ class BasicTransformerBlock(nn.Module):
|
|
|
|
|
disable_self_attn=False,
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
|
|
|
|
|
attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILABLE else "softmax"
|
|
|
|
|
assert attn_mode in self.ATTENTION_MODES
|
|
|
|
|
attn_cls = self.ATTENTION_MODES[attn_mode]
|
|
|
|
|
self.disable_self_attn = disable_self_attn
|
|
|
|
|