From 8491ed2bd30d89edca0cac06a86f28fe67475ea0 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sat, 3 Dec 2022 14:02:48 +0300 Subject: [PATCH] Add checks for forward() inputs on the client side (#123) --- src/petals/client/remote_sequential.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/petals/client/remote_sequential.py b/src/petals/client/remote_sequential.py index cb53979..aee8d67 100644 --- a/src/petals/client/remote_sequential.py +++ b/src/petals/client/remote_sequential.py @@ -53,6 +53,8 @@ class RemoteSequential(nn.Module): self.is_subsequence = self.sequence_manager.sequence_info.block_uids != block_uids def forward(self, inputs: torch.Tensor, prompts: torch.Tensor = DUMMY): + assert inputs.ndim == 3, "inputs must be a tensor of shape [batch_size, seq_length, hidden_size]" + assert inputs.shape[1] <= 2048, "The sequence length is capped at 2048 tokens in this version" outputs = _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager) return outputs