style fixes

This commit is contained in:
Artem Chumachenko 2022-11-06 18:43:52 +04:00
parent eb1334e567
commit ce22b6a47b
2 changed files with 12 additions and 7 deletions

View File

@ -8,7 +8,7 @@ from src.utils.generation_algorithms import (
DecodingAlgorithm,
GreedyAlgorithm,
NucleusAlgorithm,
TopKAlgorithm
TopKAlgorithm,
)
from src.utils.generation_constraints import ABCBloomConstraint, EosConstraint

View File

@ -80,12 +80,17 @@ class BeamSearchAlgorithm(DecodingAlgorithm):
self._cur_num_beams = 1
self.batch_size = batch_size
self._logits = torch.zeros((self.batch_size, self._cur_num_beams,))
self._logits = torch.zeros(
(
self.batch_size,
self._cur_num_beams,
)
)
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
probs = torch.softmax(sorted_logits, -1)
new_logits = torch.cat([self._logits] * self.num_beams, dim=-1)
for batch_idx in range(self.batch_size):
for cur_beam_idx in range(self._cur_num_beams):
@ -95,9 +100,9 @@ class BeamSearchAlgorithm(DecodingAlgorithm):
self._cur_num_beams = self.num_beams
new_sorted_logits, new_sorted_indices = torch.sort(new_logits, descending=True, dim=-1)
new_sorted_indices = new_sorted_indices[:, :self.num_beams].T.flatten()
self._logits = new_sorted_logits[:, :self.num_beams]
new_sorted_indices = new_sorted_indices[:, : self.num_beams].T.flatten()
self._logits = new_sorted_logits[:, : self.num_beams]
result_tokens = sorted_indices[torch.arange(self.num_beams * self.batch_size), new_sorted_indices]
result_hypos = torch.div(new_sorted_indices, self.num_beams, rounding_mode='floor')
result_hypos = torch.div(new_sorted_indices, self.num_beams, rounding_mode="floor")
return result_tokens.unsqueeze(-1), result_hypos