Make beam_search identical

This commit is contained in:
Artem Chumachenko 2022-11-13 20:06:45 +04:00
parent ce22b6a47b
commit 9bde866eb3
3 changed files with 55 additions and 22 deletions

View File

@ -140,11 +140,13 @@ class RemoteGenerationMixin:
:, seq_idx : seq_idx + 1
] + pad_token_mask * last_token_id
if torch.all(last_token_id == eos_token_id) or len(outputs) >= max_new_tokens:
break
if num_beams > 1:
outputs[-1] = outputs[-1][hypo_ids]
outputs.append(last_token_id)
seq_idx += 1
if torch.all(last_token_id == eos_token_id) or len(outputs) > max_new_tokens:
break
return torch.cat(outputs, dim=-1)

View File

@ -80,29 +80,33 @@ class BeamSearchAlgorithm(DecodingAlgorithm):
self._cur_num_beams = 1
self.batch_size = batch_size
self._logits = torch.zeros(
(
self.batch_size,
self._cur_num_beams,
)
)
self._beams = []
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
probs = torch.softmax(sorted_logits, -1)
probs = torch.log_softmax(sorted_logits, -1)
new_logits = torch.cat([self._logits] * self.num_beams, dim=-1)
if len(self._beams) > 0:
new_beams = []
for batch_idx in range(self.batch_size):
for beam_idx in range(self.num_beams):
new_beam = self._beams[beam_idx]
for hypo_idx in range(self.num_beams):
probs_idx = batch_idx + beam_idx * self.batch_size
new_beams.append((beam_idx, new_beam[1] + probs[probs_idx, hypo_idx].item()))
new_beams = sorted(new_beams, key=lambda x: x[1], reverse=True)
self._beams = new_beams[: self.batch_size * self.num_beams]
else:
for batch_idx in range(self.batch_size):
for beam_idx in range(self.num_beams):
self._beams.append((beam_idx, probs[batch_idx, beam_idx].item()))
return_hypos = []
return_tokens = []
for batch_idx in range(self.batch_size):
for cur_beam_idx in range(self._cur_num_beams):
for new_beam_idx in range(self.num_beams):
logit = probs[cur_beam_idx * self.batch_size + batch_idx, new_beam_idx]
new_logits[batch_idx, cur_beam_idx * self.num_beams + new_beam_idx] += logit
self._cur_num_beams = self.num_beams
for beam_idx in range(self.num_beams):
hypo_idx = batch_idx + beam_idx * self.batch_size
return_hypos.append(self._beams[hypo_idx][0])
return_tokens.append([sorted_indices[batch_idx, beam_idx].item()])
new_sorted_logits, new_sorted_indices = torch.sort(new_logits, descending=True, dim=-1)
new_sorted_indices = new_sorted_indices[:, : self.num_beams].T.flatten()
self._logits = new_sorted_logits[:, : self.num_beams]
result_tokens = sorted_indices[torch.arange(self.num_beams * self.batch_size), new_sorted_indices]
result_hypos = torch.div(new_sorted_indices, self.num_beams, rounding_mode="floor")
return result_tokens.unsqueeze(-1), result_hypos
return torch.tensor(return_tokens), torch.tensor(return_hypos)

View File

@ -3,6 +3,7 @@ import torch
import transformers
from hivemind import get_logger, use_hivemind_log_handler
from test_utils import *
from transformers.generation_utils import BeamSearchScorer
from src.bloom.model import BloomForCausalLM
from src.client.remote_model import DistributedBloomForCausalLM
@ -89,3 +90,29 @@ def test_greedy_generation(max_new_tokens=4):
assert torch.allclose(
remote_outputs_batch, hf_outputs_batch
), "Greedy search are not identical to HF in multibatch mode"
@pytest.mark.forked
def test_greedy_generation(max_new_tokens=4, num_beams=2):
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
model = DistributedBloomForCausalLM.from_pretrained(
MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
)
inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
remote_outputs = model.generate(
inputs,
max_new_tokens=max_new_tokens,
num_beams=num_beams,
)
beam_scorer = BeamSearchScorer(
batch_size=inputs.size(0),
num_beams=num_beams,
device=inputs.device,
length_penalty=0,
do_early_stopping=False,
num_beam_hyps_to_keep=2,
)
hf_outputs = BloomForCausalLM.beam_search(
model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens, beam_scorer=beam_scorer
)
assert torch.allclose(remote_outputs, hf_outputs), "Beam search are not identical to HF"