mirror of
https://github.com/bigscience-workshop/petals
synced 2024-10-31 09:20:41 +00:00
Make beam_search identical
This commit is contained in:
parent
ce22b6a47b
commit
9bde866eb3
@ -140,11 +140,13 @@ class RemoteGenerationMixin:
|
||||
:, seq_idx : seq_idx + 1
|
||||
] + pad_token_mask * last_token_id
|
||||
|
||||
if torch.all(last_token_id == eos_token_id) or len(outputs) >= max_new_tokens:
|
||||
break
|
||||
if num_beams > 1:
|
||||
outputs[-1] = outputs[-1][hypo_ids]
|
||||
|
||||
outputs.append(last_token_id)
|
||||
seq_idx += 1
|
||||
if torch.all(last_token_id == eos_token_id) or len(outputs) > max_new_tokens:
|
||||
break
|
||||
|
||||
return torch.cat(outputs, dim=-1)
|
||||
|
||||
|
@ -80,29 +80,33 @@ class BeamSearchAlgorithm(DecodingAlgorithm):
|
||||
self._cur_num_beams = 1
|
||||
self.batch_size = batch_size
|
||||
|
||||
self._logits = torch.zeros(
|
||||
(
|
||||
self.batch_size,
|
||||
self._cur_num_beams,
|
||||
)
|
||||
)
|
||||
self._beams = []
|
||||
|
||||
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||
probs = torch.softmax(sorted_logits, -1)
|
||||
probs = torch.log_softmax(sorted_logits, -1)
|
||||
|
||||
new_logits = torch.cat([self._logits] * self.num_beams, dim=-1)
|
||||
if len(self._beams) > 0:
|
||||
new_beams = []
|
||||
for batch_idx in range(self.batch_size):
|
||||
for beam_idx in range(self.num_beams):
|
||||
new_beam = self._beams[beam_idx]
|
||||
for hypo_idx in range(self.num_beams):
|
||||
probs_idx = batch_idx + beam_idx * self.batch_size
|
||||
new_beams.append((beam_idx, new_beam[1] + probs[probs_idx, hypo_idx].item()))
|
||||
new_beams = sorted(new_beams, key=lambda x: x[1], reverse=True)
|
||||
self._beams = new_beams[: self.batch_size * self.num_beams]
|
||||
else:
|
||||
for batch_idx in range(self.batch_size):
|
||||
for beam_idx in range(self.num_beams):
|
||||
self._beams.append((beam_idx, probs[batch_idx, beam_idx].item()))
|
||||
|
||||
return_hypos = []
|
||||
return_tokens = []
|
||||
for batch_idx in range(self.batch_size):
|
||||
for cur_beam_idx in range(self._cur_num_beams):
|
||||
for new_beam_idx in range(self.num_beams):
|
||||
logit = probs[cur_beam_idx * self.batch_size + batch_idx, new_beam_idx]
|
||||
new_logits[batch_idx, cur_beam_idx * self.num_beams + new_beam_idx] += logit
|
||||
self._cur_num_beams = self.num_beams
|
||||
for beam_idx in range(self.num_beams):
|
||||
hypo_idx = batch_idx + beam_idx * self.batch_size
|
||||
return_hypos.append(self._beams[hypo_idx][0])
|
||||
return_tokens.append([sorted_indices[batch_idx, beam_idx].item()])
|
||||
|
||||
new_sorted_logits, new_sorted_indices = torch.sort(new_logits, descending=True, dim=-1)
|
||||
new_sorted_indices = new_sorted_indices[:, : self.num_beams].T.flatten()
|
||||
self._logits = new_sorted_logits[:, : self.num_beams]
|
||||
result_tokens = sorted_indices[torch.arange(self.num_beams * self.batch_size), new_sorted_indices]
|
||||
result_hypos = torch.div(new_sorted_indices, self.num_beams, rounding_mode="floor")
|
||||
|
||||
return result_tokens.unsqueeze(-1), result_hypos
|
||||
return torch.tensor(return_tokens), torch.tensor(return_hypos)
|
||||
|
@ -3,6 +3,7 @@ import torch
|
||||
import transformers
|
||||
from hivemind import get_logger, use_hivemind_log_handler
|
||||
from test_utils import *
|
||||
from transformers.generation_utils import BeamSearchScorer
|
||||
|
||||
from src.bloom.model import BloomForCausalLM
|
||||
from src.client.remote_model import DistributedBloomForCausalLM
|
||||
@ -89,3 +90,29 @@ def test_greedy_generation(max_new_tokens=4):
|
||||
assert torch.allclose(
|
||||
remote_outputs_batch, hf_outputs_batch
|
||||
), "Greedy search are not identical to HF in multibatch mode"
|
||||
|
||||
|
||||
@pytest.mark.forked
|
||||
def test_greedy_generation(max_new_tokens=4, num_beams=2):
|
||||
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
|
||||
model = DistributedBloomForCausalLM.from_pretrained(
|
||||
MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
|
||||
)
|
||||
inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
|
||||
remote_outputs = model.generate(
|
||||
inputs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
num_beams=num_beams,
|
||||
)
|
||||
beam_scorer = BeamSearchScorer(
|
||||
batch_size=inputs.size(0),
|
||||
num_beams=num_beams,
|
||||
device=inputs.device,
|
||||
length_penalty=0,
|
||||
do_early_stopping=False,
|
||||
num_beam_hyps_to_keep=2,
|
||||
)
|
||||
hf_outputs = BloomForCausalLM.beam_search(
|
||||
model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens, beam_scorer=beam_scorer
|
||||
)
|
||||
assert torch.allclose(remote_outputs, hf_outputs), "Beam search are not identical to HF"
|
||||
|
Loading…
Reference in New Issue
Block a user