Add missing methods for SamplingAlgorithm, fix docstrings (#107)

* Add missing methods for SamplingAlgorithm, fix docstrings

* Add SamplingAlgorithm to _choose_sample_algorithm

* Add test_sampling

* Add a warning if sampling options were passed, but do_sample=False

* Skip the sampling test for now

Co-authored-by: Alexander Borzunov <borzunov.alexander@gmail.com>
pull/156/head
Max Ryabinin 1 year ago committed by GitHub
parent a0e8bbd28d
commit bd91be27ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -10,6 +10,7 @@ from petals.utils.generation_algorithms import (
DecodingAlgorithm,
GreedyAlgorithm,
NucleusAlgorithm,
SamplingAlgorithm,
TopKAlgorithm,
)
from petals.utils.generation_constraints import ABCBloomConstraint, EosConstraint
@ -22,7 +23,7 @@ class RemoteGenerationMixin:
A class containing all functions for auto-regressive text generation, to be used as a mixin in [`BloomForCausalLM`].
The class exposes can be used for:
- *greedy decoding*.
- *multinomial sampling*.
- *multinomial, top-k and top-p sampling*.
- *beam-search decoding*
This class is similar to transformer's [`generation_utils.GenerationMixin`], it can be used instead of it.
@ -126,6 +127,8 @@ class RemoteGenerationMixin:
elif num_beams is not None and num_beams > 1:
decoding_algorithm = BeamSearchAlgorithm(num_beams, batch_size=batch_size)
else:
if top_k is not None or top_p is not None:
logger.warning("You passed top_k or top_p but did pass do_sample=True. Running greedy sampling")
decoding_algorithm = GreedyAlgorithm()
if num_beams > 1:
@ -252,7 +255,8 @@ class RemoteGenerationMixin:
**model_kwargs,
) -> torch.LongTensor:
"""
Generates sequences of token ids for models with a language modeling head. Uses sampling. Uses multinomial sampling algorithm. If top_k is provided, uses top_k sampling. If top_p is provided, uses nucleus sampling.
Generates sequences of token ids for models with a language modeling head. Uses multinomial sampling.
If top_k is provided, uses top_k sampling. If top_p is provided, uses nucleus sampling.
:param: input_ids: The input tokens to the model.
:param: temperature: The temperature to use for sampling.
@ -341,10 +345,12 @@ class RemoteGenerationMixin:
) -> DecodingAlgorithm:
if (top_k is not None) and (top_p is not None):
raise ValueError("You have to provide only top_k or top_p for sampling")
if top_k:
if top_k is not None:
return TopKAlgorithm(top_k, temperature)
elif top_p:
elif top_p is not None:
return NucleusAlgorithm(top_p, temperature)
else:
return SamplingAlgorithm(temperature)
def _get_constraints(
self,

@ -1,4 +1,4 @@
from abc import ABC
from abc import ABC, abstractmethod
from typing import Tuple
import torch
@ -9,16 +9,16 @@ HypoIds = torch.Tensor
class DecodingAlgorithm(ABC):
"""
An abstract class for decoding algorithms. Describe base function of those algorithms: they have to select new tokens and provide the corresponding hypothesis.
An abstract class for decoding algorithms. Describes the base function of those algorithms:
they have to select new tokens and provide the corresponding hypotheses.
"""
def __init__(self) -> None:
pass
@abstractmethod
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
"""
:param logits: A tensor of shape (batch_size, seq_lenth, vocab_size)
:return: A tuple of selected token ids and corresponding hypothesis. The shape of the token ids is (batch_size, seq_length) and the shape of the hypothesis is (batch_size)
:return: A tuple of selected token ids and corresponding hypotheses.
The shape of the token ids is (batch_size, seq_length), and the shape of the hypotheses is (batch_size)
"""
pass
@ -30,27 +30,36 @@ class GreedyAlgorithm(DecodingAlgorithm):
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
"""
Returns the most propable token. The second return object always are range of integers from 0 to batch_size - 1.
Returns the most probable token. The second returned object is always a range of integers
from 0 to batch_size - 1.
"""
return logits.max(-1)[1].unsqueeze(1), torch.arange(logits.size(0))
class SamplingAlgorithm(DecodingAlgorithm):
def __init__(self, temperature: float = 1.0):
self.temperature = temperature
def sample(self, logits: torch.Tensor, indices_to_remove: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
"""
:param logits: A tensor of shape (batch_size * num_hypos, vocab_size)
:param indices_to_remove: A bool tensor of shape (batch_size * num_hypos, vocab_size)
:return: A tuple of selected token ids and corresponding hypothesis. The shape of the token ids is (batch_size, seq_length) and the shape of the hypothesis is (batch_size).
:return: A tuple of selected token ids and corresponding hypotheses.
The shape of the token ids is (batch_size, seq_length), and the shape of the hypotheses is (batch_size).
"""
logits[indices_to_remove] = -float("Inf")
probs = torch.softmax(logits / self.temperature, -1)
return torch.multinomial(probs, num_samples=1), torch.arange(logits.size(0))
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
indices_to_remove = torch.full_like(logits, False, dtype=torch.bool)
return self.sample(logits, indices_to_remove)
class TopKAlgorithm(SamplingAlgorithm):
def __init__(self, top_k: int, temperature: float = 1.0) -> None:
super().__init__(temperature=temperature)
self.top_k = top_k
self.temperature = temperature
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
indices_to_remove = logits < torch.topk(logits, self.top_k, dim=-1)[0][..., -1, None]
@ -59,18 +68,17 @@ class TopKAlgorithm(SamplingAlgorithm):
class NucleusAlgorithm(SamplingAlgorithm):
def __init__(self, top_p: float, temperature: float = 1.0) -> None:
super().__init__(temperature=temperature)
self.top_p = top_p
self.temperature = temperature
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
sorted_logits, sorted_indices = torch.sort(logits, descending=False, dim=-1)
probs = torch.softmax(sorted_logits / self.temperature, -1)
cumulative_probs = torch.cumsum(probs, dim=-1)
sorted_indices_to_remove = cumulative_probs > self.top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = False
indices_to_remove = torch.zeros_like(sorted_indices_to_remove)
indices_to_remove.scatter_(-1, sorted_indices, sorted_indices_to_remove)
sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p)
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
return self.sample(logits, indices_to_remove)

@ -83,7 +83,7 @@ def test_greedy_generation(max_new_tokens=4):
max_new_tokens=max_new_tokens,
)
hf_outputs = BloomForCausalLM.greedy_search(model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens)
assert torch.allclose(remote_outputs, hf_outputs), "Greedy search are not identical to HF"
assert torch.allclose(remote_outputs, hf_outputs), "Greedy search results are not identical to HF"
inputs_batch = tokenizer(["A cat sat on a mat", "A dog sat on a mat"], return_tensors="pt", padding=True)[
"input_ids"
@ -97,7 +97,53 @@ def test_greedy_generation(max_new_tokens=4):
)
assert torch.allclose(
remote_outputs_batch, hf_outputs_batch
), "Greedy search are not identical to HF in multibatch mode"
), "Greedy search results are not identical to HF in multibatch mode"
@pytest.mark.forked
@pytest.mark.parametrize("sampling_options", [dict(), dict(temperature=100.0), dict(top_k=5), dict(top_p=0.9)])
@pytest.mark.skip("Sampling is currently not consistent with outputs from Transformers")
def test_sampling(sampling_options, max_new_tokens=4):
torch.manual_seed(0)
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
model = DistributedBloomForCausalLM.from_pretrained(
MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
)
logits_warper = BloomForCausalLM._get_logits_warper(model, num_beams=1, **sampling_options)
inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
with torch.random.fork_rng():
remote_outputs = model.generate(
inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
**sampling_options,
)
with torch.random.fork_rng():
hf_outputs = BloomForCausalLM.sample(
model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens, logits_warper=logits_warper
)
assert torch.allclose(remote_outputs, hf_outputs), "Sampling results are not identical to HF"
inputs_batch = tokenizer(["A cat sat on a mat", "A dog sat on a mat"], return_tensors="pt", padding=True)[
"input_ids"
]
with torch.random.fork_rng():
remote_outputs_batch = model.generate(
inputs_batch,
max_new_tokens=max_new_tokens,
do_sample=True,
**sampling_options,
)
with torch.random.fork_rng():
hf_outputs_batch = BloomForCausalLM.sample(
model,
input_ids=inputs_batch,
max_length=inputs_batch.size(1) + max_new_tokens,
logits_warper=logits_warper,
)
assert torch.allclose(
remote_outputs_batch, hf_outputs_batch
), "Sampling results are not identical to HF in multibatch mode"
@pytest.mark.forked

Loading…
Cancel
Save