|
|
|
@ -12,7 +12,6 @@ from functools import partial
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
|
|
|
|
|
from timm.models.helpers import adapt_input_conv
|
|
|
|
|
from timm.models.layers import DropPath, trunc_normal_
|
|
|
|
|
from timm.models.vision_transformer import PatchEmbed
|
|
|
|
@ -144,8 +143,7 @@ class Block(nn.Module):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if use_grad_checkpointing:
|
|
|
|
|
self.attn = checkpoint_wrapper(self.attn)
|
|
|
|
|
self.mlp = checkpoint_wrapper(self.mlp)
|
|
|
|
|
raise RuntimeError("not supported")
|
|
|
|
|
|
|
|
|
|
def forward(self, x, register_hook=False):
|
|
|
|
|
x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
|
|
|
|
|