Add Beam Search decoding algorithm (#87)

Add beam_search
pull/96/head
Artem Chumachenko 1 year ago committed by GitHub
parent fef7257fe0
commit fdb3583a8c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,10 +1,18 @@
from typing import List, Optional
import torch
import torch.nn.functional as F
from hivemind.utils.logging import get_logger
from src.utils.generation_algorithms import DecodingAlgorithm, GreedyAlgorithm, NucleusAlgorithm, TopKAlgorithm
from src.utils.generation_constraints import ABCBloomConstraint, EosConstraint, MaxNewTokensConstraint
from src.utils.generation_algorithms import (
BeamSearchAlgorithm,
DecodingAlgorithm,
GreedyAlgorithm,
NucleusAlgorithm,
TopKAlgorithm,
)
from src.utils.generation_constraints import ABCBloomConstraint, EosConstraint
logger = get_logger(__file__)
class RemoteGenerationMixin:
@ -13,8 +21,9 @@ class RemoteGenerationMixin:
The class exposes can be used for:
- *greedy decoding*.
- *multinomial sampling*.
- *beam-search decoding*
This class is similar to transformer's [`generation_utils.GenerationMixin`], it can be used instead of it. However, it has some differences.
This class is similar to transformer's [`generation_utils.GenerationMixin`], it can be used instead of it. However, it has some differences for remote usage.
"""
@torch.no_grad()
@ -25,6 +34,7 @@ class RemoteGenerationMixin:
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
num_beams: Optional[int] = 1,
bos_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
@ -32,6 +42,7 @@ class RemoteGenerationMixin:
max_new_tokens: Optional[int] = None,
decoding_algorithm: Optional[DecodingAlgorithm] = None,
provided_constraints: List[ABCBloomConstraint] = [],
num_return_sequences: Optional[int] = None,
**model_kwargs,
) -> torch.LongTensor:
"""
@ -42,6 +53,7 @@ class RemoteGenerationMixin:
:param temperature: The temperature to use for sampling.
:param top_k: The number of results to return.
:param top_p: The cumulative probability of results to return.
:param num_beams: The number of beams to use for beam search.
:param bos_token_id: The id of the beginning of sentence token.
:param eos_token_id: The id of the end of sentence token.
:param pad_token_id: The id of the padding token.
@ -49,6 +61,7 @@ class RemoteGenerationMixin:
:param decoding_algorithm: The decoding algorithm to use.
:param provided_constraints: A list of constraints to use.
:param model_kwargs: Additional arguments to pass to the model.
:param num_return_sequences: How many hypothesis from the beam will be in output.
"""
assert (
@ -69,6 +82,8 @@ class RemoteGenerationMixin:
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
batch_size = inputs.size(0)
assert (max_length is None) != (max_new_tokens is None), "please set max_length or max_new_tokens (not both)"
if max_length is not None and max_new_tokens is None:
max_new_tokens = max_length - prefix_length
@ -78,24 +93,43 @@ class RemoteGenerationMixin:
if inputs is None:
assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs"
inputs = torch.tensor([[bos_token_id]])
inputs = torch.tensor([[bos_token_id]] * num_beams, dtype=torch.long, device=self.device)
if decoding_algorithm is None:
if do_sample:
decoding_algorithm = self._choose_sample_algorithm(temperature, top_k, top_p)
elif num_beams is not None and num_beams > 1:
decoding_algorithm = BeamSearchAlgorithm(num_beams, batch_size=batch_size)
else:
decoding_algorithm = GreedyAlgorithm()
if num_beams > 1:
inputs = torch.cat([inputs] * num_beams, dim=0)
if batch_size > 1:
# TODO: resolve padding problem
logger.warning(
f"You set batch_size {batch_size} within beam search generation. Be careful, results on sequences with different length may be padded wrong way"
)
if num_return_sequences is None:
num_return_sequences = 1
assert num_return_sequences <= num_beams, (
f"You want more sequences than the beam has."
" Check num_return_sequences: {num_return_sequences} and num_beams: {num_beams}."
)
constraints = self._get_constraints(
inputs=inputs,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
max_new_tokens=max_new_tokens,
provided_constraints=provided_constraints,
)
with self.transformer.h.inference_session(max_length=max_length) as sess:
outputs = []
# Find samples with padded inputs.
# They will be changed before all of the samples have right length.
if torch.any(inputs == pad_token_id): # TODO: move to prepare_inputs
outputs += [inputs[:, : inputs.size(1) - (inputs == pad_token_id).sum(-1).max()]]
else:
@ -117,19 +151,34 @@ class RemoteGenerationMixin:
for constraint in constraints:
lm_logits = constraint(last_token_id, lm_logits, hypo_ids)
last_token_id, hypo_ids = decoding_algorithm(lm_logits)
if seq_idx < inputs.size(1): # TODO: why is it not a constraint?
# If some samples were padded, change only these samples
if seq_idx < inputs.size(1):
pad_token_mask = inputs[:, seq_idx : seq_idx + 1] == pad_token_id
last_token_id = (~pad_token_mask) * inputs[
:, seq_idx : seq_idx + 1
] + pad_token_mask * last_token_id
if torch.all(last_token_id == eos_token_id):
break
# TODO: refactor outputs
if num_beams > 1:
for i in range(len(outputs), 1, -1):
outputs[i - 1] = outputs[i - 1][hypo_ids]
outputs.append(last_token_id)
seq_idx += 1
if torch.all(last_token_id == eos_token_id) or len(outputs) > max_new_tokens:
break
outputs = torch.cat(outputs, dim=-1)
return torch.cat(outputs, dim=-1)
if num_beams > 1:
pre_return_idx = [
torch.arange(idx, num_return_sequences * batch_size, batch_size) for idx in range(batch_size)
]
return_idx = torch.cat(pre_return_idx, dim=0)
outputs = outputs[return_idx]
return outputs
def greedy_search(
self,
@ -198,13 +247,38 @@ class RemoteGenerationMixin:
def beam_search(
self,
input_ids: torch.LongTensor,
num_beams: int = 1,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
provided_constraints: List[ABCBloomConstraint] = [],
**model_kwargs,
) -> torch.LongTensor:
raise NotImplementedError
"""
Generates sequences of token ids for models with a language modeling head. Uses beam search.
:param input_ids: The input tokens to the model.
:param num_beams: The number of beams to use.
:param max_length: The maximum length of the sequence to generate.
:param pad_token_id: The id of the padding token.
:param eos_token_id: The id of the end of sentence token.
:param provided_constraints: A list of constraints to use.
:param: model_kwargs: Additional kwargs to pass to the model.
"""
decoding_algorithm = BeamSearchAlgorithm(
num_beams=num_beams,
batch_size=input_ids.size(0),
)
return self.generate(
inputs=input_ids,
num_beams=num_beams,
max_new_tokens=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
decoding_algorithm=decoding_algorithm,
provided_constraints=provided_constraints,
**model_kwargs,
)
def beam_sample(
self,
@ -246,12 +320,9 @@ class RemoteGenerationMixin:
inputs: Optional[torch.Tensor] = None,
eos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
max_new_tokens: Optional[int] = None,
provided_constraints: List[ABCBloomConstraint] = [],
) -> List[ABCBloomConstraint]:
constraints = []
constraints.extend(provided_constraints)
if max_new_tokens is not None:
constraints.append(MaxNewTokensConstraint(inputs, max_new_tokens, eos_token_id, pad_token_id))
constraints.append(EosConstraint(inputs, eos_token_id, pad_token_id))
return constraints

@ -59,6 +59,7 @@ class TransformerBackend(ModuleBackend):
with self.memory_cache.use_cache(attention_cache_handle) as cache:
assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5
if not is_dummy(hypo_ids):
assert hypo_ids.shape[0] == cache.shape[1]
cache[:, :] = cache[:, hypo_ids] # in-place reorder cache by hypo ids
layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length]
logger.debug(f"Metadata: {cache_metadata}, past_k.shape={past_k.shape}, past_v.shape={past_v.shape}")

@ -48,7 +48,6 @@ class SamplingAlgorithm(DecodingAlgorithm):
class TopKAlgorithm(SamplingAlgorithm):
# TODO: Add NumHypos, maxBatchSize
def __init__(self, top_k: int, temperature: float = 1.0) -> None:
self.top_k = top_k
self.temperature = temperature
@ -75,4 +74,48 @@ class NucleusAlgorithm(SamplingAlgorithm):
return self.sample(logits, indices_to_remove)
# TODO: In generate function we need to check usage of top_k or sampling algorithm
class BeamSearchAlgorithm(DecodingAlgorithm):
def __init__(self, num_beams: int, batch_size: int) -> None:
self.num_beams = num_beams
self._cur_num_beams = 1
self.batch_size = batch_size
self._batch_beams = [list() for _ in range(batch_size)]
def __call__(self, logits: torch.Tensor):
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
probs = torch.log_softmax(sorted_logits, -1)
if len(self._batch_beams[0]) > 0:
for batch_idx in range(self.batch_size):
new_beams = []
cur_beams = self._batch_beams[batch_idx]
for beam_idx in range(len(cur_beams)):
probs_idx = batch_idx + beam_idx * self.batch_size
new_beam = cur_beams[beam_idx]
for hypo_idx in range(self.num_beams):
new_beams.append(
(new_beam[0] + probs[probs_idx, hypo_idx].item(), beam_idx * self.num_beams + hypo_idx)
)
self._batch_beams[batch_idx] = sorted(new_beams, reverse=True)[: self.num_beams]
else:
for batch_idx in range(self.batch_size):
for beam_idx in range(self.num_beams):
self._batch_beams[batch_idx].append((probs[batch_idx, beam_idx].item(), beam_idx))
return_hypos = []
return_tokens = []
for batch_idx in range(self.batch_size):
cur_beam = self._batch_beams[batch_idx]
return_hypos.append(list())
return_tokens.append(list())
for beam in cur_beam:
beam_idx = beam[1] // self.num_beams
hypo_idx = batch_idx + beam_idx * self.batch_size
token_idx = beam[1] % self.num_beams
return_hypos[-1].append(hypo_idx)
return_tokens[-1].append([sorted_indices[hypo_idx, token_idx].item()])
return_hypos = [hypo_idx for hypo_indexes in zip(*return_hypos) for hypo_idx in hypo_indexes]
return_tokens = [token_idx for token_indexes in zip(*return_tokens) for token_idx in token_indexes]
return torch.tensor(return_tokens), torch.tensor(return_hypos)

@ -21,39 +21,6 @@ class ABCBloomConstraint(ABC):
pass
class MaxNewTokensConstraint(ABCBloomConstraint):
"""
Constraint that forbids to generate more than max_new_tokens tokens after the prefix.
Args:
prefix: The prefix of the sequence.
max_new_tokens: The maximum number of tokens that can be generated after the prefix.
eos_token_id: The id of the end of sentence token.
pad_token_id: The id of the padding token.
min_logits: The minimum logits that can be generated. Default: -1e6.
"""
def __init__(
self, prefix: torch.Tensor, max_new_tokens: int, eos_token_id: int, pad_token_id: int, min_logits: float = -1e8
) -> None:
self.max_new_tokens = max_new_tokens
self.current_generated_tokens = None
self.eos_token_id = eos_token_id
self.min_logits = min_logits
max_pad_size = (prefix == pad_token_id).sum(1).unsqueeze(1).max()
self.current_generated_tokens = (prefix == pad_token_id).sum(1).unsqueeze(1) - max_pad_size
def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor:
if tokens_id is not None:
self.current_generated_tokens += 1
mask = self.current_generated_tokens >= self.max_new_tokens
logits += self.min_logits * mask
logits[mask[:, 0], self.eos_token_id] = 0
return logits
class EosConstraint(ABCBloomConstraint):
"""
This constrained repeats EOS token if it was generated on the previous step.

@ -3,6 +3,7 @@ import torch
import transformers
from hivemind import get_logger, use_hivemind_log_handler
from test_utils import *
from transformers.generation_utils import BeamSearchScorer
from src.bloom.model import BloomForCausalLM
from src.client.remote_model import DistributedBloomForCausalLM
@ -89,3 +90,30 @@ def test_greedy_generation(max_new_tokens=4):
assert torch.allclose(
remote_outputs_batch, hf_outputs_batch
), "Greedy search are not identical to HF in multibatch mode"
@pytest.mark.forked
def test_beam_search_generation(max_new_tokens=4, num_beams=2):
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
model = DistributedBloomForCausalLM.from_pretrained(
MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
)
text = "A cat sat on a mat"
inputs = tokenizer(text, return_tensors="pt")["input_ids"]
remote_outputs = model.generate(
inputs,
max_new_tokens=max_new_tokens,
num_beams=num_beams,
)
beam_scorer = BeamSearchScorer(
batch_size=inputs.size(0),
num_beams=num_beams,
device=inputs.device,
length_penalty=0,
do_early_stopping=False,
)
hf_inputs = tokenizer([text] * 2, return_tensors="pt")["input_ids"]
hf_outputs = BloomForCausalLM.beam_search(
model, input_ids=hf_inputs, max_length=inputs.size(1) + max_new_tokens, beam_scorer=beam_scorer
)
assert torch.allclose(remote_outputs, hf_outputs), "Beam search results are not identical to HF"

Loading…
Cancel
Save