Add vectorized version of beam_search

pull/109/head
Artem Chumachenko 2 years ago
parent 0a1cd3b9ba
commit c242232c52

@ -80,42 +80,44 @@ class BeamSearchAlgorithm(DecodingAlgorithm):
self._cur_num_beams = 1
self.batch_size = batch_size
self._batch_beams = [list() for _ in range(batch_size)]
self._batch_beams = torch.zeros((batch_size, num_beams))
def __call__(self, logits: torch.Tensor):
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
probs = torch.log_softmax(sorted_logits, -1)
if len(self._batch_beams[0]) > 0:
for batch_idx in range(self.batch_size):
new_beams = []
cur_beams = self._batch_beams[batch_idx]
for beam_idx in range(len(cur_beams)):
probs_idx = batch_idx + beam_idx * self.batch_size
new_beam = cur_beams[beam_idx]
for hypo_idx in range(self.num_beams):
new_beams.append(
(new_beam[0] + probs[probs_idx, hypo_idx].item(), beam_idx * self.num_beams + hypo_idx)
)
self._batch_beams[batch_idx] = sorted(new_beams, reverse=True)[: self.num_beams]
hypo_ids = None
if self._cur_num_beams > 1:
permuted_indexes = torch.cat(
[torch.arange(0, self.num_beams) * self.batch_size + i for i in range(self.batch_size)], dim=0
)
probs = probs[:, : self.num_beams][permuted_indexes]
probs = probs.view(self.batch_size, self.num_beams, self.num_beams)
self._batch_beams = self._batch_beams[:, :, None] + probs
self._batch_beams = self._batch_beams.view(self.batch_size, -1)
sorted_batch_beams, sorted_hypo_ids = torch.sort(self._batch_beams, descending=True, dim=-1)
self._batch_beams = sorted_batch_beams[:, : self.num_beams]
hypo_ids = sorted_hypo_ids[:, : self.num_beams]
else:
for batch_idx in range(self.batch_size):
for beam_idx in range(self.num_beams):
self._batch_beams[batch_idx].append((probs[batch_idx, beam_idx].item(), beam_idx))
self._batch_beams = probs[: self.batch_size, : self.num_beams]
self._cur_num_beams = self.num_beams
hypo_ids = torch.tile(
torch.arange(self.num_beams),
(self.batch_size, 1),
)
return_hypos = []
return_tokens = []
for batch_idx in range(self.batch_size):
cur_beam = self._batch_beams[batch_idx]
return_hypos.append(list())
return_tokens.append(list())
for beam in cur_beam:
beam_idx = beam[1] // self.num_beams
hypo_idx = batch_idx + beam_idx * self.batch_size
token_idx = beam[1] % self.num_beams
return_hypos[-1].append(hypo_idx)
return_tokens[-1].append([sorted_indices[hypo_idx, token_idx].item()])
return_hypos = [hypo_idx for hypo_indexes in zip(*return_hypos) for hypo_idx in hypo_indexes]
return_tokens = [token_idx for token_indexes in zip(*return_tokens) for token_idx in token_indexes]
return torch.tensor(return_tokens), torch.tensor(return_hypos)
cur_beam = hypo_ids[batch_idx]
hypo_idx = batch_idx + torch.floor_divide(cur_beam, self.num_beams) * self.batch_size
return_hypos.append(hypo_idx)
return_tokens.append(sorted_indices[hypo_idx, cur_beam % self.num_beams].unsqueeze(-1))
return_indexes = torch.cat(
[torch.arange(0, self.batch_size) * self.num_beams + i for i in range(self.num_beams)], dim=0
)
return_tokens = torch.cat(return_tokens, 0)
return_hypos = torch.cat(return_hypos, 0)
return return_tokens[return_indexes], return_hypos[return_indexes]

Loading…
Cancel
Save