Make client compatible with transformers' GenerationMixin (#464)

This PR drops custom generation codes and introduces compatibility with `transformers.GenerationMixin` instead. This includes support for more sampling options (`top_p`, `top_k`, `repetition_penalty` requested in #460) and beam search - all that is now identical to running model with transformers locally.

Most features (excluding beam search and other rarely used stuff) are also compatible with resuming existing sessions.

### Breaking changes

If `.generate()` or forward passes are being run inside an `.inference_session()` context, they now use the opened session by default. So, these snippets are now equivalent:

```python
# Using default session
with model.inference_session(max_length=100):
    output_ids = model.generate(input_ids, max_new_tokens=3)

# Explicitly specifying a session
with model.inference_session(max_length=100) as sess:
    output_ids = model.generate(input_ids, max_new_tokens=3, session=sess)
```

Earlier, the 1st snippet was creating a new session, which is not what most people expected (= such code was most likely to introduce a bug, which is now fixed).
pull/470/head
Alexander Borzunov 9 months ago committed by GitHub
parent 063e94b4c8
commit de2475f31c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -41,6 +41,7 @@ jobs:
pip install .[dev] pip install .[dev]
- name: Test - name: Test
run: | run: |
set -x # Print executed commands
export MODEL_NAME="${{ matrix.model }}" export MODEL_NAME="${{ matrix.model }}"
export REF_NAME="${{ matrix.model }}" export REF_NAME="${{ matrix.model }}"
export ADAPTER_NAME="${{ matrix.model == 'bigscience/bloom-560m' && 'artek0chumak/bloom-560m-safe-peft' || '' }}" export ADAPTER_NAME="${{ matrix.model == 'bigscience/bloom-560m' && 'artek0chumak/bloom-560m-safe-peft' || '' }}"

@ -3,7 +3,7 @@ import json
import os import os
import re import re
import tempfile import tempfile
import threading from contextvars import ContextVar
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
@ -47,18 +47,16 @@ class FromPretrainedMixin:
) )
_shard_config = threading.local() _ignored_keys = ContextVar("ignored_keys", default=None)
_shard_config.ignored_keys = None
@contextlib.contextmanager @contextlib.contextmanager
def ignore_keys(patterns: List[str]): def ignore_keys(patterns: List[str]):
token = _ignored_keys.set(patterns)
try: try:
prev_patterns = _shard_config.ignored_keys
_shard_config.ignored_keys = patterns
yield yield
finally: finally:
_shard_config.ignored_keys = prev_patterns _ignored_keys.reset(token)
def patched_get_checkpoint_shard_files( def patched_get_checkpoint_shard_files(
@ -66,7 +64,7 @@ def patched_get_checkpoint_shard_files(
) -> Tuple[List[str], dict]: ) -> Tuple[List[str], dict]:
"""Same as modeling_utils.get_checkpoint_shard_files(), but does not download shards for the ignored keys.""" """Same as modeling_utils.get_checkpoint_shard_files(), but does not download shards for the ignored keys."""
should_ignore_keys = _shard_config.ignored_keys is not None should_ignore_keys = _ignored_keys.get() is not None
tempdir_ctx = tempfile.TemporaryDirectory() if should_ignore_keys else contextlib.nullcontext() tempdir_ctx = tempfile.TemporaryDirectory() if should_ignore_keys else contextlib.nullcontext()
with tempdir_ctx as tempdir: with tempdir_ctx as tempdir:
if should_ignore_keys: if should_ignore_keys:
@ -77,7 +75,7 @@ def patched_get_checkpoint_shard_files(
index["weight_map"] = { index["weight_map"] = {
param_name: filename param_name: filename
for param_name, filename in index["weight_map"].items() for param_name, filename in index["weight_map"].items()
if all(re.search(pattern, param_name) is None for pattern in _shard_config.ignored_keys) if all(re.search(pattern, param_name) is None for pattern in _ignored_keys.get())
} }
n_loaded_shards = len(set(index["weight_map"].values())) n_loaded_shards = len(set(index["weight_map"].values()))
logger.debug(f"Loading {n_loaded_shards} shards out of {n_original_shards}") logger.debug(f"Loading {n_loaded_shards} shards out of {n_original_shards}")

@ -230,7 +230,7 @@ class InferenceSession:
self._server_sessions = [] self._server_sessions = []
self._position = 0 self._position = 0
self._max_length = max_length self._max_length = max_length
self.last_token_id = None self.output_ids = None
@property @property
def num_blocks(self) -> int: def num_blocks(self) -> int:
@ -377,3 +377,13 @@ class InferenceSession:
def __del__(self): def __del__(self):
self.close() self.close()
@property
def last_token_id(self) -> Optional[torch.Tensor]: # Backward compatibility with Petals < 2.1.0
return self.output_ids[:, -1:] if self.output_ids is not None else None
@last_token_id.setter
def last_token_id(self, value: torch.Tensor): # Backward compatibility with Petals < 2.1.0
if self.output_ids is None:
raise RuntimeError("Can't override `last_token_id` since the session has not stepped yet")
self.output_ids[:, -1:] = value

@ -70,8 +70,8 @@ class LMHead(nn.Module):
if not self._bf16_warning_shown: if not self._bf16_warning_shown:
if self.weight.numel() * 4 < 0.9 * psutil.virtual_memory().total: if self.weight.numel() * 4 < 0.9 * psutil.virtual_memory().total:
logger.warning( logger.warning(
"Running the client with dtype bfloat16 on CPU may be slow, since your CPU doesn't support AVX512. " "Running the model in bfloat16 on CPU will be slow since your CPU does not support AVX512. "
"Consider loading the model with torch_dtype='float32'" "To speed it up, load the model in float32 using .from_pretrained(..., torch_dtype=torch.float32)"
) )
self._bf16_warning_shown = True self._bf16_warning_shown = True

@ -76,9 +76,9 @@ def force_non_empty_weights():
[1] https://github.com/huggingface/transformers/blob/ab9fe45236cd99b8797df78219438f8f6662bb42/src/transformers/modeling_utils.py#L2515 [1] https://github.com/huggingface/transformers/blob/ab9fe45236cd99b8797df78219438f8f6662bb42/src/transformers/modeling_utils.py#L2515
""" """
possibly_patched_register_parameter = nn.Module.register_parameter
nn.Module.register_parameter = _original_register_parameter
try: try:
possibly_patched_register_parameter = nn.Module.register_parameter
nn.Module.register_parameter = _original_register_parameter
yield yield
finally: finally:
nn.Module.register_parameter = possibly_patched_register_parameter nn.Module.register_parameter = possibly_patched_register_parameter

@ -1,349 +1,142 @@
import contextlib import contextlib
from typing import List, Optional import dataclasses
from contextvars import ContextVar
from typing import ContextManager, List, Optional
import torch import torch
import transformers
from hivemind.utils.logging import get_logger from hivemind.utils.logging import get_logger
from transformers.generation.utils import ModelOutput
from petals.client.inference_session import InferenceSession from petals.client.inference_session import InferenceSession
from petals.utils.generation_algorithms import ( from petals.client.remote_sequential import RemoteSequential
BeamSearchAlgorithm, from petals.utils.misc import DUMMY, docstring_from
DecodingAlgorithm,
GreedyAlgorithm,
NucleusAlgorithm,
SamplingAlgorithm,
TopKAlgorithm,
)
from petals.utils.generation_constraints import ABCBloomConstraint, EosConstraint
logger = get_logger(__name__) logger = get_logger(__name__)
class RemoteGenerationMixin: @dataclasses.dataclass(frozen=True)
""" class RemotePastKeyValues:
A class containing all functions for auto-regressive text generation, to be used as a mixin in [`BloomForCausalLM`]. """A mock class representing the fact that `past_key_values` do exist but are stored on remote servers."""
The class exposes can be used for:
- *greedy decoding*.
- *multinomial, top-k and top-p sampling*.
- *beam-search decoding*
This class is similar to transformer's [`generation_utils.GenerationMixin`], it can be used instead of it.
However, it has some differences for remote usage.
"""
def inference_session(self, **kwargs) -> InferenceSession:
"""
Returns an inference session for the model's RemoteSequential module.
:param max_length: Maximal expected length of inference results. Servers use this parameter hypo_ids: Optional[torch.LongTensor] = None
to calculate the size of attention caches allocated to this client.
"""
return self.transformer.h.inference_session(**kwargs) def __getitem__(self, _index: int) -> List[torch.Tensor]:
return [DUMMY] # For compatibility with BloomForCausalLM.prepare_inputs_for_generation()
@torch.inference_mode()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
*,
do_sample: Optional[bool] = None,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
num_beams: Optional[int] = 1,
bos_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
max_length: Optional[int] = None,
max_new_tokens: Optional[int] = None,
decoding_algorithm: Optional[DecodingAlgorithm] = None,
provided_constraints: List[ABCBloomConstraint] = [],
num_return_sequences: Optional[int] = None,
session: Optional[InferenceSession] = None,
) -> torch.LongTensor:
"""
Generates sequences of token ids for models with a language modeling head.
:param inputs: The input tokens to the model. _skipped_tokens = ContextVar("skipped_tokens", default=0)
:param do_sample: Whether to sample from the model predictions or take the argmax.
:param temperature: The temperature to use for sampling.
:param top_k: The number of results to return.
:param top_p: The cumulative probability of results to return.
:param num_beams: The number of beams to use for beam search.
:param bos_token_id: The id of the beginning of sentence token.
:param eos_token_id: The id of the end of sentence token.
:param pad_token_id: The id of the padding token.
:param max_length: The maximum number of tokens in the output (including input tokens).
:param max_new_tokens: The maximum number of tokens to generate.
:param decoding_algorithm: The decoding algorithm to use.
:param provided_constraints: A list of constraints to use.
:param num_return_sequences: How many hypothesis from the beam will be in output.
"""
prefix_length = 0 if inputs is None else inputs.size(1)
prefix_length += self.config.pre_seq_len
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id class _SkipTokensMixin:
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id # This override is used in RemoteGenerationMixin by has to be defined in a class not named as "GenerationMixin"
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id # due to how transformers.PreTrainedModel.can_generate() works
def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> dict:
input_ids = input_ids[:, _skipped_tokens.get() :]
_skipped_tokens.set(0)
return super().prepare_inputs_for_generation(input_ids, **kwargs)
assert (max_length is None) != (max_new_tokens is None), "please set max_length or max_new_tokens (not both)"
if max_length is not None and max_new_tokens is None:
max_new_tokens = max_length - prefix_length
assert max_new_tokens > 0, f"Provided max_length is less than prefix size: {max_length} < {inputs.size(1)}"
elif max_length is None and max_new_tokens is not None:
max_length = prefix_length + max_new_tokens
resuming_session = session is not None and session.last_token_id is not None class RemoteGenerationMixin(_SkipTokensMixin):
if num_beams > 1 and resuming_session: """
raise NotImplementedError( This class is an upgrade to `transformers.GenerationMixin` that:
"Resuming inference session in .generate() along with beam search is not supported yet"
) - Designed to be compatible with most `transformers.GenerationMixin` strategies and options
- Supports generation inside a remote InferenceSession, so that remote servers store your attention caches and
you don't have to rerun the prefix through all the servers to generate each new token
- Supports multiple `.generate()` calls inside one InferenceSession, so you can easily run interactive generation
by showing tokens on the fly (multiple calls like `.generate(None, max_new_tokens=1, ...)`) or
accept prompts from a user in a chat bot (multiple calls like `.generate(new_prompts, ...)`).
- If there is no active session, `.generate()` will create a new InferenceSession with proper `max_length`.
Otherwise, `.generate()` will use the active session. You can use the `session=...` argument to override that.
"""
if inputs is not None: @docstring_from(RemoteSequential.active_session)
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 2, "inputs must be a 2d tensor [batch, length]" @property
if resuming_session: def active_session(self) -> Optional[InferenceSession]:
inputs = torch.cat([session.last_token_id, inputs], dim=1) return self.transformer.h.active_session
else:
if resuming_session:
inputs = session.last_token_id
else:
assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs"
inputs = torch.tensor([[bos_token_id]] * num_beams, dtype=torch.long, device=self.device)
batch_size = inputs.size(0)
if decoding_algorithm is None: @docstring_from(RemoteSequential.use_session)
if do_sample: def use_session(self, session: Optional[InferenceSession]) -> ContextManager[InferenceSession]:
decoding_algorithm = self._choose_sample_algorithm(temperature, top_k, top_p) return self.transformer.h.use_session(session)
elif num_beams is not None and num_beams > 1:
decoding_algorithm = BeamSearchAlgorithm(num_beams, batch_size=batch_size)
else:
if top_k is not None or top_p is not None:
logger.warning("You passed top_k or top_p but did not pass do_sample=True. Running greedy sampling")
decoding_algorithm = GreedyAlgorithm()
if num_beams > 1: @docstring_from(RemoteSequential.inference_session)
inputs = torch.cat([inputs] * num_beams, dim=0) def inference_session(self, **kwargs) -> ContextManager[InferenceSession]:
if batch_size > 1: return self.transformer.h.inference_session(**kwargs)
# TODO: resolve padding problem
logger.warning(
f"You set batch_size {batch_size} within beam search generation. "
f"Be careful, results on sequences with different length may be padded wrong way"
)
if num_return_sequences is None: @docstring_from(transformers.GenerationMixin.generate.__doc__)
num_return_sequences = 1 def generate(
self, inputs: Optional[torch.Tensor] = None, *args, session: Optional[InferenceSession] = None, **kwargs
):
self._fix_generate_kwargs(kwargs)
if session is not None:
# If a session specified explicitly, use it
context_manager = self.use_session(session)
elif self.active_session is not None:
# If there's an active session, don't do anything
context_manager = contextlib.nullcontext(self.active_session)
else:
# If there's no active session, create a new one
assert num_return_sequences <= num_beams, ( max_length = kwargs.get("max_length")
f"You want more sequences than the beam has." max_new_tokens = kwargs.get("max_new_tokens")
" Check num_return_sequences: {num_return_sequences} and num_beams: {num_beams}." assert (max_length is None) != (
) max_new_tokens is None
), "You should set `max_length` or `max_new_tokens` (but not both) to reserve server-side attention caches"
constraints = self._get_constraints( if max_length is not None:
inputs=inputs, session_max_length = max_length
eos_token_id=eos_token_id, else:
pad_token_id=pad_token_id, session_max_length = (inputs.shape[1] if inputs is not None else 0) + max_new_tokens
provided_constraints=provided_constraints, context_manager = self.inference_session(max_length=session_max_length)
)
if session is None:
context_manager = self.inference_session(max_length=max_length)
else:
context_manager = contextlib.nullcontext(session) # Doesn't actually enter session or exit from it
with context_manager as session: with context_manager as session:
outputs = [] # Prepend the tokens from the previous .generate() call
# Find samples with padded inputs. n_prev_tokens = session.output_ids.shape[1] if session.output_ids is not None else 0
# They will be changed before all of the samples have right length. if n_prev_tokens > 0:
if torch.any(inputs == pad_token_id): # TODO: move to prepare_inputs if kwargs.get("num_beams", 1) > 1:
outputs += [inputs[:, : inputs.size(1) - (inputs == pad_token_id).sum(-1).max()]] logger.warning(
"Beam search will not work properly in the resumed petals.InferenceSession "
"since intermediate beam entries are lost"
)
if inputs is not None:
inputs = torch.cat([session.output_ids, inputs], dim=1)
else:
inputs = session.output_ids
# Don't actually run all previous tokens through the transformer,
# but keep them for transformers.GenerationMixin (e.g., to compute repetition_penalty)
_skipped_tokens.set(max(0, n_prev_tokens - 1))
result = super().generate(inputs, *args, **kwargs)
sequences = result.sequences if isinstance(result, ModelOutput) else result
# Save tokens from this .generate() call
session.output_ids = sequences
# Crop the last tokens from the previous call
sequences = sequences[:, n_prev_tokens:].clone()
if isinstance(result, ModelOutput):
result.sequences = sequences
else: else:
outputs += [inputs] result = sequences
last_token_id = None
seq_idx = outputs[0].size(1)
hypo_ids = torch.arange(outputs[0].size(0))
while True:
hidden_state = self.transformer.word_embeddings(outputs[-1])
intermediate_prompts = None
if self.config.pre_seq_len > 0 and len(outputs) == 1:
prompts, intermediate_prompts = self.transformer.get_prompt(hidden_state.size(0))
hidden_state = torch.cat([prompts, hidden_state], dim=1)
hidden_state = self.transformer.word_embeddings_layernorm(hidden_state)
hidden_state = session.step(hidden_state, prompts=intermediate_prompts, hypo_ids=hypo_ids)[:, -1] return result
hidden_state = self.transformer.ln_f(hidden_state) @staticmethod
lm_logits = self.lm_head(hidden_state) def _fix_generate_kwargs(kwargs: dict) -> dict:
# Suppress inappropriate "Both max_new_tokens and max_length" HF warning
if "max_length" in kwargs and kwargs["max_length"] is None:
del kwargs["max_length"]
for constraint in constraints: # Support do_sample = {0, 1} for backward compatibility with Petals < 2.1.0
lm_logits = constraint(last_token_id, lm_logits, hypo_ids) do_sample = kwargs.get("do_sample")
last_token_id, hypo_ids = decoding_algorithm(lm_logits) if isinstance(do_sample, int):
kwargs["do_sample"] = bool(do_sample)
# If some samples were padded, change only these samples return kwargs
if seq_idx < inputs.size(1):
pad_token_mask = inputs[:, seq_idx : seq_idx + 1] == pad_token_id
last_token_id = (~pad_token_mask) * inputs[
:, seq_idx : seq_idx + 1
] + pad_token_mask * last_token_id
# TODO: refactor outputs
if num_beams > 1:
for i in range(len(outputs), 1, -1):
outputs[i - 1] = outputs[i - 1][hypo_ids]
outputs.append(last_token_id)
session.last_token_id = last_token_id
seq_idx += 1
if torch.all(last_token_id == eos_token_id) or len(outputs) > max_new_tokens:
break
outputs = torch.cat(outputs, dim=-1)
if resuming_session:
outputs = outputs[:, 1:]
if num_beams > 1:
pre_return_idx = [
torch.arange(idx, num_return_sequences * batch_size, batch_size) for idx in range(batch_size)
]
return_idx = torch.cat(pre_return_idx, dim=0)
outputs = outputs[return_idx]
return outputs
def greedy_search(
self,
input_ids: torch.LongTensor,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
provided_constraints: List[ABCBloomConstraint] = [],
) -> torch.LongTensor:
"""
Generates sequences of token ids for models with a language modeling head. Uses greedy search.
:param input_ids: The input tokens to the model.
:param max_length: The maximum length of the sequence to generate.
:param pad_token_id: The id of the padding token.
:param eos_token_id: The id of the end of sentence token.
:param provided_constraints: A list of constraints to use.
"""
return self.generate(
inputs=input_ids,
max_new_tokens=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
decoding_algorithm=GreedyAlgorithm(),
provided_constraints=provided_constraints,
)
def sample(
self,
input_ids: torch.LongTensor,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
provided_constraints: List[ABCBloomConstraint] = [],
) -> torch.LongTensor:
"""
Generates sequences of token ids for models with a language modeling head. Uses multinomial sampling.
If top_k is provided, uses top_k sampling. If top_p is provided, uses nucleus sampling.
:param: input_ids: The input tokens to the model.
:param: temperature: The temperature to use for sampling.
:param: top_k: The number of samples to use for top_k sampling.
:param: top_p: The probability of using top_p sampling.
:param: max_length: The maximum length of the sequence to generate.
:param: pad_token_id: The id of the padding token.
:param: eos_token_id: The id of the end of sentence token.
:param: provided_constraints: A list of constraints to use.
"""
return self.generate(
inputs=input_ids,
max_new_tokens=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
decoding_algorithm=self._choose_sample_algorithm(temperature, top_k, top_p),
provided_constraints=provided_constraints,
)
def beam_search(
self,
input_ids: torch.LongTensor,
num_beams: int = 1,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
provided_constraints: List[ABCBloomConstraint] = [],
) -> torch.LongTensor:
"""
Generates sequences of token ids for models with a language modeling head. Uses beam search.
:param input_ids: The input tokens to the model.
:param num_beams: The number of beams to use.
:param max_length: The maximum length of the sequence to generate.
:param pad_token_id: The id of the padding token.
:param eos_token_id: The id of the end of sentence token.
:param provided_constraints: A list of constraints to use.
"""
decoding_algorithm = BeamSearchAlgorithm(
num_beams=num_beams,
batch_size=input_ids.size(0),
)
return self.generate(
inputs=input_ids,
num_beams=num_beams,
max_new_tokens=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
decoding_algorithm=decoding_algorithm,
provided_constraints=provided_constraints,
)
def beam_sample(
self,
input_ids: torch.LongTensor,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
provided_constraints: List[ABCBloomConstraint] = [],
) -> torch.LongTensor:
raise NotImplementedError
def group_beam_search(
self,
input_ids: torch.LongTensor,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
provided_constraints: List[ABCBloomConstraint] = [],
) -> torch.LongTensor:
raise NotImplementedError
def _choose_sample_algorithm(
self,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
) -> DecodingAlgorithm:
if (top_k is not None) and (top_p is not None):
raise ValueError("You have to provide only top_k or top_p for sampling")
if top_k is not None:
return TopKAlgorithm(top_k, temperature)
elif top_p is not None:
return NucleusAlgorithm(top_p, temperature)
else:
return SamplingAlgorithm(temperature)
def _get_constraints( @staticmethod
self, def _reorder_cache(past_key_values: RemotePastKeyValues, beam_idx: torch.LongTensor) -> RemotePastKeyValues:
inputs: Optional[torch.Tensor] = None, return dataclasses.replace(past_key_values, hypo_ids=beam_idx)
eos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
provided_constraints: List[ABCBloomConstraint] = [],
) -> List[ABCBloomConstraint]:
constraints = []
constraints.extend(provided_constraints)
constraints.append(EosConstraint(inputs, eos_token_id, pad_token_id))
return constraints

@ -1,5 +1,7 @@
from __future__ import annotations from __future__ import annotations
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Optional, Union from typing import Optional, Union
import torch import torch
@ -11,7 +13,6 @@ from petals.client.inference_session import InferenceSession
from petals.client.routing import RemoteSequenceManager from petals.client.routing import RemoteSequenceManager
from petals.client.sequential_autograd import _RemoteSequentialAutogradFunction from petals.client.sequential_autograd import _RemoteSequentialAutogradFunction
from petals.data_structures import UID_DELIMITER from petals.data_structures import UID_DELIMITER
from petals.utils.misc import DUMMY
logger = get_logger(__name__) logger = get_logger(__name__)
@ -46,11 +47,52 @@ class RemoteSequential(nn.Module):
sequence_manager = RemoteSequenceManager(config, block_uids, dht=dht, **kwargs) sequence_manager = RemoteSequenceManager(config, block_uids, dht=dht, **kwargs)
self.sequence_manager = sequence_manager self.sequence_manager = sequence_manager
def forward(self, inputs: torch.Tensor, prompts: torch.Tensor = DUMMY): self._active_session = ContextVar("active_session", default=None)
def forward(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
assert inputs.ndim == 3, "inputs must be a tensor of shape [batch_size, seq_length, hidden_size]" assert inputs.ndim == 3, "inputs must be a tensor of shape [batch_size, seq_length, hidden_size]"
assert inputs.shape[1] <= 2048, "The sequence length is capped at 2048 tokens in this version" if self.active_session is None:
outputs = _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager) assert all(v is None for v in kwargs.values()), f"Extra kwargs are not supported in forward: {kwargs}"
return outputs return _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager)
else:
return self.active_session.step(inputs, prompts, **kwargs)
@property
def active_session(self) -> Optional[InferenceSession]:
"""
If called inside `with model.inference_session(...):` or `with model.use_session(...):`,
returns an active InferenceSession. Otherwise, returns None.
"""
return self._active_session.get()
@property
def position(self) -> int:
"""Returns the prefix length (in tokens) in the active inference session or zero if no session is active."""
return self.active_session.position if self.active_session is not None else 0
@contextmanager
def use_session(self, session: Optional[InferenceSession]) -> InferenceSession:
"""Inside this context, forward() will use an _existing_ InferenceSession provided as the argument."""
token = self._active_session.set(session)
try:
yield session
finally:
self._active_session.reset(token)
@contextmanager
def inference_session(self, **kwargs) -> InferenceSession:
"""
Inside this context, forward() will use a _new_ InferenceSession created with given parameters.
:param max_length: Maximal expected length of inference results. Servers use this parameter
to calculate the size of attention caches allocated to this client.
"""
with InferenceSession(self.sequence_manager, **kwargs) as session, self.use_session(session):
yield session
def __getitem__(self, ix: Union[int, slice]) -> RemoteSequential: def __getitem__(self, ix: Union[int, slice]) -> RemoteSequential:
return RemoteSequential( return RemoteSequential(
@ -65,8 +107,5 @@ class RemoteSequential(nn.Module):
def __len__(self): def __len__(self):
return len(self.sequence_manager) return len(self.sequence_manager)
def inference_session(self, **kwargs) -> InferenceSession:
return InferenceSession(self.sequence_manager, **kwargs)
def extra_repr(self) -> str: def extra_repr(self) -> str:
return f"modules={self.sequence_manager.block_uids[0]}..{self.sequence_manager.block_uids[-1]}" return f"modules={self.sequence_manager.block_uids[0]}..{self.sequence_manager.block_uids[-1]}"

@ -230,7 +230,7 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
def forward(ctx, inputs: torch.Tensor, prompts: torch.Tensor, sequence_manager: RemoteSequenceManager): def forward(ctx, inputs: torch.Tensor, prompts: torch.Tensor, sequence_manager: RemoteSequenceManager):
batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1) batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1)
input_batches: Sequence[torch.Tensor] = inputs.detach().split(batch_size) input_batches: Sequence[torch.Tensor] = inputs.detach().split(batch_size)
if is_dummy(prompts): if prompts is None or is_dummy(prompts):
prompt_batches = [DUMMY] * len(input_batches) prompt_batches = [DUMMY] * len(input_batches)
else: else:
prompt_batches: Sequence[torch.Tensor] = prompts.detach().split(batch_size, dim=1) prompt_batches: Sequence[torch.Tensor] = prompts.detach().split(batch_size, dim=1)

