Make client compatible with transformers' GenerationMixin (#464)
This PR drops custom generation codes and introduces compatibility with `transformers.GenerationMixin` instead. This includes support for more sampling options (`top_p`, `top_k`, `repetition_penalty` requested in #460) and beam search - all that is now identical to running model with transformers locally. Most features (excluding beam search and other rarely used stuff) are also compatible with resuming existing sessions. ### Breaking changes If `.generate()` or forward passes are being run inside an `.inference_session()` context, they now use the opened session by default. So, these snippets are now equivalent: ```python # Using default session with model.inference_session(max_length=100): output_ids = model.generate(input_ids, max_new_tokens=3) # Explicitly specifying a session with model.inference_session(max_length=100) as sess: output_ids = model.generate(input_ids, max_new_tokens=3, session=sess) ``` Earlier, the 1st snippet was creating a new session, which is not what most people expected (= such code was most likely to introduce a bug, which is now fixed).pull/470/head
parent
063e94b4c8
commit
de2475f31c
@ -1,349 +1,142 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
from typing import List, Optional
|
import dataclasses
|
||||||
|
from contextvars import ContextVar
|
||||||
|
from typing import ContextManager, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import transformers
|
||||||
from hivemind.utils.logging import get_logger
|
from hivemind.utils.logging import get_logger
|
||||||
|
from transformers.generation.utils import ModelOutput
|
||||||
|
|
||||||
from petals.client.inference_session import InferenceSession
|
from petals.client.inference_session import InferenceSession
|
||||||
from petals.utils.generation_algorithms import (
|
from petals.client.remote_sequential import RemoteSequential
|
||||||
BeamSearchAlgorithm,
|
from petals.utils.misc import DUMMY, docstring_from
|
||||||
DecodingAlgorithm,
|
|
||||||
GreedyAlgorithm,
|
|
||||||
NucleusAlgorithm,
|
|
||||||
SamplingAlgorithm,
|
|
||||||
TopKAlgorithm,
|
|
||||||
)
|
|
||||||
from petals.utils.generation_constraints import ABCBloomConstraint, EosConstraint
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class RemoteGenerationMixin:
|
@dataclasses.dataclass(frozen=True)
|
||||||
"""
|
class RemotePastKeyValues:
|
||||||
A class containing all functions for auto-regressive text generation, to be used as a mixin in [`BloomForCausalLM`].
|
"""A mock class representing the fact that `past_key_values` do exist but are stored on remote servers."""
|
||||||
The class exposes can be used for:
|
|
||||||
- *greedy decoding*.
|
|
||||||
- *multinomial, top-k and top-p sampling*.
|
|
||||||
- *beam-search decoding*
|
|
||||||
|
|
||||||
This class is similar to transformer's [`generation_utils.GenerationMixin`], it can be used instead of it.
|
|
||||||
However, it has some differences for remote usage.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def inference_session(self, **kwargs) -> InferenceSession:
|
|
||||||
"""
|
|
||||||
Returns an inference session for the model's RemoteSequential module.
|
|
||||||
|
|
||||||
:param max_length: Maximal expected length of inference results. Servers use this parameter
|
hypo_ids: Optional[torch.LongTensor] = None
|
||||||
to calculate the size of attention caches allocated to this client.
|
|
||||||
"""
|
|
||||||
|
|
||||||
return self.transformer.h.inference_session(**kwargs)
|
def __getitem__(self, _index: int) -> List[torch.Tensor]:
|
||||||
|
return [DUMMY] # For compatibility with BloomForCausalLM.prepare_inputs_for_generation()
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def generate(
|
|
||||||
self,
|
|
||||||
inputs: Optional[torch.Tensor] = None,
|
|
||||||
*,
|
|
||||||
do_sample: Optional[bool] = None,
|
|
||||||
temperature: float = 1.0,
|
|
||||||
top_k: Optional[int] = None,
|
|
||||||
top_p: Optional[float] = None,
|
|
||||||
num_beams: Optional[int] = 1,
|
|
||||||
bos_token_id: Optional[int] = None,
|
|
||||||
eos_token_id: Optional[int] = None,
|
|
||||||
pad_token_id: Optional[int] = None,
|
|
||||||
max_length: Optional[int] = None,
|
|
||||||
max_new_tokens: Optional[int] = None,
|
|
||||||
decoding_algorithm: Optional[DecodingAlgorithm] = None,
|
|
||||||
provided_constraints: List[ABCBloomConstraint] = [],
|
|
||||||
num_return_sequences: Optional[int] = None,
|
|
||||||
session: Optional[InferenceSession] = None,
|
|
||||||
) -> torch.LongTensor:
|
|
||||||
"""
|
|
||||||
Generates sequences of token ids for models with a language modeling head.
|
|
||||||
|
|
||||||
:param inputs: The input tokens to the model.
|
_skipped_tokens = ContextVar("skipped_tokens", default=0)
|
||||||
:param do_sample: Whether to sample from the model predictions or take the argmax.
|
|
||||||
:param temperature: The temperature to use for sampling.
|
|
||||||
:param top_k: The number of results to return.
|
|
||||||
:param top_p: The cumulative probability of results to return.
|
|
||||||
:param num_beams: The number of beams to use for beam search.
|
|
||||||
:param bos_token_id: The id of the beginning of sentence token.
|
|
||||||
:param eos_token_id: The id of the end of sentence token.
|
|
||||||
:param pad_token_id: The id of the padding token.
|
|
||||||
:param max_length: The maximum number of tokens in the output (including input tokens).
|
|
||||||
:param max_new_tokens: The maximum number of tokens to generate.
|
|
||||||
:param decoding_algorithm: The decoding algorithm to use.
|
|
||||||
:param provided_constraints: A list of constraints to use.
|
|
||||||
:param num_return_sequences: How many hypothesis from the beam will be in output.
|
|
||||||
"""
|
|
||||||
|
|
||||||
prefix_length = 0 if inputs is None else inputs.size(1)
|
|
||||||
prefix_length += self.config.pre_seq_len
|
|
||||||
|
|
||||||
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
|
class _SkipTokensMixin:
|
||||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
# This override is used in RemoteGenerationMixin by has to be defined in a class not named as "GenerationMixin"
|
||||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
# due to how transformers.PreTrainedModel.can_generate() works
|
||||||
|
def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> dict:
|
||||||
|
input_ids = input_ids[:, _skipped_tokens.get() :]
|
||||||
|
_skipped_tokens.set(0)
|
||||||
|
return super().prepare_inputs_for_generation(input_ids, **kwargs)
|
||||||
|
|
||||||
assert (max_length is None) != (max_new_tokens is None), "please set max_length or max_new_tokens (not both)"
|
|
||||||
if max_length is not None and max_new_tokens is None:
|
|
||||||
max_new_tokens = max_length - prefix_length
|
|
||||||
assert max_new_tokens > 0, f"Provided max_length is less than prefix size: {max_length} < {inputs.size(1)}"
|
|
||||||
elif max_length is None and max_new_tokens is not None:
|
|
||||||
max_length = prefix_length + max_new_tokens
|
|
||||||
|
|
||||||
resuming_session = session is not None and session.last_token_id is not None
|
class RemoteGenerationMixin(_SkipTokensMixin):
|
||||||
if num_beams > 1 and resuming_session:
|
"""
|
||||||
raise NotImplementedError(
|
This class is an upgrade to `transformers.GenerationMixin` that:
|
||||||
"Resuming inference session in .generate() along with beam search is not supported yet"
|
|
||||||
)
|
- Designed to be compatible with most `transformers.GenerationMixin` strategies and options
|
||||||
|
- Supports generation inside a remote InferenceSession, so that remote servers store your attention caches and
|
||||||
|
you don't have to rerun the prefix through all the servers to generate each new token
|
||||||
|
- Supports multiple `.generate()` calls inside one InferenceSession, so you can easily run interactive generation
|
||||||
|
by showing tokens on the fly (multiple calls like `.generate(None, max_new_tokens=1, ...)`) or
|
||||||
|
accept prompts from a user in a chat bot (multiple calls like `.generate(new_prompts, ...)`).
|
||||||
|
- If there is no active session, `.generate()` will create a new InferenceSession with proper `max_length`.
|
||||||
|
Otherwise, `.generate()` will use the active session. You can use the `session=...` argument to override that.
|
||||||
|
"""
|
||||||
|
|
||||||
if inputs is not None:
|
@docstring_from(RemoteSequential.active_session)
|
||||||
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 2, "inputs must be a 2d tensor [batch, length]"
|
@property
|
||||||
if resuming_session:
|
def active_session(self) -> Optional[InferenceSession]:
|
||||||
inputs = torch.cat([session.last_token_id, inputs], dim=1)
|
return self.transformer.h.active_session
|
||||||
else:
|
|
||||||
if resuming_session:
|
|
||||||
inputs = session.last_token_id
|
|
||||||
else:
|
|
||||||
assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs"
|
|
||||||
inputs = torch.tensor([[bos_token_id]] * num_beams, dtype=torch.long, device=self.device)
|
|
||||||
batch_size = inputs.size(0)
|
|
||||||
|
|
||||||
if decoding_algorithm is None:
|
@docstring_from(RemoteSequential.use_session)
|
||||||
if do_sample:
|
def use_session(self, session: Optional[InferenceSession]) -> ContextManager[InferenceSession]:
|
||||||
decoding_algorithm = self._choose_sample_algorithm(temperature, top_k, top_p)
|
return self.transformer.h.use_session(session)
|
||||||
elif num_beams is not None and num_beams > 1:
|
|
||||||
decoding_algorithm = BeamSearchAlgorithm(num_beams, batch_size=batch_size)
|
|
||||||
else:
|
|
||||||
if top_k is not None or top_p is not None:
|
|
||||||
logger.warning("You passed top_k or top_p but did not pass do_sample=True. Running greedy sampling")
|
|
||||||
decoding_algorithm = GreedyAlgorithm()
|
|
||||||
|
|
||||||
if num_beams > 1:
|
@docstring_from(RemoteSequential.inference_session)
|
||||||
inputs = torch.cat([inputs] * num_beams, dim=0)
|
def inference_session(self, **kwargs) -> ContextManager[InferenceSession]:
|
||||||
if batch_size > 1:
|
return self.transformer.h.inference_session(**kwargs)
|
||||||
# TODO: resolve padding problem
|
|
||||||
logger.warning(
|
|
||||||
f"You set batch_size {batch_size} within beam search generation. "
|
|
||||||
f"Be careful, results on sequences with different length may be padded wrong way"
|
|
||||||
)
|
|
||||||
|
|
||||||
if num_return_sequences is None:
|
@docstring_from(transformers.GenerationMixin.generate.__doc__)
|
||||||
num_return_sequences = 1
|
def generate(
|
||||||
|
self, inputs: Optional[torch.Tensor] = None, *args, session: Optional[InferenceSession] = None, **kwargs
|
||||||
|
):
|
||||||
|
self._fix_generate_kwargs(kwargs)
|
||||||
|
|
||||||
|
if session is not None:
|
||||||
|
# If a session specified explicitly, use it
|
||||||
|
context_manager = self.use_session(session)
|
||||||
|
elif self.active_session is not None:
|
||||||
|
# If there's an active session, don't do anything
|
||||||
|
context_manager = contextlib.nullcontext(self.active_session)
|
||||||
|
else:
|
||||||
|
# If there's no active session, create a new one
|
||||||
|
|
||||||
assert num_return_sequences <= num_beams, (
|
max_length = kwargs.get("max_length")
|
||||||
f"You want more sequences than the beam has."
|
max_new_tokens = kwargs.get("max_new_tokens")
|
||||||
" Check num_return_sequences: {num_return_sequences} and num_beams: {num_beams}."
|
assert (max_length is None) != (
|
||||||
)
|
max_new_tokens is None
|
||||||
|
), "You should set `max_length` or `max_new_tokens` (but not both) to reserve server-side attention caches"
|
||||||
|
|
||||||
constraints = self._get_constraints(
|
if max_length is not None:
|
||||||
inputs=inputs,
|
session_max_length = max_length
|
||||||
eos_token_id=eos_token_id,
|
else:
|
||||||
pad_token_id=pad_token_id,
|
session_max_length = (inputs.shape[1] if inputs is not None else 0) + max_new_tokens
|
||||||
provided_constraints=provided_constraints,
|
context_manager = self.inference_session(max_length=session_max_length)
|
||||||
)
|
|
||||||
|
|
||||||
if session is None:
|
|
||||||
context_manager = self.inference_session(max_length=max_length)
|
|
||||||
else:
|
|
||||||
context_manager = contextlib.nullcontext(session) # Doesn't actually enter session or exit from it
|
|
||||||
with context_manager as session:
|
with context_manager as session:
|
||||||
outputs = []
|
# Prepend the tokens from the previous .generate() call
|
||||||
# Find samples with padded inputs.
|
n_prev_tokens = session.output_ids.shape[1] if session.output_ids is not None else 0
|
||||||
# They will be changed before all of the samples have right length.
|
if n_prev_tokens > 0:
|
||||||
if torch.any(inputs == pad_token_id): # TODO: move to prepare_inputs
|
if kwargs.get("num_beams", 1) > 1:
|
||||||
outputs += [inputs[:, : inputs.size(1) - (inputs == pad_token_id).sum(-1).max()]]
|
logger.warning(
|
||||||
|
"Beam search will not work properly in the resumed petals.InferenceSession "
|
||||||
|
"since intermediate beam entries are lost"
|
||||||
|
)
|
||||||
|
|
||||||
|
if inputs is not None:
|
||||||
|
inputs = torch.cat([session.output_ids, inputs], dim=1)
|
||||||
|
else:
|
||||||
|
inputs = session.output_ids
|
||||||
|
|
||||||
|
# Don't actually run all previous tokens through the transformer,
|
||||||
|
# but keep them for transformers.GenerationMixin (e.g., to compute repetition_penalty)
|
||||||
|
_skipped_tokens.set(max(0, n_prev_tokens - 1))
|
||||||
|
|
||||||
|
result = super().generate(inputs, *args, **kwargs)
|
||||||
|
|
||||||
|
sequences = result.sequences if isinstance(result, ModelOutput) else result
|
||||||
|
# Save tokens from this .generate() call
|
||||||
|
session.output_ids = sequences
|
||||||
|
# Crop the last tokens from the previous call
|
||||||
|
sequences = sequences[:, n_prev_tokens:].clone()
|
||||||
|
if isinstance(result, ModelOutput):
|
||||||
|
result.sequences = sequences
|
||||||
else:
|
else:
|
||||||
outputs += [inputs]
|
result = sequences
|
||||||
last_token_id = None
|
|
||||||
seq_idx = outputs[0].size(1)
|
|
||||||
hypo_ids = torch.arange(outputs[0].size(0))
|
|
||||||
while True:
|
|
||||||
hidden_state = self.transformer.word_embeddings(outputs[-1])
|
|
||||||
intermediate_prompts = None
|
|
||||||
if self.config.pre_seq_len > 0 and len(outputs) == 1:
|
|
||||||
prompts, intermediate_prompts = self.transformer.get_prompt(hidden_state.size(0))
|
|
||||||
hidden_state = torch.cat([prompts, hidden_state], dim=1)
|
|
||||||
hidden_state = self.transformer.word_embeddings_layernorm(hidden_state)
|
|
||||||
|
|
||||||
hidden_state = session.step(hidden_state, prompts=intermediate_prompts, hypo_ids=hypo_ids)[:, -1]
|
return result
|
||||||
|
|
||||||
hidden_state = self.transformer.ln_f(hidden_state)
|
@staticmethod
|
||||||
lm_logits = self.lm_head(hidden_state)
|
def _fix_generate_kwargs(kwargs: dict) -> dict:
|
||||||
|
# Suppress inappropriate "Both max_new_tokens and max_length" HF warning
|
||||||
|
if "max_length" in kwargs and kwargs["max_length"] is None:
|
||||||
|
del kwargs["max_length"]
|
||||||
|
|
||||||
for constraint in constraints:
|
# Support do_sample = {0, 1} for backward compatibility with Petals < 2.1.0
|
||||||
lm_logits = constraint(last_token_id, lm_logits, hypo_ids)
|
do_sample = kwargs.get("do_sample")
|
||||||
last_token_id, hypo_ids = decoding_algorithm(lm_logits)
|
if isinstance(do_sample, int):
|
||||||
|
kwargs["do_sample"] = bool(do_sample)
|
||||||
|
|
||||||
# If some samples were padded, change only these samples
|
return kwargs
|
||||||
if seq_idx < inputs.size(1):
|
|
||||||
pad_token_mask = inputs[:, seq_idx : seq_idx + 1] == pad_token_id
|
|
||||||
last_token_id = (~pad_token_mask) * inputs[
|
|
||||||
:, seq_idx : seq_idx + 1
|
|
||||||
] + pad_token_mask * last_token_id
|
|
||||||
|
|
||||||
# TODO: refactor outputs
|
|
||||||
if num_beams > 1:
|
|
||||||
for i in range(len(outputs), 1, -1):
|
|
||||||
outputs[i - 1] = outputs[i - 1][hypo_ids]
|
|
||||||
|
|
||||||
outputs.append(last_token_id)
|
|
||||||
session.last_token_id = last_token_id
|
|
||||||
seq_idx += 1
|
|
||||||
if torch.all(last_token_id == eos_token_id) or len(outputs) > max_new_tokens:
|
|
||||||
break
|
|
||||||
|
|
||||||
outputs = torch.cat(outputs, dim=-1)
|
|
||||||
|
|
||||||
if resuming_session:
|
|
||||||
outputs = outputs[:, 1:]
|
|
||||||
if num_beams > 1:
|
|
||||||
pre_return_idx = [
|
|
||||||
torch.arange(idx, num_return_sequences * batch_size, batch_size) for idx in range(batch_size)
|
|
||||||
]
|
|
||||||
return_idx = torch.cat(pre_return_idx, dim=0)
|
|
||||||
outputs = outputs[return_idx]
|
|
||||||
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
def greedy_search(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor,
|
|
||||||
max_length: Optional[int] = None,
|
|
||||||
pad_token_id: Optional[int] = None,
|
|
||||||
eos_token_id: Optional[int] = None,
|
|
||||||
provided_constraints: List[ABCBloomConstraint] = [],
|
|
||||||
) -> torch.LongTensor:
|
|
||||||
"""
|
|
||||||
Generates sequences of token ids for models with a language modeling head. Uses greedy search.
|
|
||||||
|
|
||||||
:param input_ids: The input tokens to the model.
|
|
||||||
:param max_length: The maximum length of the sequence to generate.
|
|
||||||
:param pad_token_id: The id of the padding token.
|
|
||||||
:param eos_token_id: The id of the end of sentence token.
|
|
||||||
:param provided_constraints: A list of constraints to use.
|
|
||||||
"""
|
|
||||||
return self.generate(
|
|
||||||
inputs=input_ids,
|
|
||||||
max_new_tokens=max_length,
|
|
||||||
pad_token_id=pad_token_id,
|
|
||||||
eos_token_id=eos_token_id,
|
|
||||||
decoding_algorithm=GreedyAlgorithm(),
|
|
||||||
provided_constraints=provided_constraints,
|
|
||||||
)
|
|
||||||
|
|
||||||
def sample(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor,
|
|
||||||
temperature: float = 1.0,
|
|
||||||
top_k: Optional[int] = None,
|
|
||||||
top_p: Optional[float] = None,
|
|
||||||
max_length: Optional[int] = None,
|
|
||||||
pad_token_id: Optional[int] = None,
|
|
||||||
eos_token_id: Optional[int] = None,
|
|
||||||
provided_constraints: List[ABCBloomConstraint] = [],
|
|
||||||
) -> torch.LongTensor:
|
|
||||||
"""
|
|
||||||
Generates sequences of token ids for models with a language modeling head. Uses multinomial sampling.
|
|
||||||
If top_k is provided, uses top_k sampling. If top_p is provided, uses nucleus sampling.
|
|
||||||
|
|
||||||
:param: input_ids: The input tokens to the model.
|
|
||||||
:param: temperature: The temperature to use for sampling.
|
|
||||||
:param: top_k: The number of samples to use for top_k sampling.
|
|
||||||
:param: top_p: The probability of using top_p sampling.
|
|
||||||
:param: max_length: The maximum length of the sequence to generate.
|
|
||||||
:param: pad_token_id: The id of the padding token.
|
|
||||||
:param: eos_token_id: The id of the end of sentence token.
|
|
||||||
:param: provided_constraints: A list of constraints to use.
|
|
||||||
"""
|
|
||||||
|
|
||||||
return self.generate(
|
|
||||||
inputs=input_ids,
|
|
||||||
max_new_tokens=max_length,
|
|
||||||
pad_token_id=pad_token_id,
|
|
||||||
eos_token_id=eos_token_id,
|
|
||||||
decoding_algorithm=self._choose_sample_algorithm(temperature, top_k, top_p),
|
|
||||||
provided_constraints=provided_constraints,
|
|
||||||
)
|
|
||||||
|
|
||||||
def beam_search(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor,
|
|
||||||
num_beams: int = 1,
|
|
||||||
max_length: Optional[int] = None,
|
|
||||||
pad_token_id: Optional[int] = None,
|
|
||||||
eos_token_id: Optional[int] = None,
|
|
||||||
provided_constraints: List[ABCBloomConstraint] = [],
|
|
||||||
) -> torch.LongTensor:
|
|
||||||
"""
|
|
||||||
Generates sequences of token ids for models with a language modeling head. Uses beam search.
|
|
||||||
|
|
||||||
:param input_ids: The input tokens to the model.
|
|
||||||
:param num_beams: The number of beams to use.
|
|
||||||
:param max_length: The maximum length of the sequence to generate.
|
|
||||||
:param pad_token_id: The id of the padding token.
|
|
||||||
:param eos_token_id: The id of the end of sentence token.
|
|
||||||
:param provided_constraints: A list of constraints to use.
|
|
||||||
"""
|
|
||||||
decoding_algorithm = BeamSearchAlgorithm(
|
|
||||||
num_beams=num_beams,
|
|
||||||
batch_size=input_ids.size(0),
|
|
||||||
)
|
|
||||||
return self.generate(
|
|
||||||
inputs=input_ids,
|
|
||||||
num_beams=num_beams,
|
|
||||||
max_new_tokens=max_length,
|
|
||||||
pad_token_id=pad_token_id,
|
|
||||||
eos_token_id=eos_token_id,
|
|
||||||
decoding_algorithm=decoding_algorithm,
|
|
||||||
provided_constraints=provided_constraints,
|
|
||||||
)
|
|
||||||
|
|
||||||
def beam_sample(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor,
|
|
||||||
max_length: Optional[int] = None,
|
|
||||||
pad_token_id: Optional[int] = None,
|
|
||||||
eos_token_id: Optional[int] = None,
|
|
||||||
provided_constraints: List[ABCBloomConstraint] = [],
|
|
||||||
) -> torch.LongTensor:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def group_beam_search(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor,
|
|
||||||
max_length: Optional[int] = None,
|
|
||||||
pad_token_id: Optional[int] = None,
|
|
||||||
eos_token_id: Optional[int] = None,
|
|
||||||
provided_constraints: List[ABCBloomConstraint] = [],
|
|
||||||
) -> torch.LongTensor:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def _choose_sample_algorithm(
|
|
||||||
self,
|
|
||||||
temperature: float = 1.0,
|
|
||||||
top_k: Optional[int] = None,
|
|
||||||
top_p: Optional[float] = None,
|
|
||||||
) -> DecodingAlgorithm:
|
|
||||||
if (top_k is not None) and (top_p is not None):
|
|
||||||
raise ValueError("You have to provide only top_k or top_p for sampling")
|
|
||||||
if top_k is not None:
|
|
||||||
return TopKAlgorithm(top_k, temperature)
|
|
||||||
elif top_p is not None:
|
|
||||||
return NucleusAlgorithm(top_p, temperature)
|
|
||||||
else:
|
|
||||||
return SamplingAlgorithm(temperature)
|
|
||||||
|
|
||||||
def _get_constraints(
|
@staticmethod
|
||||||
self,
|
def _reorder_cache(past_key_values: RemotePastKeyValues, beam_idx: torch.LongTensor) -> RemotePastKeyValues:
|
||||||
inputs: Optional[torch.Tensor] = None,
|
return dataclasses.replace(past_key_values, hypo_ids=beam_idx)
|
||||||
eos_token_id: Optional[int] = None,
|
|
||||||
pad_token_id: Optional[int] = None,
|
|
||||||
provided_constraints: List[ABCBloomConstraint] = [],
|
|
||||||
) -> List[ABCBloomConstraint]:
|
|
||||||
constraints = []
|
|
||||||
constraints.extend(provided_constraints)
|
|
||||||
constraints.append(EosConstraint(inputs, eos_token_id, pad_token_id))
|
|
||||||
return constraints
|
|
||||||
|
@ -1,128 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
TokenIds = torch.Tensor
|
|
||||||
HypoIds = torch.Tensor
|
|
||||||
|
|
||||||
|
|
||||||
class DecodingAlgorithm(ABC):
|
|
||||||
"""
|
|
||||||
An abstract class for decoding algorithms. Describes the base function of those algorithms:
|
|
||||||
they have to select new tokens and provide the corresponding hypotheses.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
|
||||||
"""
|
|
||||||
:param logits: A tensor of shape (batch_size, seq_length, vocab_size)
|
|
||||||
:return: A tuple of selected token ids and corresponding hypotheses.
|
|
||||||
The shape of the token ids is (batch_size, seq_length), and the shape of the hypotheses is (batch_size)
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class GreedyAlgorithm(DecodingAlgorithm):
|
|
||||||
"""
|
|
||||||
The simplest algorithm for decoding. It selects the most probable token.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
|
||||||
"""
|
|
||||||
Returns the most probable token. The second returned object is always a range of integers
|
|
||||||
from 0 to batch_size - 1.
|
|
||||||
"""
|
|
||||||
return logits.max(-1)[1].unsqueeze(1), torch.arange(logits.size(0))
|
|
||||||
|
|
||||||
|
|
||||||
class SamplingAlgorithm(DecodingAlgorithm):
|
|
||||||
def __init__(self, temperature: float = 1.0):
|
|
||||||
self.temperature = temperature
|
|
||||||
|
|
||||||
def sample(self, logits: torch.Tensor, indices_to_remove: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
|
||||||
"""
|
|
||||||
:param logits: A tensor of shape (batch_size * num_hypos, vocab_size)
|
|
||||||
:param indices_to_remove: A bool tensor of shape (batch_size * num_hypos, vocab_size)
|
|
||||||
:return: A tuple of selected token ids and corresponding hypotheses.
|
|
||||||
The shape of the token ids is (batch_size, seq_length), and the shape of the hypotheses is (batch_size).
|
|
||||||
"""
|
|
||||||
logits[indices_to_remove] = -float("Inf")
|
|
||||||
probs = torch.softmax(logits / self.temperature, -1)
|
|
||||||
return torch.multinomial(probs, num_samples=1), torch.arange(logits.size(0))
|
|
||||||
|
|
||||||
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
|
||||||
indices_to_remove = torch.full_like(logits, False, dtype=torch.bool)
|
|
||||||
return self.sample(logits, indices_to_remove)
|
|
||||||
|
|
||||||
|
|
||||||
class TopKAlgorithm(SamplingAlgorithm):
|
|
||||||
def __init__(self, top_k: int, temperature: float = 1.0) -> None:
|
|
||||||
super().__init__(temperature=temperature)
|
|
||||||
self.top_k = top_k
|
|
||||||
|
|
||||||
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
|
||||||
indices_to_remove = logits < torch.topk(logits, self.top_k, dim=-1)[0][..., -1, None]
|
|
||||||
return self.sample(logits, indices_to_remove)
|
|
||||||
|
|
||||||
|
|
||||||
class NucleusAlgorithm(SamplingAlgorithm):
|
|
||||||
def __init__(self, top_p: float, temperature: float = 1.0) -> None:
|
|
||||||
super().__init__(temperature=temperature)
|
|
||||||
self.top_p = top_p
|
|
||||||
|
|
||||||
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
|
||||||
sorted_logits, sorted_indices = torch.sort(logits, descending=False, dim=-1)
|
|
||||||
probs = torch.softmax(sorted_logits / self.temperature, -1)
|
|
||||||
cumulative_probs = torch.cumsum(probs, dim=-1)
|
|
||||||
|
|
||||||
sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p)
|
|
||||||
|
|
||||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
|
||||||
return self.sample(logits, indices_to_remove)
|
|
||||||
|
|
||||||
|
|
||||||
class BeamSearchAlgorithm(DecodingAlgorithm):
|
|
||||||
def __init__(self, num_beams: int, batch_size: int) -> None:
|
|
||||||
self.num_beams = num_beams
|
|
||||||
self.batch_size = batch_size
|
|
||||||
|
|
||||||
self._batch_beams = [list() for _ in range(batch_size)]
|
|
||||||
|
|
||||||
def __call__(self, logits: torch.Tensor):
|
|
||||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
|
||||||
probs = torch.log_softmax(sorted_logits, -1)
|
|
||||||
|
|
||||||
if len(self._batch_beams[0]) > 0:
|
|
||||||
for batch_idx in range(self.batch_size):
|
|
||||||
new_beams = []
|
|
||||||
cur_beams = self._batch_beams[batch_idx]
|
|
||||||
for beam_idx in range(len(cur_beams)):
|
|
||||||
probs_idx = batch_idx + beam_idx * self.batch_size
|
|
||||||
new_beam = cur_beams[beam_idx]
|
|
||||||
for hypo_idx in range(self.num_beams):
|
|
||||||
new_beams.append(
|
|
||||||
(new_beam[0] + probs[probs_idx, hypo_idx].item(), beam_idx * self.num_beams + hypo_idx)
|
|
||||||
)
|
|
||||||
self._batch_beams[batch_idx] = sorted(new_beams, reverse=True)[: self.num_beams]
|
|
||||||
else:
|
|
||||||
for batch_idx in range(self.batch_size):
|
|
||||||
for beam_idx in range(self.num_beams):
|
|
||||||
self._batch_beams[batch_idx].append((probs[batch_idx, beam_idx].item(), beam_idx))
|
|
||||||
|
|
||||||
return_hypos = []
|
|
||||||
return_tokens = []
|
|
||||||
for batch_idx in range(self.batch_size):
|
|
||||||
cur_beam = self._batch_beams[batch_idx]
|
|
||||||
return_hypos.append(list())
|
|
||||||
return_tokens.append(list())
|
|
||||||
for beam in cur_beam:
|
|
||||||
beam_idx = beam[1] // self.num_beams
|
|
||||||
hypo_idx = batch_idx + beam_idx * self.batch_size
|
|
||||||
token_idx = beam[1] % self.num_beams
|
|
||||||
return_hypos[-1].append(hypo_idx)
|
|
||||||
return_tokens[-1].append([sorted_indices[hypo_idx, token_idx].item()])
|
|
||||||
return_hypos = [hypo_idx for hypo_indexes in zip(*return_hypos) for hypo_idx in hypo_indexes]
|
|
||||||
return_tokens = [token_idx for token_indexes in zip(*return_tokens) for token_idx in token_indexes]
|
|
||||||
|
|
||||||
return torch.tensor(return_tokens), torch.tensor(return_hypos)
|
|
@ -1,51 +0,0 @@
|
|||||||
from abc import ABC
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
class ABCBloomConstraint(ABC):
|
|
||||||
"""
|
|
||||||
Base class of all kind of decoding constraints. It can be used to implement a new constraint.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
This method is called by the decoding algorithm to apply the constraint. It changes and returns new logits.
|
|
||||||
:param tokens_id: The token id of the last chosen token.
|
|
||||||
:param logits: The logits from the Bloom model.
|
|
||||||
:param hypo_ids: The hypothesis ids of the last tokens.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class EosConstraint(ABCBloomConstraint):
|
|
||||||
"""
|
|
||||||
This constrained repeats EOS token if it was generated on the previous step.
|
|
||||||
Args:
|
|
||||||
prefix: The prefix of the sequence.
|
|
||||||
eos_token_id: The id of the end of sentence token.
|
|
||||||
pad_token_id: The id of the padding token.
|
|
||||||
min_logits: The minimum logits that can be generated. Default: -1e6.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, prefix: torch.Tensor, eos_token_id: int, pad_token_id: int, min_logits: float = -1e8) -> None:
|
|
||||||
self.eos_token_id = eos_token_id
|
|
||||||
self.min_logits = min_logits
|
|
||||||
self.past_tokens = None
|
|
||||||
|
|
||||||
self.wait_until_starting = (prefix == pad_token_id).sum(1).unsqueeze(1)
|
|
||||||
|
|
||||||
def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor:
|
|
||||||
if self.past_tokens is not None:
|
|
||||||
mask = (self.wait_until_starting < 0) & (self.past_tokens == self.eos_token_id)
|
|
||||||
logits += self.min_logits * mask
|
|
||||||
logits[mask[:, 0], self.eos_token_id] = 0
|
|
||||||
|
|
||||||
if tokens_id is not None:
|
|
||||||
self.past_tokens = tokens_id
|
|
||||||
self.wait_until_starting -= 1
|
|
||||||
|
|
||||||
return logits
|
|
Loading…
Reference in New Issue