|
|
|
@ -449,16 +449,19 @@ class SpatialTransformer(nn.Module):
|
|
|
|
|
b, c, h, w = x.shape # noqa
|
|
|
|
|
x_in = x
|
|
|
|
|
x = self.norm(x)
|
|
|
|
|
if not self.use_linear:
|
|
|
|
|
x = self.proj_in(x)
|
|
|
|
|
x = rearrange(x, "b c h w -> b (h w) c").contiguous()
|
|
|
|
|
if self.use_linear:
|
|
|
|
|
x = rearrange(x, "b c h w -> b (h w) c").contiguous()
|
|
|
|
|
x = self.proj_in(x)
|
|
|
|
|
for i, block in enumerate(self.transformer_blocks):
|
|
|
|
|
x = block(x, context=context[i])
|
|
|
|
|
if self.use_linear:
|
|
|
|
|
for i, block in enumerate(self.transformer_blocks):
|
|
|
|
|
x = block(x, context=context[i])
|
|
|
|
|
x = self.proj_out(x)
|
|
|
|
|
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
|
|
|
|
|
if not self.use_linear:
|
|
|
|
|
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
|
|
|
|
|
else:
|
|
|
|
|
x = self.proj_in(x)
|
|
|
|
|
x = rearrange(x, "b c h w -> b (h w) c")
|
|
|
|
|
for i, block in enumerate(self.transformer_blocks):
|
|
|
|
|
x = block(x, context=context[i])
|
|
|
|
|
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
|
|
|
|
|
x = self.proj_out(x)
|
|
|
|
|
|
|
|
|
|
return x + x_in
|
|
|
|
|