@ -10,7 +10,7 @@ from transformers.models.bloom import BloomForCausalLM, BloomForSequenceClassifi
from petals.client.from_pretrained import FromPretrainedMixin from petals.client.from_pretrained import FromPretrainedMixin
from petals.client.lm_head import LMHead from petals.client.lm_head import LMHead
from petals.client.ptune import PTuneMixin from petals.client.ptune import PTuneMixin
from petals.client.remote_generation import RemoteGenerationMixin from petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues
from petals.client.remote_sequential import RemoteSequential from petals.client.remote_sequential import RemoteSequential
from petals.models.bloom.config import DistributedBloomConfig from petals.models.bloom.config import DistributedBloomConfig
@ -39,16 +39,15 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
def forward( def forward(
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None, past_key_values: Optional[RemotePastKeyValues] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
**kwargs, head_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
): ):
assert attention_mask is None, f"{self.__class__.__name__} does not support attention masks right now"
for k, v in kwargs.items():
if not (v is None or v is False):
logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})")
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
@ -59,21 +58,34 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
# The causal mask will be added on the server-side
assert (
attention_mask is None or (attention_mask == 1).all()
), f"Custom attention masks are not supported, {attention_mask=}"
assert head_mask is None, f"Custom head masks are not supported, {head_mask=}"
assert use_cache is None or use_cache, f"{use_cache=} is not supported"
assert not output_attentions, f"{output_attentions=} is not supported"
assert not output_hidden_states, f"{output_hidden_states=} is not supported"
assert return_dict is None or return_dict, f"{return_dict=} is not supported"
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids) inputs_embeds = self.word_embeddings(input_ids)
if self.config.tuning_mode and "ptune" in self.config.tuning_mode: if self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0:
batch_size = inputs_embeds.shape[0] batch_size = inputs_embeds.shape[0]
prompts, intermediate_prompts = self.get_prompt(batch_size) prompts, intermediate_prompts = self.get_prompt(batch_size)
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1) inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
else:
prompts = intermediate_prompts = None
hidden_states = self.word_embeddings_layernorm(inputs_embeds) hidden_states = self.word_embeddings_layernorm(inputs_embeds)
output_shape = input_shape + (hidden_states.size(-1),) output_shape = input_shape + (hidden_states.size(-1),)
if self.config.tuning_mode and "ptune" in self.config.tuning_mode: hidden_states = self.h(
hidden_states = self.h(hidden_states, prompts=intermediate_prompts) hidden_states,
else: prompts=intermediate_prompts,
hidden_states = self.h(hidden_states) hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None,
)
# Remove prefix # Remove prefix
if self.config.tuning_mode and "ptune" in self.config.tuning_mode: if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
@ -84,7 +96,7 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
hidden_states = hidden_states.view(output_shape) hidden_states = hidden_states.view(output_shape)
return BaseModelOutputWithPastAndCrossAttentions( return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=None, past_key_values=RemotePastKeyValues(),
hidden_states=None, hidden_states=None,
attentions=None, attentions=None,
) )

