diff --git a/src/petals/utils/generation_algorithms.py b/src/petals/utils/generation_algorithms.py index d58f073..d085e8b 100644 --- a/src/petals/utils/generation_algorithms.py +++ b/src/petals/utils/generation_algorithms.py @@ -16,7 +16,7 @@ class DecodingAlgorithm(ABC): @abstractmethod def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]: """ - :param logits: A tensor of shape (batch_size, seq_lenth, vocab_size) + :param logits: A tensor of shape (batch_size, seq_length, vocab_size) :return: A tuple of selected token ids and corresponding hypotheses. The shape of the token ids is (batch_size, seq_length), and the shape of the hypotheses is (batch_size) """