mirror of
https://github.com/bigscience-workshop/petals
synced 2024-10-31 09:20:41 +00:00
style fixes
This commit is contained in:
parent
eb1334e567
commit
ce22b6a47b
@ -8,7 +8,7 @@ from src.utils.generation_algorithms import (
|
||||
DecodingAlgorithm,
|
||||
GreedyAlgorithm,
|
||||
NucleusAlgorithm,
|
||||
TopKAlgorithm
|
||||
TopKAlgorithm,
|
||||
)
|
||||
from src.utils.generation_constraints import ABCBloomConstraint, EosConstraint
|
||||
|
||||
|
@ -80,12 +80,17 @@ class BeamSearchAlgorithm(DecodingAlgorithm):
|
||||
self._cur_num_beams = 1
|
||||
self.batch_size = batch_size
|
||||
|
||||
self._logits = torch.zeros((self.batch_size, self._cur_num_beams,))
|
||||
|
||||
self._logits = torch.zeros(
|
||||
(
|
||||
self.batch_size,
|
||||
self._cur_num_beams,
|
||||
)
|
||||
)
|
||||
|
||||
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||
probs = torch.softmax(sorted_logits, -1)
|
||||
|
||||
|
||||
new_logits = torch.cat([self._logits] * self.num_beams, dim=-1)
|
||||
for batch_idx in range(self.batch_size):
|
||||
for cur_beam_idx in range(self._cur_num_beams):
|
||||
@ -95,9 +100,9 @@ class BeamSearchAlgorithm(DecodingAlgorithm):
|
||||
self._cur_num_beams = self.num_beams
|
||||
|
||||
new_sorted_logits, new_sorted_indices = torch.sort(new_logits, descending=True, dim=-1)
|
||||
new_sorted_indices = new_sorted_indices[:, :self.num_beams].T.flatten()
|
||||
self._logits = new_sorted_logits[:, :self.num_beams]
|
||||
new_sorted_indices = new_sorted_indices[:, : self.num_beams].T.flatten()
|
||||
self._logits = new_sorted_logits[:, : self.num_beams]
|
||||
result_tokens = sorted_indices[torch.arange(self.num_beams * self.batch_size), new_sorted_indices]
|
||||
result_hypos = torch.div(new_sorted_indices, self.num_beams, rounding_mode='floor')
|
||||
result_hypos = torch.div(new_sorted_indices, self.num_beams, rounding_mode="floor")
|
||||
|
||||
return result_tokens.unsqueeze(-1), result_hypos
|
||||
|
Loading…
Reference in New Issue
Block a user