@ -10,7 +10,7 @@ from transformers.models.llama import LlamaForCausalLM, LlamaForSequenceClassifi
from petals.client.from_pretrained import FromPretrainedMixin from petals.client.from_pretrained import FromPretrainedMixin
from petals.client.lm_head import LMHead from petals.client.lm_head import LMHead
from petals.client.ptune import PTuneMixin from petals.client.ptune import PTuneMixin
from petals.client.remote_generation import RemoteGenerationMixin from petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues
from petals.client.remote_sequential import RemoteSequential from petals.client.remote_sequential import RemoteSequential
from petals.models.llama.config import DistributedLlamaConfig from petals.models.llama.config import DistributedLlamaConfig
@ -39,16 +39,15 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
def forward( def forward(
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
**kwargs, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[RemotePastKeyValues] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> BaseModelOutputWithPast: ) -> BaseModelOutputWithPast:
assert attention_mask is None, f"{self.__class__.__name__} does not support attention masks right now"
for k, v in kwargs.items():
if not (v is None or v is False):
logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})")
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
@ -59,21 +58,36 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
# The causal mask will be added on the server-side
assert (
attention_mask is None or (attention_mask == 1).all()
), f"Custom attention masks are not supported, {attention_mask=}"
assert (
position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all()
), f"Non-consecutive position_ids are not supported, {position_ids=}"
assert use_cache is None or use_cache, f"{use_cache=} is not supported"
assert not output_attentions, f"{output_attentions=} is not supported"
assert not output_hidden_states, f"{output_hidden_states=} is not supported"
assert return_dict is None or return_dict, f"{return_dict=} is not supported"
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
if self.config.tuning_mode and "ptune" in self.config.tuning_mode: if self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.layers.position == 0:
batch_size = inputs_embeds.shape[0] batch_size = inputs_embeds.shape[0]
prompts, intermediate_prompts = self.get_prompt(batch_size) prompts, intermediate_prompts = self.get_prompt(batch_size)
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1) inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
else:
prompts = intermediate_prompts = None
hidden_states = inputs_embeds hidden_states = inputs_embeds
output_shape = input_shape + (hidden_states.size(-1),) output_shape = input_shape + (hidden_states.size(-1),)
if self.config.tuning_mode and "ptune" in self.config.tuning_mode: hidden_states = self.layers(
hidden_states = self.layers(hidden_states, prompts=intermediate_prompts) hidden_states,
else: prompts=intermediate_prompts,
hidden_states = self.layers(hidden_states) hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None,
)
# Remove prefix # Remove prefix
if self.config.tuning_mode and "ptune" in self.config.tuning_mode: if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
@ -84,7 +98,7 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
hidden_states = hidden_states.view(output_shape) hidden_states = hidden_states.view(output_shape)
return BaseModelOutputWithPast( return BaseModelOutputWithPast(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=None, past_key_values=RemotePastKeyValues(),
hidden_states=None, hidden_states=None,
attentions=None, attentions=None,
) )

