Make client compatible with transformers' GenerationMixin (#464)

This PR drops custom generation codes and introduces compatibility with `transformers.GenerationMixin` instead. This includes support for more sampling options (`top_p`, `top_k`, `repetition_penalty` requested in #460) and beam search - all that is now identical to running model with transformers locally.

Most features (excluding beam search and other rarely used stuff) are also compatible with resuming existing sessions.

### Breaking changes

If `.generate()` or forward passes are being run inside an `.inference_session()` context, they now use the opened session by default. So, these snippets are now equivalent:

```python
# Using default session
with model.inference_session(max_length=100):
    output_ids = model.generate(input_ids, max_new_tokens=3)

# Explicitly specifying a session
with model.inference_session(max_length=100) as sess:
    output_ids = model.generate(input_ids, max_new_tokens=3, session=sess)
```

Earlier, the 1st snippet was creating a new session, which is not what most people expected (= such code was most likely to introduce a bug, which is now fixed).
pull/470/head
Alexander Borzunov 9 months ago committed by GitHub
parent 063e94b4c8
commit de2475f31c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -41,6 +41,7 @@ jobs:
pip install .[dev]
- name: Test
run: |
set -x # Print executed commands
export MODEL_NAME="${{ matrix.model }}"
export REF_NAME="${{ matrix.model }}"
export ADAPTER_NAME="${{ matrix.model == 'bigscience/bloom-560m' && 'artek0chumak/bloom-560m-safe-peft' || '' }}"

@ -3,7 +3,7 @@ import json
import os
import re
import tempfile
import threading
from contextvars import ContextVar
from typing import List, Optional, Tuple, Union
import torch
@ -47,18 +47,16 @@ class FromPretrainedMixin:
)
_shard_config = threading.local()
_shard_config.ignored_keys = None
_ignored_keys = ContextVar("ignored_keys", default=None)
@contextlib.contextmanager
def ignore_keys(patterns: List[str]):
token = _ignored_keys.set(patterns)
try:
prev_patterns = _shard_config.ignored_keys
_shard_config.ignored_keys = patterns
yield
finally:
_shard_config.ignored_keys = prev_patterns
_ignored_keys.reset(token)
def patched_get_checkpoint_shard_files(
@ -66,7 +64,7 @@ def patched_get_checkpoint_shard_files(
) -> Tuple[List[str], dict]:
"""Same as modeling_utils.get_checkpoint_shard_files(), but does not download shards for the ignored keys."""
should_ignore_keys = _shard_config.ignored_keys is not None
should_ignore_keys = _ignored_keys.get() is not None
tempdir_ctx = tempfile.TemporaryDirectory() if should_ignore_keys else contextlib.nullcontext()
with tempdir_ctx as tempdir:
if should_ignore_keys:
@ -77,7 +75,7 @@ def patched_get_checkpoint_shard_files(
index["weight_map"] = {
param_name: filename
for param_name, filename in index["weight_map"].items()
if all(re.search(pattern, param_name) is None for pattern in _shard_config.ignored_keys)
if all(re.search(pattern, param_name) is None for pattern in _ignored_keys.get())
}
n_loaded_shards = len(set(index["weight_map"].values()))
logger.debug(f"Loading {n_loaded_shards} shards out of {n_original_shards}")

@ -230,7 +230,7 @@ class InferenceSession:
self._server_sessions = []
self._position = 0
self._max_length = max_length
self.last_token_id = None
self.output_ids = None
@property
def num_blocks(self) -> int:
@ -377,3 +377,13 @@ class InferenceSession:
def __del__(self):
self.close()
@property
def last_token_id(self) -> Optional[torch.Tensor]: # Backward compatibility with Petals < 2.1.0
return self.output_ids[:, -1:] if self.output_ids is not None else None
@last_token_id.setter
def last_token_id(self, value: torch.Tensor): # Backward compatibility with Petals < 2.1.0
if self.output_ids is None:
raise RuntimeError("Can't override `last_token_id` since the session has not stepped yet")
self.output_ids[:, -1:] = value

@ -70,8 +70,8 @@ class LMHead(nn.Module):
if not self._bf16_warning_shown:
if self.weight.numel() * 4 < 0.9 * psutil.virtual_memory().total:
logger.warning(
"Running the client with dtype bfloat16 on CPU may be slow, since your CPU doesn't support AVX512. "
"Consider loading the model with torch_dtype='float32'"
"Running the model in bfloat16 on CPU will be slow since your CPU does not support AVX512. "
"To speed it up, load the model in float32 using .from_pretrained(..., torch_dtype=torch.float32)"
)
self._bf16_warning_shown = True

@ -76,9 +76,9 @@ def force_non_empty_weights():
[1] https://github.com/huggingface/transformers/blob/ab9fe45236cd99b8797df78219438f8f6662bb42/src/transformers/modeling_utils.py#L2515
"""
possibly_patched_register_parameter = nn.Module.register_parameter
nn.Module.register_parameter = _original_register_parameter
try:
possibly_patched_register_parameter = nn.Module.register_parameter
nn.Module.register_parameter = _original_register_parameter
yield
finally:
nn.Module.register_parameter = possibly_patched_register_parameter

@ -1,349 +1,142 @@
import contextlib
from typing import List, Optional
import dataclasses
from contextvars import ContextVar
from typing import ContextManager, List, Optional
import torch
import transformers
from hivemind.utils.logging import get_logger
from transformers.generation.utils import ModelOutput
from petals.client.inference_session import InferenceSession
from petals.utils.generation_algorithms import (
BeamSearchAlgorithm,
DecodingAlgorithm,
GreedyAlgorithm,
NucleusAlgorithm,
SamplingAlgorithm,
TopKAlgorithm,
)
from petals.utils.generation_constraints import ABCBloomConstraint, EosConstraint
from petals.client.remote_sequential import RemoteSequential
from petals.utils.misc import DUMMY, docstring_from
logger = get_logger(__name__)
class RemoteGenerationMixin:
"""
A class containing all functions for auto-regressive text generation, to be used as a mixin in [`BloomForCausalLM`].
The class exposes can be used for:
- *greedy decoding*.
- *multinomial, top-k and top-p sampling*.
- *beam-search decoding*
This class is similar to transformer's [`generation_utils.GenerationMixin`], it can be used instead of it.
However, it has some differences for remote usage.
"""
def inference_session(self, **kwargs) -> InferenceSession:
"""
Returns an inference session for the model's RemoteSequential module.
@dataclasses.dataclass(frozen=True)
class RemotePastKeyValues:
"""A mock class representing the fact that `past_key_values` do exist but are stored on remote servers."""
:param max_length: Maximal expected length of inference results. Servers use this parameter
to calculate the size of attention caches allocated to this client.
"""
hypo_ids: Optional[torch.LongTensor] = None
return self.transformer.h.inference_session(**kwargs)
def __getitem__(self, _index: int) -> List[torch.Tensor]:
return [DUMMY] # For compatibility with BloomForCausalLM.prepare_inputs_for_generation()
@torch.inference_mode()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
*,
do_sample: Optional[bool] = None,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
num_beams: Optional[int] = 1,
bos_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
max_length: Optional[int] = None,
max_new_tokens: Optional[int] = None,
decoding_algorithm: Optional[DecodingAlgorithm] = None,
provided_constraints: List[ABCBloomConstraint] = [],
num_return_sequences: Optional[int] = None,
session: Optional[InferenceSession] = None,
) -> torch.LongTensor:
"""
Generates sequences of token ids for models with a language modeling head.
:param inputs: The input tokens to the model.
:param do_sample: Whether to sample from the model predictions or take the argmax.
:param temperature: The temperature to use for sampling.
:param top_k: The number of results to return.
:param top_p: The cumulative probability of results to return.
:param num_beams: The number of beams to use for beam search.
:param bos_token_id: The id of the beginning of sentence token.
:param eos_token_id: The id of the end of sentence token.
:param pad_token_id: The id of the padding token.
:param max_length: The maximum number of tokens in the output (including input tokens).
:param max_new_tokens: The maximum number of tokens to generate.
:param decoding_algorithm: The decoding algorithm to use.
:param provided_constraints: A list of constraints to use.
:param num_return_sequences: How many hypothesis from the beam will be in output.
"""
_skipped_tokens = ContextVar("skipped_tokens", default=0)
prefix_length = 0 if inputs is None else inputs.size(1)
prefix_length += self.config.pre_seq_len
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
class _SkipTokensMixin:
# This override is used in RemoteGenerationMixin by has to be defined in a class not named as "GenerationMixin"
# due to how transformers.PreTrainedModel.can_generate() works
def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> dict:
input_ids = input_ids[:, _skipped_tokens.get() :]
_skipped_tokens.set(0)
return super().prepare_inputs_for_generation(input_ids, **kwargs)
assert (max_length is None) != (max_new_tokens is None), "please set max_length or max_new_tokens (not both)"
if max_length is not None and max_new_tokens is None:
max_new_tokens = max_length - prefix_length
assert max_new_tokens > 0, f"Provided max_length is less than prefix size: {max_length} < {inputs.size(1)}"
elif max_length is None and max_new_tokens is not None:
max_length = prefix_length + max_new_tokens
resuming_session = session is not None and session.last_token_id is not None
if num_beams > 1 and resuming_session:
raise NotImplementedError(
"Resuming inference session in .generate() along with beam search is not supported yet"
)
class RemoteGenerationMixin(_SkipTokensMixin):
"""
This class is an upgrade to `transformers.GenerationMixin` that:
- Designed to be compatible with most `transformers.GenerationMixin` strategies and options
- Supports generation inside a remote InferenceSession, so that remote servers store your attention caches and
you don't have to rerun the prefix through all the servers to generate each new token
- Supports multiple `.generate()` calls inside one InferenceSession, so you can easily run interactive generation
by showing tokens on the fly (multiple calls like `.generate(None, max_new_tokens=1, ...)`) or
accept prompts from a user in a chat bot (multiple calls like `.generate(new_prompts, ...)`).
- If there is no active session, `.generate()` will create a new InferenceSession with proper `max_length`.
Otherwise, `.generate()` will use the active session. You can use the `session=...` argument to override that.
"""
if inputs is not None:
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 2, "inputs must be a 2d tensor [batch, length]"
if resuming_session:
inputs = torch.cat([session.last_token_id, inputs], dim=1)
else:
if resuming_session:
inputs = session.last_token_id
else:
assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs"
inputs = torch.tensor([[bos_token_id]] * num_beams, dtype=torch.long, device=self.device)
batch_size = inputs.size(0)
@docstring_from(RemoteSequential.active_session)
@property
def active_session(self) -> Optional[InferenceSession]:
return self.transformer.h.active_session
if decoding_algorithm is None:
if do_sample:
decoding_algorithm = self._choose_sample_algorithm(temperature, top_k, top_p)
elif num_beams is not None and num_beams > 1:
decoding_algorithm = BeamSearchAlgorithm(num_beams, batch_size=batch_size)
else:
if top_k is not None or top_p is not None:
logger.warning("You passed top_k or top_p but did not pass do_sample=True. Running greedy sampling")
decoding_algorithm = GreedyAlgorithm()
@docstring_from(RemoteSequential.use_session)
def use_session(self, session: Optional[InferenceSession]) -> ContextManager[InferenceSession]:
return self.transformer.h.use_session(session)
if num_beams > 1:
inputs = torch.cat([inputs] * num_beams, dim=0)
if batch_size > 1:
# TODO: resolve padding problem
logger.warning(
f"You set batch_size {batch_size} within beam search generation. "
f"Be careful, results on sequences with different length may be padded wrong way"
)
@docstring_from(RemoteSequential.inference_session)
def inference_session(self, **kwargs) -> ContextManager[InferenceSession]:
return self.transformer.h.inference_session(**kwargs)
if num_return_sequences is None:
num_return_sequences = 1
@docstring_from(transformers.GenerationMixin.generate.__doc__)
def generate(
self, inputs: Optional[torch.Tensor] = None, *args, session: Optional[InferenceSession] = None, **kwargs
):
self._fix_generate_kwargs(kwargs)
if session is not None:
# If a session specified explicitly, use it
context_manager = self.use_session(session)
elif self.active_session is not None:
# If there's an active session, don't do anything
context_manager = contextlib.nullcontext(self.active_session)
else:
# If there's no active session, create a new one
assert num_return_sequences <= num_beams, (
f"You want more sequences than the beam has."
" Check num_return_sequences: {num_return_sequences} and num_beams: {num_beams}."
)
max_length = kwargs.get("max_length")
max_new_tokens = kwargs.get("max_new_tokens")
assert (max_length is None) != (
max_new_tokens is None
), "You should set `max_length` or `max_new_tokens` (but not both) to reserve server-side attention caches"
constraints = self._get_constraints(
inputs=inputs,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
provided_constraints=provided_constraints,
)
if max_length is not None:
session_max_length = max_length
else:
session_max_length = (inputs.shape[1] if inputs is not None else 0) + max_new_tokens
context_manager = self.inference_session(max_length=session_max_length)
if session is None:
context_manager = self.inference_session(max_length=max_length)
else:
context_manager = contextlib.nullcontext(session) # Doesn't actually enter session or exit from it
with context_manager as session:
outputs = []
# Find samples with padded inputs.
# They will be changed before all of the samples have right length.
if torch.any(inputs == pad_token_id): # TODO: move to prepare_inputs
outputs += [inputs[:, : inputs.size(1) - (inputs == pad_token_id).sum(-1).max()]]
# Prepend the tokens from the previous .generate() call
n_prev_tokens = session.output_ids.shape[1] if session.output_ids is not None else 0
if n_prev_tokens > 0:
if kwargs.get("num_beams", 1) > 1:
logger.warning(
"Beam search will not work properly in the resumed petals.InferenceSession "
"since intermediate beam entries are lost"
)
if inputs is not None:
inputs = torch.cat([session.output_ids, inputs], dim=1)
else:
inputs = session.output_ids
# Don't actually run all previous tokens through the transformer,
# but keep them for transformers.GenerationMixin (e.g., to compute repetition_penalty)
_skipped_tokens.set(max(0, n_prev_tokens - 1))
result = super().generate(inputs, *args, **kwargs)
sequences = result.sequences if isinstance(result, ModelOutput) else result
# Save tokens from this .generate() call
session.output_ids = sequences
# Crop the last tokens from the previous call
sequences = sequences[:, n_prev_tokens:].clone()
if isinstance(result, ModelOutput):
result.sequences = sequences
else:
outputs += [inputs]
last_token_id = None
seq_idx = outputs[0].size(1)
hypo_ids = torch.arange(outputs[0].size(0))
while True:
hidden_state = self.transformer.word_embeddings(outputs[-1])
intermediate_prompts = None
if self.config.pre_seq_len > 0 and len(outputs) == 1:
prompts, intermediate_prompts = self.transformer.get_prompt(hidden_state.size(0))
hidden_state = torch.cat([prompts, hidden_state], dim=1)
hidden_state = self.transformer.word_embeddings_layernorm(hidden_state)
result = sequences
hidden_state = session.step(hidden_state, prompts=intermediate_prompts, hypo_ids=hypo_ids)[:, -1]
return result
hidden_state = self.transformer.ln_f(hidden_state)
lm_logits = self.lm_head(hidden_state)
@staticmethod
def _fix_generate_kwargs(kwargs: dict) -> dict:
# Suppress inappropriate "Both max_new_tokens and max_length" HF warning
if "max_length" in kwargs and kwargs["max_length"] is None:
del kwargs["max_length"]
for constraint in constraints:
lm_logits = constraint(last_token_id, lm_logits, hypo_ids)
last_token_id, hypo_ids = decoding_algorithm(lm_logits)
# Support do_sample = {0, 1} for backward compatibility with Petals < 2.1.0
do_sample = kwargs.get("do_sample")
if isinstance(do_sample, int):
kwargs["do_sample"] = bool(do_sample)
# If some samples were padded, change only these samples
if seq_idx < inputs.size(1):
pad_token_mask = inputs[:, seq_idx : seq_idx + 1] == pad_token_id
last_token_id = (~pad_token_mask) * inputs[
:, seq_idx : seq_idx + 1
] + pad_token_mask * last_token_id
# TODO: refactor outputs
if num_beams > 1:
for i in range(len(outputs), 1, -1):
outputs[i - 1] = outputs[i - 1][hypo_ids]
outputs.append(last_token_id)
session.last_token_id = last_token_id
seq_idx += 1
if torch.all(last_token_id == eos_token_id) or len(outputs) > max_new_tokens:
break
outputs = torch.cat(outputs, dim=-1)
if resuming_session:
outputs = outputs[:, 1:]
if num_beams > 1:
pre_return_idx = [
torch.arange(idx, num_return_sequences * batch_size, batch_size) for idx in range(batch_size)
]
return_idx = torch.cat(pre_return_idx, dim=0)
outputs = outputs[return_idx]
return outputs
def greedy_search(
self,
input_ids: torch.LongTensor,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
provided_constraints: List[ABCBloomConstraint] = [],
) -> torch.LongTensor:
"""
Generates sequences of token ids for models with a language modeling head. Uses greedy search.
:param input_ids: The input tokens to the model.
:param max_length: The maximum length of the sequence to generate.
:param pad_token_id: The id of the padding token.
:param eos_token_id: The id of the end of sentence token.
:param provided_constraints: A list of constraints to use.
"""
return self.generate(
inputs=input_ids,
max_new_tokens=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
decoding_algorithm=GreedyAlgorithm(),
provided_constraints=provided_constraints,
)
def sample(
self,
input_ids: torch.LongTensor,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
provided_constraints: List[ABCBloomConstraint] = [],
) -> torch.LongTensor:
"""
Generates sequences of token ids for models with a language modeling head. Uses multinomial sampling.
If top_k is provided, uses top_k sampling. If top_p is provided, uses nucleus sampling.
:param: input_ids: The input tokens to the model.
:param: temperature: The temperature to use for sampling.
:param: top_k: The number of samples to use for top_k sampling.
:param: top_p: The probability of using top_p sampling.
:param: max_length: The maximum length of the sequence to generate.
:param: pad_token_id: The id of the padding token.
:param: eos_token_id: The id of the end of sentence token.
:param: provided_constraints: A list of constraints to use.
"""
return self.generate(
inputs=input_ids,
max_new_tokens=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
decoding_algorithm=self._choose_sample_algorithm(temperature, top_k, top_p),
provided_constraints=provided_constraints,
)
def beam_search(
self,
input_ids: torch.LongTensor,
num_beams: int = 1,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
provided_constraints: List[ABCBloomConstraint] = [],
) -> torch.LongTensor:
"""
Generates sequences of token ids for models with a language modeling head. Uses beam search.
:param input_ids: The input tokens to the model.
:param num_beams: The number of beams to use.
:param max_length: The maximum length of the sequence to generate.
:param pad_token_id: The id of the padding token.
:param eos_token_id: The id of the end of sentence token.
:param provided_constraints: A list of constraints to use.
"""
decoding_algorithm = BeamSearchAlgorithm(
num_beams=num_beams,
batch_size=input_ids.size(0),
)
return self.generate(
inputs=input_ids,
num_beams=num_beams,
max_new_tokens=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
decoding_algorithm=decoding_algorithm,
provided_constraints=provided_constraints,
)
def beam_sample(
self,
input_ids: torch.LongTensor,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
provided_constraints: List[ABCBloomConstraint] = [],
) -> torch.LongTensor:
raise NotImplementedError
def group_beam_search(
self,
input_ids: torch.LongTensor,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
provided_constraints: List[ABCBloomConstraint] = [],
) -> torch.LongTensor:
raise NotImplementedError
def _choose_sample_algorithm(
self,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
) -> DecodingAlgorithm:
if (top_k is not None) and (top_p is not None):
raise ValueError("You have to provide only top_k or top_p for sampling")
if top_k is not None:
return TopKAlgorithm(top_k, temperature)
elif top_p is not None:
return NucleusAlgorithm(top_p, temperature)
else:
return SamplingAlgorithm(temperature)
return kwargs
def _get_constraints(
self,
inputs: Optional[torch.Tensor] = None,
eos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
provided_constraints: List[ABCBloomConstraint] = [],
) -> List[ABCBloomConstraint]:
constraints = []
constraints.extend(provided_constraints)
constraints.append(EosConstraint(inputs, eos_token_id, pad_token_id))
return constraints
@staticmethod
def _reorder_cache(past_key_values: RemotePastKeyValues, beam_idx: torch.LongTensor) -> RemotePastKeyValues:
return dataclasses.replace(past_key_values, hypo_ids=beam_idx)

@ -1,5 +1,7 @@
from __future__ import annotations
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Optional, Union
import torch
@ -11,7 +13,6 @@ from petals.client.inference_session import InferenceSession
from petals.client.routing import RemoteSequenceManager
from petals.client.sequential_autograd import _RemoteSequentialAutogradFunction
from petals.data_structures import UID_DELIMITER
from petals.utils.misc import DUMMY
logger = get_logger(__name__)
@ -46,11 +47,52 @@ class RemoteSequential(nn.Module):
sequence_manager = RemoteSequenceManager(config, block_uids, dht=dht, **kwargs)
self.sequence_manager = sequence_manager
def forward(self, inputs: torch.Tensor, prompts: torch.Tensor = DUMMY):
self._active_session = ContextVar("active_session", default=None)
def forward(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
assert inputs.ndim == 3, "inputs must be a tensor of shape [batch_size, seq_length, hidden_size]"
assert inputs.shape[1] <= 2048, "The sequence length is capped at 2048 tokens in this version"
outputs = _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager)
return outputs
if self.active_session is None:
assert all(v is None for v in kwargs.values()), f"Extra kwargs are not supported in forward: {kwargs}"
return _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager)
else:
return self.active_session.step(inputs, prompts, **kwargs)
@property
def active_session(self) -> Optional[InferenceSession]:
"""
If called inside `with model.inference_session(...):` or `with model.use_session(...):`,
returns an active InferenceSession. Otherwise, returns None.
"""
return self._active_session.get()
@property
def position(self) -> int:
"""Returns the prefix length (in tokens) in the active inference session or zero if no session is active."""
return self.active_session.position if self.active_session is not None else 0
@contextmanager
def use_session(self, session: Optional[InferenceSession]) -> InferenceSession:
"""Inside this context, forward() will use an _existing_ InferenceSession provided as the argument."""
token = self._active_session.set(session)
try:
yield session
finally:
self._active_session.reset(token)
@contextmanager
def inference_session(self, **kwargs) -> InferenceSession:
"""
Inside this context, forward() will use a _new_ InferenceSession created with given parameters.
:param max_length: Maximal expected length of inference results. Servers use this parameter
to calculate the size of attention caches allocated to this client.
"""
with InferenceSession(self.sequence_manager, **kwargs) as session, self.use_session(session):
yield session
def __getitem__(self, ix: Union[int, slice]) -> RemoteSequential:
return RemoteSequential(
@ -65,8 +107,5 @@ class RemoteSequential(nn.Module):
def __len__(self):
return len(self.sequence_manager)
def inference_session(self, **kwargs) -> InferenceSession:
return InferenceSession(self.sequence_manager, **kwargs)
def extra_repr(self) -> str:
return f"modules={self.sequence_manager.block_uids[0]}..{self.sequence_manager.block_uids[-1]}"

@ -230,7 +230,7 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
def forward(ctx, inputs: torch.Tensor, prompts: torch.Tensor, sequence_manager: RemoteSequenceManager):
batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1)
input_batches: Sequence[torch.Tensor] = inputs.detach().split(batch_size)
if is_dummy(prompts):
if prompts is None or is_dummy(prompts):
prompt_batches = [DUMMY] * len(input_batches)
else:
prompt_batches: Sequence[torch.Tensor] = prompts.detach().split(batch_size, dim=1)

@ -10,7 +10,7 @@ from transformers.models.bloom import BloomForCausalLM, BloomForSequenceClassifi
from petals.client.from_pretrained import FromPretrainedMixin
from petals.client.lm_head import LMHead
from petals.client.ptune import PTuneMixin
from petals.client.remote_generation import RemoteGenerationMixin
from petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues
from petals.client.remote_sequential import RemoteSequential
from petals.models.bloom.config import DistributedBloomConfig
@ -39,16 +39,15 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
past_key_values: Optional[RemotePastKeyValues] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
head_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
assert attention_mask is None, f"{self.__class__.__name__} does not support attention masks right now"
for k, v in kwargs.items():
if not (v is None or v is False):
logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})")
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
@ -59,21 +58,34 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
# The causal mask will be added on the server-side
assert (
attention_mask is None or (attention_mask == 1).all()
), f"Custom attention masks are not supported, {attention_mask=}"
assert head_mask is None, f"Custom head masks are not supported, {head_mask=}"
assert use_cache is None or use_cache, f"{use_cache=} is not supported"
assert not output_attentions, f"{output_attentions=} is not supported"
assert not output_hidden_states, f"{output_hidden_states=} is not supported"
assert return_dict is None or return_dict, f"{return_dict=} is not supported"
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
if self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0:
batch_size = inputs_embeds.shape[0]
prompts, intermediate_prompts = self.get_prompt(batch_size)
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
else:
prompts = intermediate_prompts = None
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
output_shape = input_shape + (hidden_states.size(-1),)
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
hidden_states = self.h(hidden_states, prompts=intermediate_prompts)
else:
hidden_states = self.h(hidden_states)
hidden_states = self.h(
hidden_states,
prompts=intermediate_prompts,
hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None,
)
# Remove prefix
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
@ -84,7 +96,7 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
hidden_states = hidden_states.view(output_shape)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=None,
past_key_values=RemotePastKeyValues(),
hidden_states=None,
attentions=None,
)

@ -10,7 +10,7 @@ from transformers.models.llama import LlamaForCausalLM, LlamaForSequenceClassifi
from petals.client.from_pretrained import FromPretrainedMixin
from petals.client.lm_head import LMHead
from petals.client.ptune import PTuneMixin
from petals.client.remote_generation import RemoteGenerationMixin
from petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues
from petals.client.remote_sequential import RemoteSequential
from petals.models.llama.config import DistributedLlamaConfig
@ -39,16 +39,15 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[RemotePastKeyValues] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> BaseModelOutputWithPast:
assert attention_mask is None, f"{self.__class__.__name__} does not support attention masks right now"
for k, v in kwargs.items():
if not (v is None or v is False):
logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})")
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
@ -59,21 +58,36 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
# The causal mask will be added on the server-side
assert (
attention_mask is None or (attention_mask == 1).all()
), f"Custom attention masks are not supported, {attention_mask=}"
assert (
position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all()
), f"Non-consecutive position_ids are not supported, {position_ids=}"
assert use_cache is None or use_cache, f"{use_cache=} is not supported"
assert not output_attentions, f"{output_attentions=} is not supported"
assert not output_hidden_states, f"{output_hidden_states=} is not supported"
assert return_dict is None or return_dict, f"{return_dict=} is not supported"
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
if self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.layers.position == 0:
batch_size = inputs_embeds.shape[0]
prompts, intermediate_prompts = self.get_prompt(batch_size)
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
else:
prompts = intermediate_prompts = None
hidden_states = inputs_embeds
output_shape = input_shape + (hidden_states.size(-1),)
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
hidden_states = self.layers(hidden_states, prompts=intermediate_prompts)
else:
hidden_states = self.layers(hidden_states)
hidden_states = self.layers(
hidden_states,
prompts=intermediate_prompts,
hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None,
)
# Remove prefix
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
@ -84,7 +98,7 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
hidden_states = hidden_states.view(output_shape)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=None,
past_key_values=RemotePastKeyValues(),
hidden_states=None,
attentions=None,
)

@ -196,7 +196,7 @@ async def iterate_rpc_inference(
hypo_ids,
points=point_per_piece,
requested_uids=requested_uids,
type="short_inference" if can_merge_pools else "inference",
type="inference",
)
# A client may pass a tensor with 0 tokens. This is a special case that occurs, e.g.

@ -14,9 +14,7 @@ class TaskPrioritizerBase(ABC):
class DummyTaskPrioritizer(TaskPrioritizerBase):
def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float:
# Inference steps (especially short ones) go first since they are more latency-sensitive
if kwargs.get("type") == "short_inference":
return 1.0
# Inference steps go first since they are more latency-sensitive
if kwargs.get("type") == "inference":
return 2.0
return 3.0 # Forward, backward
return 1.0
return 2.0 # Forward, backward

@ -1,128 +0,0 @@
from abc import ABC, abstractmethod
from typing import Tuple
import torch
TokenIds = torch.Tensor
HypoIds = torch.Tensor
class DecodingAlgorithm(ABC):
"""
An abstract class for decoding algorithms. Describes the base function of those algorithms:
they have to select new tokens and provide the corresponding hypotheses.
"""
@abstractmethod
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
"""
:param logits: A tensor of shape (batch_size, seq_length, vocab_size)
:return: A tuple of selected token ids and corresponding hypotheses.
The shape of the token ids is (batch_size, seq_length), and the shape of the hypotheses is (batch_size)
"""
pass
class GreedyAlgorithm(DecodingAlgorithm):
"""
The simplest algorithm for decoding. It selects the most probable token.
"""
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
"""
Returns the most probable token. The second returned object is always a range of integers
from 0 to batch_size - 1.
"""
return logits.max(-1)[1].unsqueeze(1), torch.arange(logits.size(0))
class SamplingAlgorithm(DecodingAlgorithm):
def __init__(self, temperature: float = 1.0):
self.temperature = temperature
def sample(self, logits: torch.Tensor, indices_to_remove: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
"""
:param logits: A tensor of shape (batch_size * num_hypos, vocab_size)
:param indices_to_remove: A bool tensor of shape (batch_size * num_hypos, vocab_size)
:return: A tuple of selected token ids and corresponding hypotheses.
The shape of the token ids is (batch_size, seq_length), and the shape of the hypotheses is (batch_size).
"""
logits[indices_to_remove] = -float("Inf")
probs = torch.softmax(logits / self.temperature, -1)
return torch.multinomial(probs, num_samples=1), torch.arange(logits.size(0))
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
indices_to_remove = torch.full_like(logits, False, dtype=torch.bool)
return self.sample(logits, indices_to_remove)
class TopKAlgorithm(SamplingAlgorithm):
def __init__(self, top_k: int, temperature: float = 1.0) -> None:
super().__init__(temperature=temperature)
self.top_k = top_k
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
indices_to_remove = logits < torch.topk(logits, self.top_k, dim=-1)[0][..., -1, None]
return self.sample(logits, indices_to_remove)
class NucleusAlgorithm(SamplingAlgorithm):
def __init__(self, top_p: float, temperature: float = 1.0) -> None:
super().__init__(temperature=temperature)
self.top_p = top_p
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
sorted_logits, sorted_indices = torch.sort(logits, descending=False, dim=-1)
probs = torch.softmax(sorted_logits / self.temperature, -1)
cumulative_probs = torch.cumsum(probs, dim=-1)
sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p)
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
return self.sample(logits, indices_to_remove)
class BeamSearchAlgorithm(DecodingAlgorithm):
def __init__(self, num_beams: int, batch_size: int) -> None:
self.num_beams = num_beams
self.batch_size = batch_size
self._batch_beams = [list() for _ in range(batch_size)]
def __call__(self, logits: torch.Tensor):
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
probs = torch.log_softmax(sorted_logits, -1)
if len(self._batch_beams[0]) > 0:
for batch_idx in range(self.batch_size):
new_beams = []
cur_beams = self._batch_beams[batch_idx]
for beam_idx in range(len(cur_beams)):
probs_idx = batch_idx + beam_idx * self.batch_size
new_beam = cur_beams[beam_idx]
for hypo_idx in range(self.num_beams):
new_beams.append(
(new_beam[0] + probs[probs_idx, hypo_idx].item(), beam_idx * self.num_beams + hypo_idx)
)
self._batch_beams[batch_idx] = sorted(new_beams, reverse=True)[: self.num_beams]
else:
for batch_idx in range(self.batch_size):
for beam_idx in range(self.num_beams):
self._batch_beams[batch_idx].append((probs[batch_idx, beam_idx].item(), beam_idx))
return_hypos = []
return_tokens = []
for batch_idx in range(self.batch_size):
cur_beam = self._batch_beams[batch_idx]
return_hypos.append(list())
return_tokens.append(list())
for beam in cur_beam:
beam_idx = beam[1] // self.num_beams
hypo_idx = batch_idx + beam_idx * self.batch_size
token_idx = beam[1] % self.num_beams
return_hypos[-1].append(hypo_idx)
return_tokens[-1].append([sorted_indices[hypo_idx, token_idx].item()])
return_hypos = [hypo_idx for hypo_indexes in zip(*return_hypos) for hypo_idx in hypo_indexes]
return_tokens = [token_idx for token_indexes in zip(*return_tokens) for token_idx in token_indexes]
return torch.tensor(return_tokens), torch.tensor(return_hypos)

@ -1,51 +0,0 @@
from abc import ABC
import torch
class ABCBloomConstraint(ABC):
"""
Base class of all kind of decoding constraints. It can be used to implement a new constraint.
"""
def __init__(self) -> None:
pass
def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor:
"""
This method is called by the decoding algorithm to apply the constraint. It changes and returns new logits.
:param tokens_id: The token id of the last chosen token.
:param logits: The logits from the Bloom model.
:param hypo_ids: The hypothesis ids of the last tokens.
"""
pass
class EosConstraint(ABCBloomConstraint):
"""
This constrained repeats EOS token if it was generated on the previous step.
Args:
prefix: The prefix of the sequence.
eos_token_id: The id of the end of sentence token.
pad_token_id: The id of the padding token.
min_logits: The minimum logits that can be generated. Default: -1e6.
"""
def __init__(self, prefix: torch.Tensor, eos_token_id: int, pad_token_id: int, min_logits: float = -1e8) -> None:
self.eos_token_id = eos_token_id
self.min_logits = min_logits
self.past_tokens = None
self.wait_until_starting = (prefix == pad_token_id).sum(1).unsqueeze(1)
def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor:
if self.past_tokens is not None:
mask = (self.wait_until_starting < 0) & (self.past_tokens == self.eos_token_id)
logits += self.min_logits * mask
logits[mask[:, 0], self.eos_token_id] = 0
if tokens_id is not None:
self.past_tokens = tokens_id
self.wait_until_starting -= 1
return logits

@ -5,5 +5,13 @@ DUMMY = torch.empty(0) # dummy tensor that replaces empty prompt or adapter par
DUMMY_INT64 = torch.empty(0, dtype=torch.int64)
def is_dummy(tensor: torch.Tensor):
def is_dummy(tensor: torch.Tensor) -> bool:
return tensor.numel() == 0
def docstring_from(source):
def add_docstring(dest):
dest.__doc__ = source.__doc__
return dest
return add_docstring

@ -3,7 +3,6 @@ import pytest
import torch
import transformers
from hivemind import get_logger
from transformers.generation import BeamSearchScorer, GenerationMixin as HfGenerationMixin
from petals import AutoDistributedModelForCausalLM
from test_utils import *
@ -17,18 +16,29 @@ def tokenizer():
return transformers.AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
@pytest.fixture
def model():
return AutoDistributedModelForCausalLM.from_pretrained(
MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32
)
@pytest.fixture
def ref_model():
return transformers.AutoModelForCausalLM.from_pretrained(
REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
)
@pytest.mark.forked
@pytest.mark.parametrize("use_peft", (True, False) if ADAPTER_NAME else (False,))
@pytest.mark.parametrize("pass_empty_tensors", (True, False))
def test_full_model_exact_match(tokenizer, use_peft, pass_empty_tensors, atol_forward=1e-3, atol_inference=1e-3):
model = AutoDistributedModelForCausalLM.from_pretrained(
MODEL_NAME,
initial_peers=INITIAL_PEERS,
torch_dtype=torch.float32,
active_adapter=ADAPTER_NAME if use_peft else None,
)
config = model.config
assert len(model.transformer.h) == model.config.num_hidden_layers
def test_full_model_exact_match(tokenizer, model, ref_model, use_peft, pass_empty_tensors, atol=1e-3):
if use_peft:
model.config.active_adapter = ADAPTER_NAME
ref_model = peft.PeftModel.from_pretrained(ref_model, ADAPTER_NAME)
ref_model.train(False)
test_inputs = tokenizer("A quick brown fox was minding its own buisness", return_tensors="pt")["input_ids"]
@ -42,7 +52,7 @@ def test_full_model_exact_match(tokenizer, use_peft, pass_empty_tensors, atol_fo
recurrent_outputs = []
with model.transformer.h.inference_session(max_length=embs.shape[1]) as sess:
if pass_empty_tensors:
recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
recurrent_outputs.append(sess.step(torch.empty(1, 0, model.config.hidden_size)))
for t in range(embs.shape[1]):
if t == 4:
@ -53,52 +63,39 @@ def test_full_model_exact_match(tokenizer, use_peft, pass_empty_tensors, atol_fo
recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
if t == 2 and pass_empty_tensors:
recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
recurrent_outputs.append(sess.step(torch.empty(1, 0, model.config.hidden_size)))
recurrent_outputs.append(sess.step(torch.empty(1, 0, model.config.hidden_size)))
recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
recurrent_outputs = model.transformer.ln_f(recurrent_outputs)
recurrent_outputs = model.lm_head(recurrent_outputs)
assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference)
logger.info("Inference is consistent with forward")
del model, embs, recurrent_outputs
if REF_NAME:
ref_model = transformers.AutoModelForCausalLM.from_pretrained(
REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
)
if use_peft:
ref_model = peft.PeftModel.from_pretrained(ref_model, ADAPTER_NAME)
ref_model.train(False)
if config.vocab_size < ref_model.config.vocab_size:
ref_model.resize_token_embeddings(config.vocab_size)
logger.warning(f"Resized the reference model embeddings, new total = {ref_model.config.vocab_size}")
dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
# note: this creates a dummy mask to make the test compatible with older transformer versions
# prior to https://github.com/huggingface/transformers/pull/17837
ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits.float()
assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward)
logger.warning(f"Distributed forward is consistent with {type(ref_model)}.forward")
del ref_model, ref_outputs, dummy_mask
else:
logger.warning("Did not test exact match with local model: REF_NAME environment variable is not set")
assert False
assert torch.allclose(
recurrent_outputs, parallel_outputs, rtol=0, atol=atol
), "Inference differs from forward pass"
ref_outputs = ref_model.forward(test_inputs).logits.float()
assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol), "Outputs are not identical to HF"
def make_generate_calls(model, inputs, *, max_new_tokens, multiple_calls=False, **kwargs):
if not multiple_calls:
return model.generate(inputs, max_new_tokens=max_new_tokens, **kwargs)
with model.inference_session(max_length=inputs.shape[1] + max_new_tokens) as sess:
return torch.cat(
[
# Sessions provided both explicitly and implicitly should work
model.generate(inputs, max_new_tokens=1, **kwargs, session=sess),
model.generate(None, max_new_tokens=max_new_tokens - 2, **kwargs),
model.generate(None, max_new_tokens=1, **kwargs),
],
dim=1,
)
@pytest.mark.forked
def test_greedy_generation(tokenizer, max_new_tokens=4):
model = AutoDistributedModelForCausalLM.from_pretrained(
MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32
)
inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
remote_outputs = model.generate(
inputs,
max_new_tokens=max_new_tokens,
)
hf_outputs = HfGenerationMixin.greedy_search(model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens)
assert torch.allclose(remote_outputs, hf_outputs), "Greedy search results are not identical to HF"
def test_greedy_generation(tokenizer, model, ref_model, max_new_tokens=4):
inputs_single = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
@ -106,85 +103,49 @@ def test_greedy_generation(tokenizer, max_new_tokens=4):
"input_ids"
]
remote_outputs_batch = model.generate(
inputs_batch,
max_new_tokens=max_new_tokens,
)
hf_outputs_batch = HfGenerationMixin.greedy_search(
model, input_ids=inputs_batch, max_length=inputs_batch.size(1) + max_new_tokens
)
assert torch.allclose(
remote_outputs_batch, hf_outputs_batch
), "Greedy search results are not identical to HF in multibatch mode"
options = dict(max_new_tokens=max_new_tokens, do_sample=False)
for multiple_calls in [False, True]:
for inputs in [inputs_single, inputs_batch]:
outputs = make_generate_calls(model, inputs, multiple_calls=multiple_calls, **options)
ref_outputs = ref_model.generate(inputs, **options)
assert torch.allclose(
outputs, ref_outputs
), f"Greedy generation is not identical to HF with {multiple_calls=}, {inputs.shape=}"
@pytest.mark.forked
@pytest.mark.parametrize("sampling_options", [dict(), dict(temperature=100.0), dict(top_k=5), dict(top_p=0.9)])
@pytest.mark.skip("Sampling is currently not consistent with outputs from Transformers")
def test_sampling(tokenizer, sampling_options, max_new_tokens=4):
torch.manual_seed(0)
model = AutoDistributedModelForCausalLM.from_pretrained(
MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32
)
logits_warper = HfGenerationMixin._get_logits_warper(model, num_beams=1, **sampling_options)
inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
with torch.random.fork_rng():
remote_outputs = model.generate(
inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
**sampling_options,
)
with torch.random.fork_rng():
hf_outputs = HfGenerationMixin.sample(
model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens, logits_warper=logits_warper
)
assert torch.allclose(remote_outputs, hf_outputs), "Sampling results are not identical to HF"
def test_sampling(tokenizer, model, ref_model, max_new_tokens=10):
inputs_single = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
inputs_batch = tokenizer(["A cat sat on a mat", "A dog sat on a mat"], return_tensors="pt", padding=True)[
"input_ids"
]
with torch.random.fork_rng():
remote_outputs_batch = model.generate(
inputs_batch,
max_new_tokens=max_new_tokens,
do_sample=True,
**sampling_options,
)
with torch.random.fork_rng():
hf_outputs_batch = HfGenerationMixin.sample(
model,
input_ids=inputs_batch,
max_length=inputs_batch.size(1) + max_new_tokens,
logits_warper=logits_warper,
)
assert torch.allclose(
remote_outputs_batch, hf_outputs_batch
), "Sampling results are not identical to HF in multibatch mode"
for options in [
dict(do_sample=True, temperature=0.5, top_k=5, top_p=0.9),
dict(do_sample=True, temperature=0.5, repetition_penalty=1.2),
]:
options.update(max_new_tokens=max_new_tokens)
for multiple_calls in [False, True]:
for inputs in [inputs_single, inputs_batch]:
torch.manual_seed(0)
outputs = make_generate_calls(model, inputs, multiple_calls=multiple_calls, **options)
torch.manual_seed(0)
ref_outputs = ref_model.generate(inputs, **options)
assert torch.allclose(
outputs, ref_outputs
), f"Sampling is not identical to HF with {options=}, {multiple_calls=}, {inputs.shape=}"
@pytest.mark.forked
def test_beam_search_generation(tokenizer, max_new_tokens=4, num_beams=2):
model = AutoDistributedModelForCausalLM.from_pretrained(
MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32
)
text = "A cat sat on a mat"
inputs = tokenizer(text, return_tensors="pt")["input_ids"]
remote_outputs = model.generate(
inputs,
max_new_tokens=max_new_tokens,
num_beams=num_beams,
)
beam_scorer = BeamSearchScorer(
batch_size=inputs.size(0),
num_beams=num_beams,
device=inputs.device,
length_penalty=0,
do_early_stopping=False,
)
hf_inputs = tokenizer([text] * 2, return_tensors="pt")["input_ids"]
hf_outputs = HfGenerationMixin.beam_search(
model, input_ids=hf_inputs, max_length=inputs.size(1) + max_new_tokens, beam_scorer=beam_scorer
)
assert torch.allclose(remote_outputs, hf_outputs), "Beam search results are not identical to HF"
def test_beam_search_generation(tokenizer, model, ref_model, max_new_tokens=4, num_beams=5):
inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
options = dict(max_new_tokens=max_new_tokens, num_beams=num_beams, do_sample=False)
outputs = make_generate_calls(model, inputs, **options)
ref_outputs = ref_model.generate(inputs, **options)
assert torch.allclose(outputs, ref_outputs), f"Beam search results are not identical to HF"

Loading…
Cancel
Save