|
|
|
@ -49,7 +49,7 @@ class RemoteSequential(nn.Module):
|
|
|
|
|
logger.debug(f"Reusing sequence manager with {len(sequence_manager)} modules")
|
|
|
|
|
self.sequence_manager = sequence_manager
|
|
|
|
|
assert isinstance(sequence_manager.block_uids, list)
|
|
|
|
|
self.is_subsequence = self.sequence_manager.block_uids == block_uids
|
|
|
|
|
self.is_subsequence = self.sequence_manager.block_uids != block_uids
|
|
|
|
|
|
|
|
|
|
def forward(self, inputs: torch.Tensor):
|
|
|
|
|
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3 and inputs.shape[-1] == self.config.n_embed
|
|
|
|
|