@ -196,7 +196,7 @@ async def iterate_rpc_inference(
hypo_ids, hypo_ids,
points=point_per_piece, points=point_per_piece,
requested_uids=requested_uids, requested_uids=requested_uids,
type="short_inference" if can_merge_pools else "inference", type="inference",
) )
# A client may pass a tensor with 0 tokens. This is a special case that occurs, e.g. # A client may pass a tensor with 0 tokens. This is a special case that occurs, e.g.

@ -14,9 +14,7 @@ class TaskPrioritizerBase(ABC):
class DummyTaskPrioritizer(TaskPrioritizerBase): class DummyTaskPrioritizer(TaskPrioritizerBase):
def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float: def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float:
# Inference steps (especially short ones) go first since they are more latency-sensitive # Inference steps go first since they are more latency-sensitive
if kwargs.get("type") == "short_inference":
return 1.0
if kwargs.get("type") == "inference": if kwargs.get("type") == "inference":
return 2.0 return 1.0
return 3.0 # Forward, backward return 2.0 # Forward, backward

@ -1,128 +0,0 @@
from abc import ABC, abstractmethod
from typing import Tuple
import torch
TokenIds = torch.Tensor
HypoIds = torch.Tensor
class DecodingAlgorithm(ABC):
"""
An abstract class for decoding algorithms. Describes the base function of those algorithms:
they have to select new tokens and provide the corresponding hypotheses.
"""
@abstractmethod
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
"""
:param logits: A tensor of shape (batch_size, seq_length, vocab_size)
:return: A tuple of selected token ids and corresponding hypotheses.
The shape of the token ids is (batch_size, seq_length), and the shape of the hypotheses is (batch_size)
"""
pass
class GreedyAlgorithm(DecodingAlgorithm):
"""
The simplest algorithm for decoding. It selects the most probable token.
"""
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
"""
Returns the most probable token. The second returned object is always a range of integers
from 0 to batch_size - 1.
"""
return logits.max(-1)[1].unsqueeze(1), torch.arange(logits.size(0))
class SamplingAlgorithm(DecodingAlgorithm):
def __init__(self, temperature: float = 1.0):
self.temperature = temperature
def sample(self, logits: torch.Tensor, indices_to_remove: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
"""
:param logits: A tensor of shape (batch_size * num_hypos, vocab_size)
:param indices_to_remove: A bool tensor of shape (batch_size * num_hypos, vocab_size)
:return: A tuple of selected token ids and corresponding hypotheses.
The shape of the token ids is (batch_size, seq_length), and the shape of the hypotheses is (batch_size).
"""
logits[indices_to_remove] = -float("Inf")
probs = torch.softmax(logits / self.temperature, -1)
return torch.multinomial(probs, num_samples=1), torch.arange(logits.size(0))
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
indices_to_remove = torch.full_like(logits, False, dtype=torch.bool)
return self.sample(logits, indices_to_remove)
class TopKAlgorithm(SamplingAlgorithm):
def __init__(self, top_k: int, temperature: float = 1.0) -> None:
super().__init__(temperature=temperature)
self.top_k = top_k
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
indices_to_remove = logits < torch.topk(logits, self.top_k, dim=-1)[0][..., -1, None]
return self.sample(logits, indices_to_remove)
class NucleusAlgorithm(SamplingAlgorithm):
def __init__(self, top_p: float, temperature: float = 1.0) -> None:
super().__init__(temperature=temperature)
self.top_p = top_p
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
sorted_logits, sorted_indices = torch.sort(logits, descending=False, dim=-1)
probs = torch.softmax(sorted_logits / self.temperature, -1)
cumulative_probs = torch.cumsum(probs, dim=-1)
sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p)
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
return self.sample(logits, indices_to_remove)
class BeamSearchAlgorithm(DecodingAlgorithm):
def __init__(self, num_beams: int, batch_size: int) -> None:
self.num_beams = num_beams
self.batch_size = batch_size
self._batch_beams = [list() for _ in range(batch_size)]
def __call__(self, logits: torch.Tensor):
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
probs = torch.log_softmax(sorted_logits, -1)
if len(self._batch_beams[0]) > 0:
for batch_idx in range(self.batch_size):
new_beams = []
cur_beams = self._batch_beams[batch_idx]
for beam_idx in range(len(cur_beams)):
probs_idx = batch_idx + beam_idx * self.batch_size
new_beam = cur_beams[beam_idx]
for hypo_idx in range(self.num_beams):
new_beams.append(
(new_beam[0] + probs[probs_idx, hypo_idx].item(), beam_idx * self.num_beams + hypo_idx)
)
self._batch_beams[batch_idx] = sorted(new_beams, reverse=True)[: self.num_beams]
else:
for batch_idx in range(self.batch_size):
for beam_idx in range(self.num_beams):
self._batch_beams[batch_idx].append((probs[batch_idx, beam_idx].item(), beam_idx))
return_hypos = []
return_tokens = []
for batch_idx in range(self.batch_size):
cur_beam = self._batch_beams[batch_idx]
return_hypos.append(list())
return_tokens.append(list())
for beam in cur_beam:
beam_idx = beam[1] // self.num_beams
hypo_idx = batch_idx + beam_idx * self.batch_size
token_idx = beam[1] % self.num_beams
return_hypos[-1].append(hypo_idx)
return_tokens[-1].append([sorted_indices[hypo_idx, token_idx].item()])
return_hypos = [hypo_idx for hypo_indexes in zip(*return_hypos) for hypo_idx in hypo_indexes]
return_tokens = [token_idx for token_indexes in zip(*return_tokens) for token_idx in token_indexes]
return torch.tensor(return_tokens), torch.tensor(return_hypos)

