diff --git a/src/client/remote_generation.py b/src/client/remote_generation.py index 21d2fa8..a7cbb2d 100644 --- a/src/client/remote_generation.py +++ b/src/client/remote_generation.py @@ -140,11 +140,13 @@ class RemoteGenerationMixin: :, seq_idx : seq_idx + 1 ] + pad_token_mask * last_token_id - if torch.all(last_token_id == eos_token_id) or len(outputs) >= max_new_tokens: - break + if num_beams > 1: + outputs[-1] = outputs[-1][hypo_ids] outputs.append(last_token_id) seq_idx += 1 + if torch.all(last_token_id == eos_token_id) or len(outputs) > max_new_tokens: + break return torch.cat(outputs, dim=-1) diff --git a/src/utils/generation_algorithms.py b/src/utils/generation_algorithms.py index c6e4096..a0cd2de 100644 --- a/src/utils/generation_algorithms.py +++ b/src/utils/generation_algorithms.py @@ -80,29 +80,33 @@ class BeamSearchAlgorithm(DecodingAlgorithm): self._cur_num_beams = 1 self.batch_size = batch_size - self._logits = torch.zeros( - ( - self.batch_size, - self._cur_num_beams, - ) - ) + self._beams = [] def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]: sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) - probs = torch.softmax(sorted_logits, -1) + probs = torch.log_softmax(sorted_logits, -1) - new_logits = torch.cat([self._logits] * self.num_beams, dim=-1) + if len(self._beams) > 0: + new_beams = [] + for batch_idx in range(self.batch_size): + for beam_idx in range(self.num_beams): + new_beam = self._beams[beam_idx] + for hypo_idx in range(self.num_beams): + probs_idx = batch_idx + beam_idx * self.batch_size + new_beams.append((beam_idx, new_beam[1] + probs[probs_idx, hypo_idx].item())) + new_beams = sorted(new_beams, key=lambda x: x[1], reverse=True) + self._beams = new_beams[: self.batch_size * self.num_beams] + else: + for batch_idx in range(self.batch_size): + for beam_idx in range(self.num_beams): + self._beams.append((beam_idx, probs[batch_idx, beam_idx].item())) + + return_hypos = [] + return_tokens = [] for batch_idx in range(self.batch_size): - for cur_beam_idx in range(self._cur_num_beams): - for new_beam_idx in range(self.num_beams): - logit = probs[cur_beam_idx * self.batch_size + batch_idx, new_beam_idx] - new_logits[batch_idx, cur_beam_idx * self.num_beams + new_beam_idx] += logit - self._cur_num_beams = self.num_beams + for beam_idx in range(self.num_beams): + hypo_idx = batch_idx + beam_idx * self.batch_size + return_hypos.append(self._beams[hypo_idx][0]) + return_tokens.append([sorted_indices[batch_idx, beam_idx].item()]) - new_sorted_logits, new_sorted_indices = torch.sort(new_logits, descending=True, dim=-1) - new_sorted_indices = new_sorted_indices[:, : self.num_beams].T.flatten() - self._logits = new_sorted_logits[:, : self.num_beams] - result_tokens = sorted_indices[torch.arange(self.num_beams * self.batch_size), new_sorted_indices] - result_hypos = torch.div(new_sorted_indices, self.num_beams, rounding_mode="floor") - - return result_tokens.unsqueeze(-1), result_hypos + return torch.tensor(return_tokens), torch.tensor(return_hypos) diff --git a/tests/test_full_model.py b/tests/test_full_model.py index b0ce824..c74324a 100644 --- a/tests/test_full_model.py +++ b/tests/test_full_model.py @@ -3,6 +3,7 @@ import torch import transformers from hivemind import get_logger, use_hivemind_log_handler from test_utils import * +from transformers.generation_utils import BeamSearchScorer from src.bloom.model import BloomForCausalLM from src.client.remote_model import DistributedBloomForCausalLM @@ -89,3 +90,29 @@ def test_greedy_generation(max_new_tokens=4): assert torch.allclose( remote_outputs_batch, hf_outputs_batch ), "Greedy search are not identical to HF in multibatch mode" + + +@pytest.mark.forked +def test_greedy_generation(max_new_tokens=4, num_beams=2): + tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME) + model = DistributedBloomForCausalLM.from_pretrained( + MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32 + ) + inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"] + remote_outputs = model.generate( + inputs, + max_new_tokens=max_new_tokens, + num_beams=num_beams, + ) + beam_scorer = BeamSearchScorer( + batch_size=inputs.size(0), + num_beams=num_beams, + device=inputs.device, + length_penalty=0, + do_early_stopping=False, + num_beam_hyps_to_keep=2, + ) + hf_outputs = BloomForCausalLM.beam_search( + model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens, beam_scorer=beam_scorer + ) + assert torch.allclose(remote_outputs, hf_outputs), "Beam search are not identical to HF"