|
|
|
@ -83,41 +83,27 @@ class BeamSearchAlgorithm(DecodingAlgorithm):
|
|
|
|
|
self._batch_beams = torch.zeros((batch_size, num_beams))
|
|
|
|
|
|
|
|
|
|
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
|
|
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
|
|
|
|
probs = torch.log_softmax(sorted_logits, -1)
|
|
|
|
|
logits = torch.log_softmax(logits, -1)
|
|
|
|
|
probs, topk_indices = torch.topk(logits, k=self.num_beams, dim=-1)
|
|
|
|
|
|
|
|
|
|
hypo_ids = None
|
|
|
|
|
if self._cur_num_beams > 1:
|
|
|
|
|
permuted_indexes = torch.cat(
|
|
|
|
|
[torch.arange(0, self.num_beams) * self.batch_size + i for i in range(self.batch_size)], dim=0
|
|
|
|
|
)
|
|
|
|
|
probs = probs[:, : self.num_beams][permuted_indexes]
|
|
|
|
|
probs = probs.view(self.batch_size, self.num_beams, self.num_beams)
|
|
|
|
|
probs = probs.reshape(self.batch_size, self.num_beams, self.num_beams)
|
|
|
|
|
self._batch_beams = self._batch_beams[:, :, None] + probs
|
|
|
|
|
self._batch_beams = self._batch_beams.view(self.batch_size, -1)
|
|
|
|
|
sorted_batch_beams, sorted_hypo_ids = torch.sort(self._batch_beams, descending=True, dim=-1)
|
|
|
|
|
self._batch_beams = sorted_batch_beams[:, : self.num_beams]
|
|
|
|
|
hypo_ids = sorted_hypo_ids[:, : self.num_beams]
|
|
|
|
|
self._batch_beams, hypo_ids = torch.topk(self._batch_beams, k=self.num_beams, dim=-1)
|
|
|
|
|
else:
|
|
|
|
|
self._batch_beams = probs[: self.batch_size, : self.num_beams]
|
|
|
|
|
self._batch_beams = probs[:self.batch_size, :self.num_beams]
|
|
|
|
|
self._cur_num_beams = self.num_beams
|
|
|
|
|
hypo_ids = torch.tile(
|
|
|
|
|
torch.arange(self.num_beams),
|
|
|
|
|
torch.arange(self.num_beams, device=probs.device),
|
|
|
|
|
(self.batch_size, 1),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return_hypos = []
|
|
|
|
|
return_tokens = []
|
|
|
|
|
for batch_idx in range(self.batch_size):
|
|
|
|
|
cur_beam = hypo_ids[batch_idx]
|
|
|
|
|
hypo_idx = batch_idx + torch.floor_divide(cur_beam, self.num_beams) * self.batch_size
|
|
|
|
|
return_hypos.append(hypo_idx)
|
|
|
|
|
return_tokens.append(sorted_indices[hypo_idx, cur_beam % self.num_beams].unsqueeze(-1))
|
|
|
|
|
|
|
|
|
|
return_indexes = torch.cat(
|
|
|
|
|
[torch.arange(0, self.batch_size) * self.num_beams + i for i in range(self.num_beams)], dim=0
|
|
|
|
|
)
|
|
|
|
|
return_tokens = torch.cat(return_tokens, 0)
|
|
|
|
|
return_hypos = torch.cat(return_hypos, 0)
|
|
|
|
|
|
|
|
|
|
return return_tokens[return_indexes], return_hypos[return_indexes]
|
|
|
|
|
return_hypos = (
|
|
|
|
|
torch.arange(self.batch_size, device=probs.device)[:, None] +
|
|
|
|
|
torch.div(hypo_ids, self.num_beams, rounding_mode="floor") * self.batch_size
|
|
|
|
|
).reshape(-1)
|
|
|
|
|
return_tokens = topk_indices[return_hypos, (hypo_ids % self.num_beams).reshape(-1)].unsqueeze(-1)
|
|
|
|
|
|
|
|
|
|
return return_tokens, return_hypos
|
|
|
|
|