@ -1,51 +0,0 @@
from abc import ABC
import torch
class ABCBloomConstraint(ABC):
"""
Base class of all kind of decoding constraints. It can be used to implement a new constraint.
"""
def __init__(self) -> None:
pass
def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor:
"""
This method is called by the decoding algorithm to apply the constraint. It changes and returns new logits.
:param tokens_id: The token id of the last chosen token.
:param logits: The logits from the Bloom model.
:param hypo_ids: The hypothesis ids of the last tokens.
"""
pass
class EosConstraint(ABCBloomConstraint):
"""
This constrained repeats EOS token if it was generated on the previous step.
Args:
prefix: The prefix of the sequence.
eos_token_id: The id of the end of sentence token.
pad_token_id: The id of the padding token.
min_logits: The minimum logits that can be generated. Default: -1e6.
"""
def __init__(self, prefix: torch.Tensor, eos_token_id: int, pad_token_id: int, min_logits: float = -1e8) -> None:
self.eos_token_id = eos_token_id
self.min_logits = min_logits
self.past_tokens = None
self.wait_until_starting = (prefix == pad_token_id).sum(1).unsqueeze(1)
def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor:
if self.past_tokens is not None:
mask = (self.wait_until_starting < 0) & (self.past_tokens == self.eos_token_id)
logits += self.min_logits * mask
logits[mask[:, 0], self.eos_token_id] = 0
if tokens_id is not None:
self.past_tokens = tokens_id
self.wait_until_starting -= 1
return logits

