Make client compatible with transformers' GenerationMixin (#464)
This PR drops custom generation codes and introduces compatibility with `transformers.GenerationMixin` instead. This includes support for more sampling options (`top_p`, `top_k`, `repetition_penalty` requested in #460) and beam search - all that is now identical to running model with transformers locally. Most features (excluding beam search and other rarely used stuff) are also compatible with resuming existing sessions. ### Breaking changes If `.generate()` or forward passes are being run inside an `.inference_session()` context, they now use the opened session by default. So, these snippets are now equivalent: ```python # Using default session with model.inference_session(max_length=100): output_ids = model.generate(input_ids, max_new_tokens=3) # Explicitly specifying a session with model.inference_session(max_length=100) as sess: output_ids = model.generate(input_ids, max_new_tokens=3, session=sess) ``` Earlier, the 1st snippet was creating a new session, which is not what most people expected (= such code was most likely to introduce a bug, which is now fixed).pull/470/head
parent
063e94b4c8
commit
de2475f31c
@ -1,349 +1,142 @@
|
||||
import contextlib
|
||||
from typing import List, Optional
|
||||
import dataclasses
|
||||
from contextvars import ContextVar
|
||||
from typing import ContextManager, List, Optional
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from hivemind.utils.logging import get_logger
|
||||
from transformers.generation.utils import ModelOutput
|
||||
|
||||
from petals.client.inference_session import InferenceSession
|
||||
from petals.utils.generation_algorithms import (
|
||||
BeamSearchAlgorithm,
|
||||
DecodingAlgorithm,
|
||||
GreedyAlgorithm,
|
||||
NucleusAlgorithm,
|
||||
SamplingAlgorithm,
|
||||
TopKAlgorithm,
|
||||
)
|
||||
from petals.utils.generation_constraints import ABCBloomConstraint, EosConstraint
|
||||
from petals.client.remote_sequential import RemoteSequential
|
||||
from petals.utils.misc import DUMMY, docstring_from
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class RemoteGenerationMixin:
|
||||
"""
|
||||
A class containing all functions for auto-regressive text generation, to be used as a mixin in [`BloomForCausalLM`].
|
||||
The class exposes can be used for:
|
||||
- *greedy decoding*.
|
||||
- *multinomial, top-k and top-p sampling*.
|
||||
- *beam-search decoding*
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class RemotePastKeyValues:
|
||||
"""A mock class representing the fact that `past_key_values` do exist but are stored on remote servers."""
|
||||
|
||||
This class is similar to transformer's [`generation_utils.GenerationMixin`], it can be used instead of it.
|
||||
However, it has some differences for remote usage.
|
||||
"""
|
||||
hypo_ids: Optional[torch.LongTensor] = None
|
||||
|
||||
def inference_session(self, **kwargs) -> InferenceSession:
|
||||
"""
|
||||
Returns an inference session for the model's RemoteSequential module.
|
||||
def __getitem__(self, _index: int) -> List[torch.Tensor]:
|
||||
return [DUMMY] # For compatibility with BloomForCausalLM.prepare_inputs_for_generation()
|
||||
|
||||
:param max_length: Maximal expected length of inference results. Servers use this parameter
|
||||
to calculate the size of attention caches allocated to this client.
|
||||
"""
|
||||
|
||||
return self.transformer.h.inference_session(**kwargs)
|
||||
_skipped_tokens = ContextVar("skipped_tokens", default=0)
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(
|
||||
self,
|
||||
inputs: Optional[torch.Tensor] = None,
|
||||
*,
|
||||
do_sample: Optional[bool] = None,
|
||||
temperature: float = 1.0,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
num_beams: Optional[int] = 1,
|
||||
bos_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
max_length: Optional[int] = None,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
decoding_algorithm: Optional[DecodingAlgorithm] = None,
|
||||
provided_constraints: List[ABCBloomConstraint] = [],
|
||||
num_return_sequences: Optional[int] = None,
|
||||
session: Optional[InferenceSession] = None,
|
||||
) -> torch.LongTensor:
|
||||
"""
|
||||
Generates sequences of token ids for models with a language modeling head.
|
||||
|
||||
:param inputs: The input tokens to the model.
|
||||
:param do_sample: Whether to sample from the model predictions or take the argmax.
|
||||
:param temperature: The temperature to use for sampling.
|
||||
:param top_k: The number of results to return.
|
||||
:param top_p: The cumulative probability of results to return.
|
||||
:param num_beams: The number of beams to use for beam search.
|
||||
:param bos_token_id: The id of the beginning of sentence token.
|
||||
:param eos_token_id: The id of the end of sentence token.
|
||||
:param pad_token_id: The id of the padding token.
|
||||
:param max_length: The maximum number of tokens in the output (including input tokens).
|
||||
:param max_new_tokens: The maximum number of tokens to generate.
|
||||
:param decoding_algorithm: The decoding algorithm to use.
|
||||
:param provided_constraints: A list of constraints to use.
|
||||
:param num_return_sequences: How many hypothesis from the beam will be in output.
|
||||
"""
|
||||
class _SkipTokensMixin:
|
||||
# This override is used in RemoteGenerationMixin by has to be defined in a class not named as "GenerationMixin"
|
||||
# due to how transformers.PreTrainedModel.can_generate() works
|
||||
def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> dict:
|
||||
input_ids = input_ids[:, _skipped_tokens.get() :]
|
||||
_skipped_tokens.set(0)
|
||||
return super().prepare_inputs_for_generation(input_ids, **kwargs)
|
||||
|
||||
prefix_length = 0 if inputs is None else inputs.size(1)
|
||||
prefix_length += self.config.pre_seq_len
|
||||
|
||||
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
class RemoteGenerationMixin(_SkipTokensMixin):
|
||||
"""
|
||||
This class is an upgrade to `transformers.GenerationMixin` that:
|
||||
|
||||
- Designed to be compatible with most `transformers.GenerationMixin` strategies and options
|
||||
- Supports generation inside a remote InferenceSession, so that remote servers store your attention caches and
|
||||
you don't have to rerun the prefix through all the servers to generate each new token
|
||||
- Supports multiple `.generate()` calls inside one InferenceSession, so you can easily run interactive generation
|
||||
by showing tokens on the fly (multiple calls like `.generate(None, max_new_tokens=1, ...)`) or
|
||||
accept prompts from a user in a chat bot (multiple calls like `.generate(new_prompts, ...)`).
|
||||
- If there is no active session, `.generate()` will create a new InferenceSession with proper `max_length`.
|
||||
Otherwise, `.generate()` will use the active session. You can use the `session=...` argument to override that.
|
||||
"""
|
||||
|
||||
assert (max_length is None) != (max_new_tokens is None), "please set max_length or max_new_tokens (not both)"
|
||||
if max_length is not None and max_new_tokens is None:
|
||||
max_new_tokens = max_length - prefix_length
|
||||
assert max_new_tokens > 0, f"Provided max_length is less than prefix size: {max_length} < {inputs.size(1)}"
|
||||
elif max_length is None and max_new_tokens is not None:
|
||||
max_length = prefix_length + max_new_tokens
|
||||
@docstring_from(RemoteSequential.active_session)
|
||||
@property
|
||||
def active_session(self) -> Optional[InferenceSession]:
|
||||
return self.transformer.h.active_session
|
||||
|
||||
resuming_session = session is not None and session.last_token_id is not None
|
||||
if num_beams > 1 and resuming_session:
|
||||
raise NotImplementedError(
|
||||
"Resuming inference session in .generate() along with beam search is not supported yet"
|
||||
)
|
||||
@docstring_from(RemoteSequential.use_session)
|
||||
def use_session(self, session: Optional[InferenceSession]) -> ContextManager[InferenceSession]:
|
||||
return self.transformer.h.use_session(session)
|
||||
|
||||
if inputs is not None:
|
||||
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 2, "inputs must be a 2d tensor [batch, length]"
|
||||
if resuming_session:
|
||||
inputs = torch.cat([session.last_token_id, inputs], dim=1)
|
||||
else:
|
||||
if resuming_session:
|
||||
inputs = session.last_token_id
|
||||
else:
|
||||
assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs"
|
||||
inputs = torch.tensor([[bos_token_id]] * num_beams, dtype=torch.long, device=self.device)
|
||||
batch_size = inputs.size(0)
|
||||
@docstring_from(RemoteSequential.inference_session)
|
||||
def inference_session(self, **kwargs) -> ContextManager[InferenceSession]:
|
||||
return self.transformer.h.inference_session(**kwargs)
|
||||
|
||||
if decoding_algorithm is None:
|
||||
if do_sample:
|
||||
decoding_algorithm = self._choose_sample_algorithm(temperature, top_k, top_p)
|
||||
elif num_beams is not None and num_beams > 1:
|
||||
decoding_algorithm = BeamSearchAlgorithm(num_beams, batch_size=batch_size)
|
||||
@docstring_from(transformers.GenerationMixin.generate.__doc__)
|
||||
def generate(
|
||||
self, inputs: Optional[torch.Tensor] = None, *args, session: Optional[InferenceSession] = None, **kwargs
|
||||
):
|
||||
self._fix_generate_kwargs(kwargs)
|
||||
|
||||
if session is not None:
|
||||
# If a session specified explicitly, use it
|
||||
context_manager = self.use_session(session)
|
||||
elif self.active_session is not None:
|
||||
# If there's an active session, don't do anything
|
||||
context_manager = contextlib.nullcontext(self.active_session)
|
||||
else:
|
||||
if top_k is not None or top_p is not None:
|
||||
logger.warning("You passed top_k or top_p but did not pass do_sample=True. Running greedy sampling")
|
||||
decoding_algorithm = GreedyAlgorithm()
|
||||
|
||||
if num_beams > 1:
|
||||
inputs = torch.cat([inputs] * num_beams, dim=0)
|
||||
if batch_size > 1:
|
||||
# TODO: resolve padding problem
|
||||
logger.warning(
|
||||
f"You set batch_size {batch_size} within beam search generation. "
|
||||
f"Be careful, results on sequences with different length may be padded wrong way"
|
||||
)
|
||||
# If there's no active session, create a new one
|
||||
|
||||
if num_return_sequences is None:
|
||||
num_return_sequences = 1
|
||||
max_length = kwargs.get("max_length")
|
||||
max_new_tokens = kwargs.get("max_new_tokens")
|
||||
assert (max_length is None) != (
|
||||
max_new_tokens is None
|
||||
), "You should set `max_length` or `max_new_tokens` (but not both) to reserve server-side attention caches"
|
||||
|
||||
assert num_return_sequences <= num_beams, (
|
||||
f"You want more sequences than the beam has."
|
||||
" Check num_return_sequences: {num_return_sequences} and num_beams: {num_beams}."
|
||||
)
|
||||
|
||||
constraints = self._get_constraints(
|
||||
inputs=inputs,
|
||||
eos_token_id=eos_token_id,
|
||||
pad_token_id=pad_token_id,
|
||||
provided_constraints=provided_constraints,
|
||||
)
|
||||
|
||||
if session is None:
|
||||
context_manager = self.inference_session(max_length=max_length)
|
||||
else:
|
||||
context_manager = contextlib.nullcontext(session) # Doesn't actually enter session or exit from it
|
||||
with context_manager as session:
|
||||
outputs = []
|
||||
# Find samples with padded inputs.
|
||||
# They will be changed before all of the samples have right length.
|
||||
if torch.any(inputs == pad_token_id): # TODO: move to prepare_inputs
|
||||
outputs += [inputs[:, : inputs.size(1) - (inputs == pad_token_id).sum(-1).max()]]
|
||||
if max_length is not None:
|
||||
session_max_length = max_length
|
||||
else:
|
||||
outputs += [inputs]
|
||||
last_token_id = None
|
||||
seq_idx = outputs[0].size(1)
|
||||
hypo_ids = torch.arange(outputs[0].size(0))
|
||||
while True:
|
||||
hidden_state = self.transformer.word_embeddings(outputs[-1])
|
||||
intermediate_prompts = None
|
||||
if self.config.pre_seq_len > 0 and len(outputs) == 1:
|
||||
prompts, intermediate_prompts = self.transformer.get_prompt(hidden_state.size(0))
|
||||
hidden_state = torch.cat([prompts, hidden_state], dim=1)
|
||||
hidden_state = self.transformer.word_embeddings_layernorm(hidden_state)
|
||||
session_max_length = (inputs.shape[1] if inputs is not None else 0) + max_new_tokens
|
||||
context_manager = self.inference_session(max_length=session_max_length)
|
||||
|
||||
hidden_state = session.step(hidden_state, prompts=intermediate_prompts, hypo_ids=hypo_ids)[:, -1]
|
||||
|
||||
hidden_state = self.transformer.ln_f(hidden_state)
|
||||
lm_logits = self.lm_head(hidden_state)
|
||||
|
||||
for constraint in constraints:
|
||||
lm_logits = constraint(last_token_id, lm_logits, hypo_ids)
|
||||
last_token_id, hypo_ids = decoding_algorithm(lm_logits)
|
||||
|
||||
# If some samples were padded, change only these samples
|
||||
if seq_idx < inputs.size(1):
|
||||
pad_token_mask = inputs[:, seq_idx : seq_idx + 1] == pad_token_id
|
||||
last_token_id = (~pad_token_mask) * inputs[
|
||||
:, seq_idx : seq_idx + 1
|
||||
] + pad_token_mask * last_token_id
|
||||
|
||||
# TODO: refactor outputs
|
||||
if num_beams > 1:
|
||||
for i in range(len(outputs), 1, -1):
|
||||
outputs[i - 1] = outputs[i - 1][hypo_ids]
|
||||
|
||||
outputs.append(last_token_id)
|
||||
session.last_token_id = last_token_id
|
||||
seq_idx += 1
|
||||
if torch.all(last_token_id == eos_token_id) or len(outputs) > max_new_tokens:
|
||||
break
|
||||
|
||||
outputs = torch.cat(outputs, dim=-1)
|
||||
|
||||
if resuming_session:
|
||||
outputs = outputs[:, 1:]
|
||||
if num_beams > 1:
|
||||
pre_return_idx = [
|
||||
torch.arange(idx, num_return_sequences * batch_size, batch_size) for idx in range(batch_size)
|
||||
]
|
||||
return_idx = torch.cat(pre_return_idx, dim=0)
|
||||
outputs = outputs[return_idx]
|
||||
|
||||
return outputs
|
||||
|
||||
def greedy_search(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
max_length: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
provided_constraints: List[ABCBloomConstraint] = [],
|
||||
) -> torch.LongTensor:
|
||||
"""
|
||||
Generates sequences of token ids for models with a language modeling head. Uses greedy search.
|
||||
|
||||
:param input_ids: The input tokens to the model.
|
||||
:param max_length: The maximum length of the sequence to generate.
|
||||
:param pad_token_id: The id of the padding token.
|
||||
:param eos_token_id: The id of the end of sentence token.
|
||||
:param provided_constraints: A list of constraints to use.
|
||||
"""
|
||||
return self.generate(
|
||||
inputs=input_ids,
|
||||
max_new_tokens=max_length,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
decoding_algorithm=GreedyAlgorithm(),
|
||||
provided_constraints=provided_constraints,
|
||||
with context_manager as session:
|
||||
# Prepend the tokens from the previous .generate() call
|
||||
n_prev_tokens = session.output_ids.shape[1] if session.output_ids is not None else 0
|
||||
if n_prev_tokens > 0:
|
||||
if kwargs.get("num_beams", 1) > 1:
|
||||
logger.warning(
|
||||
"Beam search will not work properly in the resumed petals.InferenceSession "
|
||||
"since intermediate beam entries are lost"
|
||||
)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
temperature: float = 1.0,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
max_length: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
provided_constraints: List[ABCBloomConstraint] = [],
|
||||
) -> torch.LongTensor:
|
||||
"""
|
||||
Generates sequences of token ids for models with a language modeling head. Uses multinomial sampling.
|
||||
If top_k is provided, uses top_k sampling. If top_p is provided, uses nucleus sampling.
|
||||
if inputs is not None:
|
||||
inputs = torch.cat([session.output_ids, inputs], dim=1)
|
||||
else:
|
||||
inputs = session.output_ids
|
||||
|
||||
:param: input_ids: The input tokens to the model.
|
||||
:param: temperature: The temperature to use for sampling.
|
||||
:param: top_k: The number of samples to use for top_k sampling.
|
||||
:param: top_p: The probability of using top_p sampling.
|
||||
:param: max_length: The maximum length of the sequence to generate.
|
||||
:param: pad_token_id: The id of the padding token.
|
||||
:param: eos_token_id: The id of the end of sentence token.
|
||||
:param: provided_constraints: A list of constraints to use.
|
||||
"""
|
||||
# Don't actually run all previous tokens through the transformer,
|
||||
# but keep them for transformers.GenerationMixin (e.g., to compute repetition_penalty)
|
||||
_skipped_tokens.set(max(0, n_prev_tokens - 1))
|
||||
|
||||
return self.generate(
|
||||
inputs=input_ids,
|
||||
max_new_tokens=max_length,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
decoding_algorithm=self._choose_sample_algorithm(temperature, top_k, top_p),
|
||||
provided_constraints=provided_constraints,
|
||||
)
|
||||
result = super().generate(inputs, *args, **kwargs)
|
||||
|
||||
def beam_search(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
num_beams: int = 1,
|
||||
max_length: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
provided_constraints: List[ABCBloomConstraint] = [],
|
||||
) -> torch.LongTensor:
|
||||
"""
|
||||
Generates sequences of token ids for models with a language modeling head. Uses beam search.
|
||||
sequences = result.sequences if isinstance(result, ModelOutput) else result
|
||||
# Save tokens from this .generate() call
|
||||
session.output_ids = sequences
|
||||
# Crop the last tokens from the previous call
|
||||
sequences = sequences[:, n_prev_tokens:].clone()
|
||||
if isinstance(result, ModelOutput):
|
||||
result.sequences = sequences
|
||||
else:
|
||||
result = sequences
|
||||
|
||||
:param input_ids: The input tokens to the model.
|
||||
:param num_beams: The number of beams to use.
|
||||
:param max_length: The maximum length of the sequence to generate.
|
||||
:param pad_token_id: The id of the padding token.
|
||||
:param eos_token_id: The id of the end of sentence token.
|
||||
:param provided_constraints: A list of constraints to use.
|
||||
"""
|
||||
decoding_algorithm = BeamSearchAlgorithm(
|
||||
num_beams=num_beams,
|
||||
batch_size=input_ids.size(0),
|
||||
)
|
||||
return self.generate(
|
||||
inputs=input_ids,
|
||||
num_beams=num_beams,
|
||||
max_new_tokens=max_length,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
decoding_algorithm=decoding_algorithm,
|
||||
provided_constraints=provided_constraints,
|
||||
)
|
||||
return result
|
||||
|
||||
def beam_sample(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
max_length: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
provided_constraints: List[ABCBloomConstraint] = [],
|
||||
) -> torch.LongTensor:
|
||||
raise NotImplementedError
|
||||
@staticmethod
|
||||
def _fix_generate_kwargs(kwargs: dict) -> dict:
|
||||
# Suppress inappropriate "Both max_new_tokens and max_length" HF warning
|
||||
if "max_length" in kwargs and kwargs["max_length"] is None:
|
||||
del kwargs["max_length"]
|
||||
|
||||
def group_beam_search(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
max_length: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
provided_constraints: List[ABCBloomConstraint] = [],
|
||||
) -> torch.LongTensor:
|
||||
raise NotImplementedError
|
||||
# Support do_sample = {0, 1} for backward compatibility with Petals < 2.1.0
|
||||
do_sample = kwargs.get("do_sample")
|
||||
if isinstance(do_sample, int):
|
||||
kwargs["do_sample"] = bool(do_sample)
|
||||
|
||||
def _choose_sample_algorithm(
|
||||
self,
|
||||
temperature: float = 1.0,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
) -> DecodingAlgorithm:
|
||||
if (top_k is not None) and (top_p is not None):
|
||||
raise ValueError("You have to provide only top_k or top_p for sampling")
|
||||
if top_k is not None:
|
||||
return TopKAlgorithm(top_k, temperature)
|
||||
elif top_p is not None:
|
||||
return NucleusAlgorithm(top_p, temperature)
|
||||
else:
|
||||
return SamplingAlgorithm(temperature)
|
||||
return kwargs
|
||||
|
||||
def _get_constraints(
|
||||
self,
|
||||
inputs: Optional[torch.Tensor] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
provided_constraints: List[ABCBloomConstraint] = [],
|
||||
) -> List[ABCBloomConstraint]:
|
||||
constraints = []
|
||||
constraints.extend(provided_constraints)
|
||||
constraints.append(EosConstraint(inputs, eos_token_id, pad_token_id))
|
||||
return constraints
|
||||
@staticmethod
|
||||
def _reorder_cache(past_key_values: RemotePastKeyValues, beam_idx: torch.LongTensor) -> RemotePastKeyValues:
|
||||
return dataclasses.replace(past_key_values, hypo_ids=beam_idx)
|
||||
|
@ -1,128 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
TokenIds = torch.Tensor
|
||||
HypoIds = torch.Tensor
|
||||
|
||||
|
||||
class DecodingAlgorithm(ABC):
|
||||
"""
|
||||
An abstract class for decoding algorithms. Describes the base function of those algorithms:
|
||||
they have to select new tokens and provide the corresponding hypotheses.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
||||
"""
|
||||
:param logits: A tensor of shape (batch_size, seq_length, vocab_size)
|
||||
:return: A tuple of selected token ids and corresponding hypotheses.
|
||||
The shape of the token ids is (batch_size, seq_length), and the shape of the hypotheses is (batch_size)
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class GreedyAlgorithm(DecodingAlgorithm):
|
||||
"""
|
||||
The simplest algorithm for decoding. It selects the most probable token.
|
||||
"""
|
||||
|
||||
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
||||
"""
|
||||
Returns the most probable token. The second returned object is always a range of integers
|
||||
from 0 to batch_size - 1.
|
||||
"""
|
||||
return logits.max(-1)[1].unsqueeze(1), torch.arange(logits.size(0))
|
||||
|
||||
|
||||
class SamplingAlgorithm(DecodingAlgorithm):
|
||||
def __init__(self, temperature: float = 1.0):
|
||||
self.temperature = temperature
|
||||
|
||||
def sample(self, logits: torch.Tensor, indices_to_remove: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
||||
"""
|
||||
:param logits: A tensor of shape (batch_size * num_hypos, vocab_size)
|
||||
:param indices_to_remove: A bool tensor of shape (batch_size * num_hypos, vocab_size)
|
||||
:return: A tuple of selected token ids and corresponding hypotheses.
|
||||
The shape of the token ids is (batch_size, seq_length), and the shape of the hypotheses is (batch_size).
|
||||
"""
|
||||
logits[indices_to_remove] = -float("Inf")
|
||||
probs = torch.softmax(logits / self.temperature, -1)
|
||||
return torch.multinomial(probs, num_samples=1), torch.arange(logits.size(0))
|
||||
|
||||
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
||||
indices_to_remove = torch.full_like(logits, False, dtype=torch.bool)
|
||||
return self.sample(logits, indices_to_remove)
|
||||
|
||||
|
||||
class TopKAlgorithm(SamplingAlgorithm):
|
||||
def __init__(self, top_k: int, temperature: float = 1.0) -> None:
|
||||
super().__init__(temperature=temperature)
|
||||
self.top_k = top_k
|
||||
|
||||
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
||||
indices_to_remove = logits < torch.topk(logits, self.top_k, dim=-1)[0][..., -1, None]
|
||||
return self.sample(logits, indices_to_remove)
|
||||
|
||||
|
||||
class NucleusAlgorithm(SamplingAlgorithm):
|
||||
def __init__(self, top_p: float, temperature: float = 1.0) -> None:
|
||||
super().__init__(temperature=temperature)
|
||||
self.top_p = top_p
|
||||
|
||||
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=False, dim=-1)
|
||||
probs = torch.softmax(sorted_logits / self.temperature, -1)
|
||||
cumulative_probs = torch.cumsum(probs, dim=-1)
|
||||
|
||||
sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p)
|
||||
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
return self.sample(logits, indices_to_remove)
|
||||
|
||||
|
||||
class BeamSearchAlgorithm(DecodingAlgorithm):
|
||||
def __init__(self, num_beams: int, batch_size: int) -> None:
|
||||
self.num_beams = num_beams
|
||||
self.batch_size = batch_size
|
||||
|
||||
self._batch_beams = [list() for _ in range(batch_size)]
|
||||
|
||||
def __call__(self, logits: torch.Tensor):
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||
probs = torch.log_softmax(sorted_logits, -1)
|
||||
|
||||
if len(self._batch_beams[0]) > 0:
|
||||
for batch_idx in range(self.batch_size):
|
||||
new_beams = []
|
||||
cur_beams = self._batch_beams[batch_idx]
|
||||
for beam_idx in range(len(cur_beams)):
|
||||
probs_idx = batch_idx + beam_idx * self.batch_size
|
||||
new_beam = cur_beams[beam_idx]
|
||||
for hypo_idx in range(self.num_beams):
|
||||
new_beams.append(
|
||||
(new_beam[0] + probs[probs_idx, hypo_idx].item(), beam_idx * self.num_beams + hypo_idx)
|
||||
)
|
||||
self._batch_beams[batch_idx] = sorted(new_beams, reverse=True)[: self.num_beams]
|
||||
else:
|
||||
for batch_idx in range(self.batch_size):
|
||||
for beam_idx in range(self.num_beams):
|
||||
self._batch_beams[batch_idx].append((probs[batch_idx, beam_idx].item(), beam_idx))
|
||||
|
||||
return_hypos = []
|
||||
return_tokens = []
|
||||
for batch_idx in range(self.batch_size):
|
||||
cur_beam = self._batch_beams[batch_idx]
|
||||
return_hypos.append(list())
|
||||
return_tokens.append(list())
|
||||
for beam in cur_beam:
|
||||
beam_idx = beam[1] // self.num_beams
|
||||
hypo_idx = batch_idx + beam_idx * self.batch_size
|
||||
token_idx = beam[1] % self.num_beams
|
||||
return_hypos[-1].append(hypo_idx)
|
||||
return_tokens[-1].append([sorted_indices[hypo_idx, token_idx].item()])
|
||||
return_hypos = [hypo_idx for hypo_indexes in zip(*return_hypos) for hypo_idx in hypo_indexes]
|
||||
return_tokens = [token_idx for token_indexes in zip(*return_tokens) for token_idx in token_indexes]
|
||||
|
||||
return torch.tensor(return_tokens), torch.tensor(return_hypos)
|
@ -1,51 +0,0 @@
|
||||
from abc import ABC
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class ABCBloomConstraint(ABC):
|
||||
"""
|
||||
Base class of all kind of decoding constraints. It can be used to implement a new constraint.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
This method is called by the decoding algorithm to apply the constraint. It changes and returns new logits.
|
||||
:param tokens_id: The token id of the last chosen token.
|
||||
:param logits: The logits from the Bloom model.
|
||||
:param hypo_ids: The hypothesis ids of the last tokens.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class EosConstraint(ABCBloomConstraint):
|
||||
"""
|
||||
This constrained repeats EOS token if it was generated on the previous step.
|
||||
Args:
|
||||
prefix: The prefix of the sequence.
|
||||
eos_token_id: The id of the end of sentence token.
|
||||
pad_token_id: The id of the padding token.
|
||||
min_logits: The minimum logits that can be generated. Default: -1e6.
|
||||
"""
|
||||
|
||||
def __init__(self, prefix: torch.Tensor, eos_token_id: int, pad_token_id: int, min_logits: float = -1e8) -> None:
|
||||
self.eos_token_id = eos_token_id
|
||||
self.min_logits = min_logits
|
||||
self.past_tokens = None
|
||||
|
||||
self.wait_until_starting = (prefix == pad_token_id).sum(1).unsqueeze(1)
|
||||
|
||||
def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor:
|
||||
if self.past_tokens is not None:
|
||||
mask = (self.wait_until_starting < 0) & (self.past_tokens == self.eos_token_id)
|
||||
logits += self.min_logits * mask
|
||||
logits[mask[:, 0], self.eos_token_id] = 0
|
||||
|
||||
if tokens_id is not None:
|
||||
self.past_tokens = tokens_id
|
||||
self.wait_until_starting -= 1
|
||||
|
||||
return logits
|
Loading…
Reference in New Issue