|
|
@ -16,7 +16,7 @@ class DecodingAlgorithm(ABC):
|
|
|
|
@abstractmethod
|
|
|
|
@abstractmethod
|
|
|
|
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
|
|
|
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
:param logits: A tensor of shape (batch_size, seq_lenth, vocab_size)
|
|
|
|
:param logits: A tensor of shape (batch_size, seq_length, vocab_size)
|
|
|
|
:return: A tuple of selected token ids and corresponding hypotheses.
|
|
|
|
:return: A tuple of selected token ids and corresponding hypotheses.
|
|
|
|
The shape of the token ids is (batch_size, seq_length), and the shape of the hypotheses is (batch_size)
|
|
|
|
The shape of the token ids is (batch_size, seq_length), and the shape of the hypotheses is (batch_size)
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|