@ -5,5 +5,13 @@ DUMMY = torch.empty(0) # dummy tensor that replaces empty prompt or adapter par
DUMMY_INT64 = torch.empty(0, dtype=torch.int64) DUMMY_INT64 = torch.empty(0, dtype=torch.int64)
def is_dummy(tensor: torch.Tensor): def is_dummy(tensor: torch.Tensor) -> bool:
return tensor.numel() == 0 return tensor.numel() == 0
def docstring_from(source):
def add_docstring(dest):
dest.__doc__ = source.__doc__
return dest
return add_docstring

@ -3,7 +3,6 @@ import pytest
import torch import torch
import transformers import transformers
from hivemind import get_logger from hivemind import get_logger
from transformers.generation import BeamSearchScorer, GenerationMixin as HfGenerationMixin
from petals import AutoDistributedModelForCausalLM from petals import AutoDistributedModelForCausalLM
from test_utils import * from test_utils import *
@ -17,18 +16,29 @@ def tokenizer():
return transformers.AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False) return transformers.AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
@pytest.fixture
def model():
return AutoDistributedModelForCausalLM.from_pretrained(
MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32
)
@pytest.fixture
def ref_model():
return transformers.AutoModelForCausalLM.from_pretrained(
REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
)
@pytest.mark.forked @pytest.mark.forked
@pytest.mark.parametrize("use_peft", (True, False) if ADAPTER_NAME else (False,)) @pytest.mark.parametrize("use_peft", (True, False) if ADAPTER_NAME else (False,))
@pytest.mark.parametrize("pass_empty_tensors", (True, False)) @pytest.mark.parametrize("pass_empty_tensors", (True, False))
def test_full_model_exact_match(tokenizer, use_peft, pass_empty_tensors, atol_forward=1e-3, atol_inference=1e-3): def test_full_model_exact_match(tokenizer, model, ref_model, use_peft, pass_empty_tensors, atol=1e-3):
model = AutoDistributedModelForCausalLM.from_pretrained( if use_peft:
MODEL_NAME, model.config.active_adapter = ADAPTER_NAME
initial_peers=INITIAL_PEERS,
torch_dtype=torch.float32, ref_model = peft.PeftModel.from_pretrained(ref_model, ADAPTER_NAME)
active_adapter=ADAPTER_NAME if use_peft else None, ref_model.train(False)
)
config = model.config
assert len(model.transformer.h) == model.config.num_hidden_layers
test_inputs = tokenizer("A quick brown fox was minding its own buisness", return_tensors="pt")["input_ids"] test_inputs = tokenizer("A quick brown fox was minding its own buisness", return_tensors="pt")["input_ids"]
@ -42,7 +52,7 @@ def test_full_model_exact_match(tokenizer, use_peft, pass_empty_tensors, atol_fo
recurrent_outputs = [] recurrent_outputs = []
with model.transformer.h.inference_session(max_length=embs.shape[1]) as sess: with model.transformer.h.inference_session(max_length=embs.shape[1]) as sess:
if pass_empty_tensors: if pass_empty_tensors:
recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size))) recurrent_outputs.append(sess.step(torch.empty(1, 0, model.config.hidden_size)))
for t in range(embs.shape[1]): for t in range(embs.shape[1]):
if t == 4: if t == 4:
@ -53,52 +63,39 @@ def test_full_model_exact_match(tokenizer, use_peft, pass_empty_tensors, atol_fo
recurrent_outputs.append(sess.step(embs[:, t : t + 1, :])) recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
if t == 2 and pass_empty_tensors: if t == 2 and pass_empty_tensors:
recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size))) recurrent_outputs.append(sess.step(torch.empty(1, 0, model.config.hidden_size)))
recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size))) recurrent_outputs.append(sess.step(torch.empty(1, 0, model.config.hidden_size)))
recurrent_outputs = torch.cat(recurrent_outputs, dim=1) recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
recurrent_outputs = model.transformer.ln_f(recurrent_outputs) recurrent_outputs = model.transformer.ln_f(recurrent_outputs)
recurrent_outputs = model.lm_head(recurrent_outputs) recurrent_outputs = model.lm_head(recurrent_outputs)
assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference) assert torch.allclose(
logger.info("Inference is consistent with forward") recurrent_outputs, parallel_outputs, rtol=0, atol=atol
), "Inference differs from forward pass"
del model, embs, recurrent_outputs
ref_outputs = ref_model.forward(test_inputs).logits.float()
if REF_NAME: assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol), "Outputs are not identical to HF"
ref_model = transformers.AutoModelForCausalLM.from_pretrained(
REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
) def make_generate_calls(model, inputs, *, max_new_tokens, multiple_calls=False, **kwargs):
if use_peft: if not multiple_calls:
ref_model = peft.PeftModel.from_pretrained(ref_model, ADAPTER_NAME) return model.generate(inputs, max_new_tokens=max_new_tokens, **kwargs)
ref_model.train(False)
if config.vocab_size < ref_model.config.vocab_size: with model.inference_session(max_length=inputs.shape[1] + max_new_tokens) as sess:
ref_model.resize_token_embeddings(config.vocab_size) return torch.cat(
logger.warning(f"Resized the reference model embeddings, new total = {ref_model.config.vocab_size}") [
# Sessions provided both explicitly and implicitly should work
dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool) model.generate(inputs, max_new_tokens=1, **kwargs, session=sess),
# note: this creates a dummy mask to make the test compatible with older transformer versions model.generate(None, max_new_tokens=max_new_tokens - 2, **kwargs),
# prior to https://github.com/huggingface/transformers/pull/17837 model.generate(None, max_new_tokens=1, **kwargs),
ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits.float() ],
assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward) dim=1,
logger.warning(f"Distributed forward is consistent with {type(ref_model)}.forward") )
del ref_model, ref_outputs, dummy_mask
else:
logger.warning("Did not test exact match with local model: REF_NAME environment variable is not set")
assert False
@pytest.mark.forked @pytest.mark.forked
def test_greedy_generation(tokenizer, max_new_tokens=4): def test_greedy_generation(tokenizer, model, ref_model, max_new_tokens=4):
model = AutoDistributedModelForCausalLM.from_pretrained( inputs_single = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32
)
inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
remote_outputs = model.generate(
inputs,
max_new_tokens=max_new_tokens,
)
hf_outputs = HfGenerationMixin.greedy_search(model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens)
assert torch.allclose(remote_outputs, hf_outputs), "Greedy search results are not identical to HF"
if tokenizer.pad_token_id is None: if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id tokenizer.pad_token_id = tokenizer.eos_token_id
@ -106,85 +103,49 @@ def test_greedy_generation(tokenizer, max_new_tokens=4):
"input_ids" "input_ids"
] ]
remote_outputs_batch = model.generate( options = dict(max_new_tokens=max_new_tokens, do_sample=False)
inputs_batch, for multiple_calls in [False, True]:
max_new_tokens=max_new_tokens, for inputs in [inputs_single, inputs_batch]:
) outputs = make_generate_calls(model, inputs, multiple_calls=multiple_calls, **options)
hf_outputs_batch = HfGenerationMixin.greedy_search( ref_outputs = ref_model.generate(inputs, **options)
model, input_ids=inputs_batch, max_length=inputs_batch.size(1) + max_new_tokens assert torch.allclose(
) outputs, ref_outputs
assert torch.allclose( ), f"Greedy generation is not identical to HF with {multiple_calls=}, {inputs.shape=}"
remote_outputs_batch, hf_outputs_batch
), "Greedy search results are not identical to HF in multibatch mode"
@pytest.mark.forked @pytest.mark.forked
@pytest.mark.parametrize("sampling_options", [dict(), dict(temperature=100.0), dict(top_k=5), dict(top_p=0.9)]) def test_sampling(tokenizer, model, ref_model, max_new_tokens=10):
@pytest.mark.skip("Sampling is currently not consistent with outputs from Transformers") inputs_single = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
def test_sampling(tokenizer, sampling_options, max_new_tokens=4):
torch.manual_seed(0)
model = AutoDistributedModelForCausalLM.from_pretrained(
MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32
)
logits_warper = HfGenerationMixin._get_logits_warper(model, num_beams=1, **sampling_options)
inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
with torch.random.fork_rng():
remote_outputs = model.generate(
inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
**sampling_options,
)
with torch.random.fork_rng():
hf_outputs = HfGenerationMixin.sample(
model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens, logits_warper=logits_warper
)
assert torch.allclose(remote_outputs, hf_outputs), "Sampling results are not identical to HF"
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
inputs_batch = tokenizer(["A cat sat on a mat", "A dog sat on a mat"], return_tensors="pt", padding=True)[ inputs_batch = tokenizer(["A cat sat on a mat", "A dog sat on a mat"], return_tensors="pt", padding=True)[
"input_ids" "input_ids"
] ]
with torch.random.fork_rng():
remote_outputs_batch = model.generate( for options in [
inputs_batch, dict(do_sample=True, temperature=0.5, top_k=5, top_p=0.9),
max_new_tokens=max_new_tokens, dict(do_sample=True, temperature=0.5, repetition_penalty=1.2),
do_sample=True, ]:
**sampling_options, options.update(max_new_tokens=max_new_tokens)
) for multiple_calls in [False, True]:
with torch.random.fork_rng(): for inputs in [inputs_single, inputs_batch]:
hf_outputs_batch = HfGenerationMixin.sample( torch.manual_seed(0)
model, outputs = make_generate_calls(model, inputs, multiple_calls=multiple_calls, **options)
input_ids=inputs_batch,
max_length=inputs_batch.size(1) + max_new_tokens, torch.manual_seed(0)
logits_warper=logits_warper, ref_outputs = ref_model.generate(inputs, **options)
)
assert torch.allclose( assert torch.allclose(
remote_outputs_batch, hf_outputs_batch outputs, ref_outputs
), "Sampling results are not identical to HF in multibatch mode" ), f"Sampling is not identical to HF with {options=}, {multiple_calls=}, {inputs.shape=}"
@pytest.mark.forked @pytest.mark.forked
def test_beam_search_generation(tokenizer, max_new_tokens=4, num_beams=2): def test_beam_search_generation(tokenizer, model, ref_model, max_new_tokens=4, num_beams=5):
model = AutoDistributedModelForCausalLM.from_pretrained( inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32
) options = dict(max_new_tokens=max_new_tokens, num_beams=num_beams, do_sample=False)
text = "A cat sat on a mat" outputs = make_generate_calls(model, inputs, **options)
inputs = tokenizer(text, return_tensors="pt")["input_ids"] ref_outputs = ref_model.generate(inputs, **options)
remote_outputs = model.generate( assert torch.allclose(outputs, ref_outputs), f"Beam search results are not identical to HF"
inputs,
max_new_tokens=max_new_tokens,
num_beams=num_beams,
)
beam_scorer = BeamSearchScorer(
batch_size=inputs.size(0),
num_beams=num_beams,
device=inputs.device,
length_penalty=0,
do_early_stopping=False,
)
hf_inputs = tokenizer([text] * 2, return_tensors="pt")["input_ids"]
hf_outputs = HfGenerationMixin.beam_search(
model, input_ids=hf_inputs, max_length=inputs.size(1) + max_new_tokens, beam_scorer=beam_scorer
)
assert torch.allclose(remote_outputs, hf_outputs), "Beam search results are not identical to HF"

Loading…
Cancel
Save