From bd91be27ea701e87d4d6925e20f856460b5da14f Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Tue, 13 Dec 2022 21:09:15 +0400 Subject: [PATCH] Add missing methods for SamplingAlgorithm, fix docstrings (#107) * Add missing methods for SamplingAlgorithm, fix docstrings * Add SamplingAlgorithm to _choose_sample_algorithm * Add test_sampling * Add a warning if sampling options were passed, but do_sample=False * Skip the sampling test for now Co-authored-by: Alexander Borzunov --- src/petals/client/remote_generation.py | 14 +++++-- src/petals/utils/generation_algorithms.py | 40 ++++++++++-------- tests/test_full_model.py | 50 ++++++++++++++++++++++- 3 files changed, 82 insertions(+), 22 deletions(-) diff --git a/src/petals/client/remote_generation.py b/src/petals/client/remote_generation.py index 1a736ce..053e209 100644 --- a/src/petals/client/remote_generation.py +++ b/src/petals/client/remote_generation.py @@ -10,6 +10,7 @@ from petals.utils.generation_algorithms import ( DecodingAlgorithm, GreedyAlgorithm, NucleusAlgorithm, + SamplingAlgorithm, TopKAlgorithm, ) from petals.utils.generation_constraints import ABCBloomConstraint, EosConstraint @@ -22,7 +23,7 @@ class RemoteGenerationMixin: A class containing all functions for auto-regressive text generation, to be used as a mixin in [`BloomForCausalLM`]. The class exposes can be used for: - *greedy decoding*. - - *multinomial sampling*. + - *multinomial, top-k and top-p sampling*. - *beam-search decoding* This class is similar to transformer's [`generation_utils.GenerationMixin`], it can be used instead of it. @@ -126,6 +127,8 @@ class RemoteGenerationMixin: elif num_beams is not None and num_beams > 1: decoding_algorithm = BeamSearchAlgorithm(num_beams, batch_size=batch_size) else: + if top_k is not None or top_p is not None: + logger.warning("You passed top_k or top_p but did pass do_sample=True. Running greedy sampling") decoding_algorithm = GreedyAlgorithm() if num_beams > 1: @@ -252,7 +255,8 @@ class RemoteGenerationMixin: **model_kwargs, ) -> torch.LongTensor: """ - Generates sequences of token ids for models with a language modeling head. Uses sampling. Uses multinomial sampling algorithm. If top_k is provided, uses top_k sampling. If top_p is provided, uses nucleus sampling. + Generates sequences of token ids for models with a language modeling head. Uses multinomial sampling. + If top_k is provided, uses top_k sampling. If top_p is provided, uses nucleus sampling. :param: input_ids: The input tokens to the model. :param: temperature: The temperature to use for sampling. @@ -341,10 +345,12 @@ class RemoteGenerationMixin: ) -> DecodingAlgorithm: if (top_k is not None) and (top_p is not None): raise ValueError("You have to provide only top_k or top_p for sampling") - if top_k: + if top_k is not None: return TopKAlgorithm(top_k, temperature) - elif top_p: + elif top_p is not None: return NucleusAlgorithm(top_p, temperature) + else: + return SamplingAlgorithm(temperature) def _get_constraints( self, diff --git a/src/petals/utils/generation_algorithms.py b/src/petals/utils/generation_algorithms.py index 43c8d34..9033371 100644 --- a/src/petals/utils/generation_algorithms.py +++ b/src/petals/utils/generation_algorithms.py @@ -1,4 +1,4 @@ -from abc import ABC +from abc import ABC, abstractmethod from typing import Tuple import torch @@ -9,16 +9,16 @@ HypoIds = torch.Tensor class DecodingAlgorithm(ABC): """ - An abstract class for decoding algorithms. Describe base function of those algorithms: they have to select new tokens and provide the corresponding hypothesis. + An abstract class for decoding algorithms. Describes the base function of those algorithms: + they have to select new tokens and provide the corresponding hypotheses. """ - def __init__(self) -> None: - pass - + @abstractmethod def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]: """ :param logits: A tensor of shape (batch_size, seq_lenth, vocab_size) - :return: A tuple of selected token ids and corresponding hypothesis. The shape of the token ids is (batch_size, seq_length) and the shape of the hypothesis is (batch_size) + :return: A tuple of selected token ids and corresponding hypotheses. + The shape of the token ids is (batch_size, seq_length), and the shape of the hypotheses is (batch_size) """ pass @@ -30,27 +30,36 @@ class GreedyAlgorithm(DecodingAlgorithm): def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]: """ - Returns the most propable token. The second return object always are range of integers from 0 to batch_size - 1. + Returns the most probable token. The second returned object is always a range of integers + from 0 to batch_size - 1. """ return logits.max(-1)[1].unsqueeze(1), torch.arange(logits.size(0)) class SamplingAlgorithm(DecodingAlgorithm): + def __init__(self, temperature: float = 1.0): + self.temperature = temperature + def sample(self, logits: torch.Tensor, indices_to_remove: torch.Tensor) -> Tuple[TokenIds, HypoIds]: """ :param logits: A tensor of shape (batch_size * num_hypos, vocab_size) :param indices_to_remove: A bool tensor of shape (batch_size * num_hypos, vocab_size) - :return: A tuple of selected token ids and corresponding hypothesis. The shape of the token ids is (batch_size, seq_length) and the shape of the hypothesis is (batch_size). + :return: A tuple of selected token ids and corresponding hypotheses. + The shape of the token ids is (batch_size, seq_length), and the shape of the hypotheses is (batch_size). """ logits[indices_to_remove] = -float("Inf") probs = torch.softmax(logits / self.temperature, -1) return torch.multinomial(probs, num_samples=1), torch.arange(logits.size(0)) + def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]: + indices_to_remove = torch.full_like(logits, False, dtype=torch.bool) + return self.sample(logits, indices_to_remove) + class TopKAlgorithm(SamplingAlgorithm): def __init__(self, top_k: int, temperature: float = 1.0) -> None: + super().__init__(temperature=temperature) self.top_k = top_k - self.temperature = temperature def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]: indices_to_remove = logits < torch.topk(logits, self.top_k, dim=-1)[0][..., -1, None] @@ -59,18 +68,17 @@ class TopKAlgorithm(SamplingAlgorithm): class NucleusAlgorithm(SamplingAlgorithm): def __init__(self, top_p: float, temperature: float = 1.0) -> None: + super().__init__(temperature=temperature) self.top_p = top_p - self.temperature = temperature def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]: - sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) + sorted_logits, sorted_indices = torch.sort(logits, descending=False, dim=-1) probs = torch.softmax(sorted_logits / self.temperature, -1) cumulative_probs = torch.cumsum(probs, dim=-1) - sorted_indices_to_remove = cumulative_probs > self.top_p - sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() - sorted_indices_to_remove[..., 0] = False - indices_to_remove = torch.zeros_like(sorted_indices_to_remove) - indices_to_remove.scatter_(-1, sorted_indices, sorted_indices_to_remove) + + sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p) + + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) return self.sample(logits, indices_to_remove) diff --git a/tests/test_full_model.py b/tests/test_full_model.py index 710ff33..e3c4730 100644 --- a/tests/test_full_model.py +++ b/tests/test_full_model.py @@ -83,7 +83,7 @@ def test_greedy_generation(max_new_tokens=4): max_new_tokens=max_new_tokens, ) hf_outputs = BloomForCausalLM.greedy_search(model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens) - assert torch.allclose(remote_outputs, hf_outputs), "Greedy search are not identical to HF" + assert torch.allclose(remote_outputs, hf_outputs), "Greedy search results are not identical to HF" inputs_batch = tokenizer(["A cat sat on a mat", "A dog sat on a mat"], return_tensors="pt", padding=True)[ "input_ids" @@ -97,7 +97,53 @@ def test_greedy_generation(max_new_tokens=4): ) assert torch.allclose( remote_outputs_batch, hf_outputs_batch - ), "Greedy search are not identical to HF in multibatch mode" + ), "Greedy search results are not identical to HF in multibatch mode" + + +@pytest.mark.forked +@pytest.mark.parametrize("sampling_options", [dict(), dict(temperature=100.0), dict(top_k=5), dict(top_p=0.9)]) +@pytest.mark.skip("Sampling is currently not consistent with outputs from Transformers") +def test_sampling(sampling_options, max_new_tokens=4): + torch.manual_seed(0) + tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME) + model = DistributedBloomForCausalLM.from_pretrained( + MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32 + ) + logits_warper = BloomForCausalLM._get_logits_warper(model, num_beams=1, **sampling_options) + inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"] + with torch.random.fork_rng(): + remote_outputs = model.generate( + inputs, + max_new_tokens=max_new_tokens, + do_sample=True, + **sampling_options, + ) + with torch.random.fork_rng(): + hf_outputs = BloomForCausalLM.sample( + model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens, logits_warper=logits_warper + ) + assert torch.allclose(remote_outputs, hf_outputs), "Sampling results are not identical to HF" + + inputs_batch = tokenizer(["A cat sat on a mat", "A dog sat on a mat"], return_tensors="pt", padding=True)[ + "input_ids" + ] + with torch.random.fork_rng(): + remote_outputs_batch = model.generate( + inputs_batch, + max_new_tokens=max_new_tokens, + do_sample=True, + **sampling_options, + ) + with torch.random.fork_rng(): + hf_outputs_batch = BloomForCausalLM.sample( + model, + input_ids=inputs_batch, + max_length=inputs_batch.size(1) + max_new_tokens, + logits_warper=logits_warper, + ) + assert torch.allclose( + remote_outputs_batch, hf_outputs_batch + ), "Sampling results are not identical to HF in multibatch mode" @pytest.mark.forked