diff --git a/src/client/remote_generation.py b/src/client/remote_generation.py index d2be2c9..e2da719 100644 --- a/src/client/remote_generation.py +++ b/src/client/remote_generation.py @@ -1,10 +1,18 @@ from typing import List, Optional import torch -import torch.nn.functional as F +from hivemind.utils.logging import get_logger -from src.utils.generation_algorithms import DecodingAlgorithm, GreedyAlgorithm, NucleusAlgorithm, TopKAlgorithm -from src.utils.generation_constraints import ABCBloomConstraint, EosConstraint, MaxNewTokensConstraint +from src.utils.generation_algorithms import ( + BeamSearchAlgorithm, + DecodingAlgorithm, + GreedyAlgorithm, + NucleusAlgorithm, + TopKAlgorithm, +) +from src.utils.generation_constraints import ABCBloomConstraint, EosConstraint + +logger = get_logger(__file__) class RemoteGenerationMixin: @@ -13,8 +21,9 @@ class RemoteGenerationMixin: The class exposes can be used for: - *greedy decoding*. - *multinomial sampling*. + - *beam-search decoding* - This class is similar to transformer's [`generation_utils.GenerationMixin`], it can be used instead of it. However, it has some differences. + This class is similar to transformer's [`generation_utils.GenerationMixin`], it can be used instead of it. However, it has some differences for remote usage. """ @torch.no_grad() @@ -25,6 +34,7 @@ class RemoteGenerationMixin: temperature: float = 1.0, top_k: Optional[int] = None, top_p: Optional[float] = None, + num_beams: Optional[int] = 1, bos_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, pad_token_id: Optional[int] = None, @@ -32,6 +42,7 @@ class RemoteGenerationMixin: max_new_tokens: Optional[int] = None, decoding_algorithm: Optional[DecodingAlgorithm] = None, provided_constraints: List[ABCBloomConstraint] = [], + num_return_sequences: Optional[int] = None, **model_kwargs, ) -> torch.LongTensor: """ @@ -42,6 +53,7 @@ class RemoteGenerationMixin: :param temperature: The temperature to use for sampling. :param top_k: The number of results to return. :param top_p: The cumulative probability of results to return. + :param num_beams: The number of beams to use for beam search. :param bos_token_id: The id of the beginning of sentence token. :param eos_token_id: The id of the end of sentence token. :param pad_token_id: The id of the padding token. @@ -49,6 +61,7 @@ class RemoteGenerationMixin: :param decoding_algorithm: The decoding algorithm to use. :param provided_constraints: A list of constraints to use. :param model_kwargs: Additional arguments to pass to the model. + :param num_return_sequences: How many hypothesis from the beam will be in output. """ assert ( @@ -69,6 +82,8 @@ class RemoteGenerationMixin: pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + batch_size = inputs.size(0) + assert (max_length is None) != (max_new_tokens is None), "please set max_length or max_new_tokens (not both)" if max_length is not None and max_new_tokens is None: max_new_tokens = max_length - prefix_length @@ -78,24 +93,43 @@ class RemoteGenerationMixin: if inputs is None: assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs" - inputs = torch.tensor([[bos_token_id]]) + inputs = torch.tensor([[bos_token_id]] * num_beams, dtype=torch.long, device=self.device) if decoding_algorithm is None: if do_sample: decoding_algorithm = self._choose_sample_algorithm(temperature, top_k, top_p) + elif num_beams is not None and num_beams > 1: + decoding_algorithm = BeamSearchAlgorithm(num_beams, batch_size=batch_size) else: decoding_algorithm = GreedyAlgorithm() + if num_beams > 1: + inputs = torch.cat([inputs] * num_beams, dim=0) + if batch_size > 1: + # TODO: resolve padding problem + logger.warning( + f"You set batch_size {batch_size} within beam search generation. Be careful, results on sequences with different length may be padded wrong way" + ) + + if num_return_sequences is None: + num_return_sequences = 1 + + assert num_return_sequences <= num_beams, ( + f"You want more sequences than the beam has." + " Check num_return_sequences: {num_return_sequences} and num_beams: {num_beams}." + ) + constraints = self._get_constraints( inputs=inputs, eos_token_id=eos_token_id, pad_token_id=pad_token_id, - max_new_tokens=max_new_tokens, provided_constraints=provided_constraints, ) with self.transformer.h.inference_session(max_length=max_length) as sess: outputs = [] + # Find samples with padded inputs. + # They will be changed before all of the samples have right length. if torch.any(inputs == pad_token_id): # TODO: move to prepare_inputs outputs += [inputs[:, : inputs.size(1) - (inputs == pad_token_id).sum(-1).max()]] else: @@ -117,19 +151,34 @@ class RemoteGenerationMixin: for constraint in constraints: lm_logits = constraint(last_token_id, lm_logits, hypo_ids) last_token_id, hypo_ids = decoding_algorithm(lm_logits) - if seq_idx < inputs.size(1): # TODO: why is it not a constraint? + + # If some samples were padded, change only these samples + if seq_idx < inputs.size(1): pad_token_mask = inputs[:, seq_idx : seq_idx + 1] == pad_token_id last_token_id = (~pad_token_mask) * inputs[ :, seq_idx : seq_idx + 1 ] + pad_token_mask * last_token_id - if torch.all(last_token_id == eos_token_id): - break + # TODO: refactor outputs + if num_beams > 1: + for i in range(len(outputs), 1, -1): + outputs[i - 1] = outputs[i - 1][hypo_ids] outputs.append(last_token_id) seq_idx += 1 + if torch.all(last_token_id == eos_token_id) or len(outputs) > max_new_tokens: + break + + outputs = torch.cat(outputs, dim=-1) - return torch.cat(outputs, dim=-1) + if num_beams > 1: + pre_return_idx = [ + torch.arange(idx, num_return_sequences * batch_size, batch_size) for idx in range(batch_size) + ] + return_idx = torch.cat(pre_return_idx, dim=0) + outputs = outputs[return_idx] + + return outputs def greedy_search( self, @@ -198,13 +247,38 @@ class RemoteGenerationMixin: def beam_search( self, input_ids: torch.LongTensor, + num_beams: int = 1, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, provided_constraints: List[ABCBloomConstraint] = [], **model_kwargs, ) -> torch.LongTensor: - raise NotImplementedError + """ + Generates sequences of token ids for models with a language modeling head. Uses beam search. + + :param input_ids: The input tokens to the model. + :param num_beams: The number of beams to use. + :param max_length: The maximum length of the sequence to generate. + :param pad_token_id: The id of the padding token. + :param eos_token_id: The id of the end of sentence token. + :param provided_constraints: A list of constraints to use. + :param: model_kwargs: Additional kwargs to pass to the model. + """ + decoding_algorithm = BeamSearchAlgorithm( + num_beams=num_beams, + batch_size=input_ids.size(0), + ) + return self.generate( + inputs=input_ids, + num_beams=num_beams, + max_new_tokens=max_length, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + decoding_algorithm=decoding_algorithm, + provided_constraints=provided_constraints, + **model_kwargs, + ) def beam_sample( self, @@ -246,12 +320,9 @@ class RemoteGenerationMixin: inputs: Optional[torch.Tensor] = None, eos_token_id: Optional[int] = None, pad_token_id: Optional[int] = None, - max_new_tokens: Optional[int] = None, provided_constraints: List[ABCBloomConstraint] = [], ) -> List[ABCBloomConstraint]: constraints = [] constraints.extend(provided_constraints) - if max_new_tokens is not None: - constraints.append(MaxNewTokensConstraint(inputs, max_new_tokens, eos_token_id, pad_token_id)) constraints.append(EosConstraint(inputs, eos_token_id, pad_token_id)) return constraints diff --git a/src/server/backend.py b/src/server/backend.py index 00f55dd..a8a1a28 100644 --- a/src/server/backend.py +++ b/src/server/backend.py @@ -59,6 +59,7 @@ class TransformerBackend(ModuleBackend): with self.memory_cache.use_cache(attention_cache_handle) as cache: assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5 if not is_dummy(hypo_ids): + assert hypo_ids.shape[0] == cache.shape[1] cache[:, :] = cache[:, hypo_ids] # in-place reorder cache by hypo ids layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length] logger.debug(f"Metadata: {cache_metadata}, past_k.shape={past_k.shape}, past_v.shape={past_v.shape}") diff --git a/src/utils/generation_algorithms.py b/src/utils/generation_algorithms.py index 8507a49..399eb8d 100644 --- a/src/utils/generation_algorithms.py +++ b/src/utils/generation_algorithms.py @@ -48,7 +48,6 @@ class SamplingAlgorithm(DecodingAlgorithm): class TopKAlgorithm(SamplingAlgorithm): - # TODO: Add NumHypos, maxBatchSize def __init__(self, top_k: int, temperature: float = 1.0) -> None: self.top_k = top_k self.temperature = temperature @@ -75,4 +74,48 @@ class NucleusAlgorithm(SamplingAlgorithm): return self.sample(logits, indices_to_remove) -# TODO: In generate function we need to check usage of top_k or sampling algorithm +class BeamSearchAlgorithm(DecodingAlgorithm): + def __init__(self, num_beams: int, batch_size: int) -> None: + self.num_beams = num_beams + self._cur_num_beams = 1 + self.batch_size = batch_size + + self._batch_beams = [list() for _ in range(batch_size)] + + def __call__(self, logits: torch.Tensor): + sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) + probs = torch.log_softmax(sorted_logits, -1) + + if len(self._batch_beams[0]) > 0: + for batch_idx in range(self.batch_size): + new_beams = [] + cur_beams = self._batch_beams[batch_idx] + for beam_idx in range(len(cur_beams)): + probs_idx = batch_idx + beam_idx * self.batch_size + new_beam = cur_beams[beam_idx] + for hypo_idx in range(self.num_beams): + new_beams.append( + (new_beam[0] + probs[probs_idx, hypo_idx].item(), beam_idx * self.num_beams + hypo_idx) + ) + self._batch_beams[batch_idx] = sorted(new_beams, reverse=True)[: self.num_beams] + else: + for batch_idx in range(self.batch_size): + for beam_idx in range(self.num_beams): + self._batch_beams[batch_idx].append((probs[batch_idx, beam_idx].item(), beam_idx)) + + return_hypos = [] + return_tokens = [] + for batch_idx in range(self.batch_size): + cur_beam = self._batch_beams[batch_idx] + return_hypos.append(list()) + return_tokens.append(list()) + for beam in cur_beam: + beam_idx = beam[1] // self.num_beams + hypo_idx = batch_idx + beam_idx * self.batch_size + token_idx = beam[1] % self.num_beams + return_hypos[-1].append(hypo_idx) + return_tokens[-1].append([sorted_indices[hypo_idx, token_idx].item()]) + return_hypos = [hypo_idx for hypo_indexes in zip(*return_hypos) for hypo_idx in hypo_indexes] + return_tokens = [token_idx for token_indexes in zip(*return_tokens) for token_idx in token_indexes] + + return torch.tensor(return_tokens), torch.tensor(return_hypos) diff --git a/src/utils/generation_constraints.py b/src/utils/generation_constraints.py index 72c526f..c48bde8 100644 --- a/src/utils/generation_constraints.py +++ b/src/utils/generation_constraints.py @@ -21,39 +21,6 @@ class ABCBloomConstraint(ABC): pass -class MaxNewTokensConstraint(ABCBloomConstraint): - """ - Constraint that forbids to generate more than max_new_tokens tokens after the prefix. - - Args: - prefix: The prefix of the sequence. - max_new_tokens: The maximum number of tokens that can be generated after the prefix. - eos_token_id: The id of the end of sentence token. - pad_token_id: The id of the padding token. - min_logits: The minimum logits that can be generated. Default: -1e6. - """ - - def __init__( - self, prefix: torch.Tensor, max_new_tokens: int, eos_token_id: int, pad_token_id: int, min_logits: float = -1e8 - ) -> None: - self.max_new_tokens = max_new_tokens - self.current_generated_tokens = None - self.eos_token_id = eos_token_id - self.min_logits = min_logits - - max_pad_size = (prefix == pad_token_id).sum(1).unsqueeze(1).max() - self.current_generated_tokens = (prefix == pad_token_id).sum(1).unsqueeze(1) - max_pad_size - - def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor: - if tokens_id is not None: - self.current_generated_tokens += 1 - - mask = self.current_generated_tokens >= self.max_new_tokens - logits += self.min_logits * mask - logits[mask[:, 0], self.eos_token_id] = 0 - return logits - - class EosConstraint(ABCBloomConstraint): """ This constrained repeats EOS token if it was generated on the previous step. diff --git a/tests/test_full_model.py b/tests/test_full_model.py index b0ce824..b81e9ca 100644 --- a/tests/test_full_model.py +++ b/tests/test_full_model.py @@ -3,6 +3,7 @@ import torch import transformers from hivemind import get_logger, use_hivemind_log_handler from test_utils import * +from transformers.generation_utils import BeamSearchScorer from src.bloom.model import BloomForCausalLM from src.client.remote_model import DistributedBloomForCausalLM @@ -89,3 +90,30 @@ def test_greedy_generation(max_new_tokens=4): assert torch.allclose( remote_outputs_batch, hf_outputs_batch ), "Greedy search are not identical to HF in multibatch mode" + + +@pytest.mark.forked +def test_beam_search_generation(max_new_tokens=4, num_beams=2): + tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME) + model = DistributedBloomForCausalLM.from_pretrained( + MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32 + ) + text = "A cat sat on a mat" + inputs = tokenizer(text, return_tensors="pt")["input_ids"] + remote_outputs = model.generate( + inputs, + max_new_tokens=max_new_tokens, + num_beams=num_beams, + ) + beam_scorer = BeamSearchScorer( + batch_size=inputs.size(0), + num_beams=num_beams, + device=inputs.device, + length_penalty=0, + do_early_stopping=False, + ) + hf_inputs = tokenizer([text] * 2, return_tensors="pt")["input_ids"] + hf_outputs = BloomForCausalLM.beam_search( + model, input_ids=hf_inputs, max_length=inputs.size(1) + max_new_tokens, beam_scorer=beam_scorer + ) + assert torch.allclose(remote_outputs, hf_outputs), "Beam search results are not identical to HF"