From de2475f31ce544a976ad01aa4ddeb212ac782baf Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Sun, 20 Aug 2023 19:18:36 +0400 Subject: [PATCH] Make client compatible with transformers' GenerationMixin (#464) This PR drops custom generation codes and introduces compatibility with `transformers.GenerationMixin` instead. This includes support for more sampling options (`top_p`, `top_k`, `repetition_penalty` requested in #460) and beam search - all that is now identical to running model with transformers locally. Most features (excluding beam search and other rarely used stuff) are also compatible with resuming existing sessions. ### Breaking changes If `.generate()` or forward passes are being run inside an `.inference_session()` context, they now use the opened session by default. So, these snippets are now equivalent: ```python # Using default session with model.inference_session(max_length=100): output_ids = model.generate(input_ids, max_new_tokens=3) # Explicitly specifying a session with model.inference_session(max_length=100) as sess: output_ids = model.generate(input_ids, max_new_tokens=3, session=sess) ``` Earlier, the 1st snippet was creating a new session, which is not what most people expected (= such code was most likely to introduce a bug, which is now fixed). --- .github/workflows/run-tests.yaml | 1 + src/petals/client/from_pretrained.py | 14 +- src/petals/client/inference_session.py | 12 +- src/petals/client/lm_head.py | 4 +- src/petals/client/ptune.py | 4 +- src/petals/client/remote_generation.py | 427 ++++++--------------- src/petals/client/remote_sequential.py | 55 ++- src/petals/client/sequential_autograd.py | 2 +- src/petals/models/bloom/model.py | 42 +- src/petals/models/llama/model.py | 44 ++- src/petals/server/block_functions.py | 2 +- src/petals/server/task_prioritizer.py | 8 +- src/petals/utils/generation_algorithms.py | 128 ------ src/petals/utils/generation_constraints.py | 51 --- src/petals/utils/misc.py | 10 +- tests/test_full_model.py | 205 ++++------ 16 files changed, 332 insertions(+), 677 deletions(-) delete mode 100644 src/petals/utils/generation_algorithms.py delete mode 100644 src/petals/utils/generation_constraints.py diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index 735fd2a..f4a41f2 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -41,6 +41,7 @@ jobs: pip install .[dev] - name: Test run: | + set -x # Print executed commands export MODEL_NAME="${{ matrix.model }}" export REF_NAME="${{ matrix.model }}" export ADAPTER_NAME="${{ matrix.model == 'bigscience/bloom-560m' && 'artek0chumak/bloom-560m-safe-peft' || '' }}" diff --git a/src/petals/client/from_pretrained.py b/src/petals/client/from_pretrained.py index b8d02c0..f2c88d2 100644 --- a/src/petals/client/from_pretrained.py +++ b/src/petals/client/from_pretrained.py @@ -3,7 +3,7 @@ import json import os import re import tempfile -import threading +from contextvars import ContextVar from typing import List, Optional, Tuple, Union import torch @@ -47,18 +47,16 @@ class FromPretrainedMixin: ) -_shard_config = threading.local() -_shard_config.ignored_keys = None +_ignored_keys = ContextVar("ignored_keys", default=None) @contextlib.contextmanager def ignore_keys(patterns: List[str]): + token = _ignored_keys.set(patterns) try: - prev_patterns = _shard_config.ignored_keys - _shard_config.ignored_keys = patterns yield finally: - _shard_config.ignored_keys = prev_patterns + _ignored_keys.reset(token) def patched_get_checkpoint_shard_files( @@ -66,7 +64,7 @@ def patched_get_checkpoint_shard_files( ) -> Tuple[List[str], dict]: """Same as modeling_utils.get_checkpoint_shard_files(), but does not download shards for the ignored keys.""" - should_ignore_keys = _shard_config.ignored_keys is not None + should_ignore_keys = _ignored_keys.get() is not None tempdir_ctx = tempfile.TemporaryDirectory() if should_ignore_keys else contextlib.nullcontext() with tempdir_ctx as tempdir: if should_ignore_keys: @@ -77,7 +75,7 @@ def patched_get_checkpoint_shard_files( index["weight_map"] = { param_name: filename for param_name, filename in index["weight_map"].items() - if all(re.search(pattern, param_name) is None for pattern in _shard_config.ignored_keys) + if all(re.search(pattern, param_name) is None for pattern in _ignored_keys.get()) } n_loaded_shards = len(set(index["weight_map"].values())) logger.debug(f"Loading {n_loaded_shards} shards out of {n_original_shards}") diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index 7f467b6..c9d0a97 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -230,7 +230,7 @@ class InferenceSession: self._server_sessions = [] self._position = 0 self._max_length = max_length - self.last_token_id = None + self.output_ids = None @property def num_blocks(self) -> int: @@ -377,3 +377,13 @@ class InferenceSession: def __del__(self): self.close() + + @property + def last_token_id(self) -> Optional[torch.Tensor]: # Backward compatibility with Petals < 2.1.0 + return self.output_ids[:, -1:] if self.output_ids is not None else None + + @last_token_id.setter + def last_token_id(self, value: torch.Tensor): # Backward compatibility with Petals < 2.1.0 + if self.output_ids is None: + raise RuntimeError("Can't override `last_token_id` since the session has not stepped yet") + self.output_ids[:, -1:] = value diff --git a/src/petals/client/lm_head.py b/src/petals/client/lm_head.py index 938d6da..cbea89d 100644 --- a/src/petals/client/lm_head.py +++ b/src/petals/client/lm_head.py @@ -70,8 +70,8 @@ class LMHead(nn.Module): if not self._bf16_warning_shown: if self.weight.numel() * 4 < 0.9 * psutil.virtual_memory().total: logger.warning( - "Running the client with dtype bfloat16 on CPU may be slow, since your CPU doesn't support AVX512. " - "Consider loading the model with torch_dtype='float32'" + "Running the model in bfloat16 on CPU will be slow since your CPU does not support AVX512. " + "To speed it up, load the model in float32 using .from_pretrained(..., torch_dtype=torch.float32)" ) self._bf16_warning_shown = True diff --git a/src/petals/client/ptune.py b/src/petals/client/ptune.py index 684cc23..f3995d6 100644 --- a/src/petals/client/ptune.py +++ b/src/petals/client/ptune.py @@ -76,9 +76,9 @@ def force_non_empty_weights(): [1] https://github.com/huggingface/transformers/blob/ab9fe45236cd99b8797df78219438f8f6662bb42/src/transformers/modeling_utils.py#L2515 """ + possibly_patched_register_parameter = nn.Module.register_parameter + nn.Module.register_parameter = _original_register_parameter try: - possibly_patched_register_parameter = nn.Module.register_parameter - nn.Module.register_parameter = _original_register_parameter yield finally: nn.Module.register_parameter = possibly_patched_register_parameter diff --git a/src/petals/client/remote_generation.py b/src/petals/client/remote_generation.py index 9c2b51f..793b573 100644 --- a/src/petals/client/remote_generation.py +++ b/src/petals/client/remote_generation.py @@ -1,349 +1,142 @@ import contextlib -from typing import List, Optional +import dataclasses +from contextvars import ContextVar +from typing import ContextManager, List, Optional import torch +import transformers from hivemind.utils.logging import get_logger +from transformers.generation.utils import ModelOutput from petals.client.inference_session import InferenceSession -from petals.utils.generation_algorithms import ( - BeamSearchAlgorithm, - DecodingAlgorithm, - GreedyAlgorithm, - NucleusAlgorithm, - SamplingAlgorithm, - TopKAlgorithm, -) -from petals.utils.generation_constraints import ABCBloomConstraint, EosConstraint +from petals.client.remote_sequential import RemoteSequential +from petals.utils.misc import DUMMY, docstring_from logger = get_logger(__name__) -class RemoteGenerationMixin: - """ - A class containing all functions for auto-regressive text generation, to be used as a mixin in [`BloomForCausalLM`]. - The class exposes can be used for: - - *greedy decoding*. - - *multinomial, top-k and top-p sampling*. - - *beam-search decoding* - - This class is similar to transformer's [`generation_utils.GenerationMixin`], it can be used instead of it. - However, it has some differences for remote usage. - """ - - def inference_session(self, **kwargs) -> InferenceSession: - """ - Returns an inference session for the model's RemoteSequential module. +@dataclasses.dataclass(frozen=True) +class RemotePastKeyValues: + """A mock class representing the fact that `past_key_values` do exist but are stored on remote servers.""" - :param max_length: Maximal expected length of inference results. Servers use this parameter - to calculate the size of attention caches allocated to this client. - """ + hypo_ids: Optional[torch.LongTensor] = None - return self.transformer.h.inference_session(**kwargs) + def __getitem__(self, _index: int) -> List[torch.Tensor]: + return [DUMMY] # For compatibility with BloomForCausalLM.prepare_inputs_for_generation() - @torch.inference_mode() - def generate( - self, - inputs: Optional[torch.Tensor] = None, - *, - do_sample: Optional[bool] = None, - temperature: float = 1.0, - top_k: Optional[int] = None, - top_p: Optional[float] = None, - num_beams: Optional[int] = 1, - bos_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, - pad_token_id: Optional[int] = None, - max_length: Optional[int] = None, - max_new_tokens: Optional[int] = None, - decoding_algorithm: Optional[DecodingAlgorithm] = None, - provided_constraints: List[ABCBloomConstraint] = [], - num_return_sequences: Optional[int] = None, - session: Optional[InferenceSession] = None, - ) -> torch.LongTensor: - """ - Generates sequences of token ids for models with a language modeling head. - :param inputs: The input tokens to the model. - :param do_sample: Whether to sample from the model predictions or take the argmax. - :param temperature: The temperature to use for sampling. - :param top_k: The number of results to return. - :param top_p: The cumulative probability of results to return. - :param num_beams: The number of beams to use for beam search. - :param bos_token_id: The id of the beginning of sentence token. - :param eos_token_id: The id of the end of sentence token. - :param pad_token_id: The id of the padding token. - :param max_length: The maximum number of tokens in the output (including input tokens). - :param max_new_tokens: The maximum number of tokens to generate. - :param decoding_algorithm: The decoding algorithm to use. - :param provided_constraints: A list of constraints to use. - :param num_return_sequences: How many hypothesis from the beam will be in output. - """ +_skipped_tokens = ContextVar("skipped_tokens", default=0) - prefix_length = 0 if inputs is None else inputs.size(1) - prefix_length += self.config.pre_seq_len - bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id - pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id +class _SkipTokensMixin: + # This override is used in RemoteGenerationMixin by has to be defined in a class not named as "GenerationMixin" + # due to how transformers.PreTrainedModel.can_generate() works + def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> dict: + input_ids = input_ids[:, _skipped_tokens.get() :] + _skipped_tokens.set(0) + return super().prepare_inputs_for_generation(input_ids, **kwargs) - assert (max_length is None) != (max_new_tokens is None), "please set max_length or max_new_tokens (not both)" - if max_length is not None and max_new_tokens is None: - max_new_tokens = max_length - prefix_length - assert max_new_tokens > 0, f"Provided max_length is less than prefix size: {max_length} < {inputs.size(1)}" - elif max_length is None and max_new_tokens is not None: - max_length = prefix_length + max_new_tokens - resuming_session = session is not None and session.last_token_id is not None - if num_beams > 1 and resuming_session: - raise NotImplementedError( - "Resuming inference session in .generate() along with beam search is not supported yet" - ) +class RemoteGenerationMixin(_SkipTokensMixin): + """ + This class is an upgrade to `transformers.GenerationMixin` that: + + - Designed to be compatible with most `transformers.GenerationMixin` strategies and options + - Supports generation inside a remote InferenceSession, so that remote servers store your attention caches and + you don't have to rerun the prefix through all the servers to generate each new token + - Supports multiple `.generate()` calls inside one InferenceSession, so you can easily run interactive generation + by showing tokens on the fly (multiple calls like `.generate(None, max_new_tokens=1, ...)`) or + accept prompts from a user in a chat bot (multiple calls like `.generate(new_prompts, ...)`). + - If there is no active session, `.generate()` will create a new InferenceSession with proper `max_length`. + Otherwise, `.generate()` will use the active session. You can use the `session=...` argument to override that. + """ - if inputs is not None: - assert isinstance(inputs, torch.Tensor) and inputs.ndim == 2, "inputs must be a 2d tensor [batch, length]" - if resuming_session: - inputs = torch.cat([session.last_token_id, inputs], dim=1) - else: - if resuming_session: - inputs = session.last_token_id - else: - assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs" - inputs = torch.tensor([[bos_token_id]] * num_beams, dtype=torch.long, device=self.device) - batch_size = inputs.size(0) + @docstring_from(RemoteSequential.active_session) + @property + def active_session(self) -> Optional[InferenceSession]: + return self.transformer.h.active_session - if decoding_algorithm is None: - if do_sample: - decoding_algorithm = self._choose_sample_algorithm(temperature, top_k, top_p) - elif num_beams is not None and num_beams > 1: - decoding_algorithm = BeamSearchAlgorithm(num_beams, batch_size=batch_size) - else: - if top_k is not None or top_p is not None: - logger.warning("You passed top_k or top_p but did not pass do_sample=True. Running greedy sampling") - decoding_algorithm = GreedyAlgorithm() + @docstring_from(RemoteSequential.use_session) + def use_session(self, session: Optional[InferenceSession]) -> ContextManager[InferenceSession]: + return self.transformer.h.use_session(session) - if num_beams > 1: - inputs = torch.cat([inputs] * num_beams, dim=0) - if batch_size > 1: - # TODO: resolve padding problem - logger.warning( - f"You set batch_size {batch_size} within beam search generation. " - f"Be careful, results on sequences with different length may be padded wrong way" - ) + @docstring_from(RemoteSequential.inference_session) + def inference_session(self, **kwargs) -> ContextManager[InferenceSession]: + return self.transformer.h.inference_session(**kwargs) - if num_return_sequences is None: - num_return_sequences = 1 + @docstring_from(transformers.GenerationMixin.generate.__doc__) + def generate( + self, inputs: Optional[torch.Tensor] = None, *args, session: Optional[InferenceSession] = None, **kwargs + ): + self._fix_generate_kwargs(kwargs) + + if session is not None: + # If a session specified explicitly, use it + context_manager = self.use_session(session) + elif self.active_session is not None: + # If there's an active session, don't do anything + context_manager = contextlib.nullcontext(self.active_session) + else: + # If there's no active session, create a new one - assert num_return_sequences <= num_beams, ( - f"You want more sequences than the beam has." - " Check num_return_sequences: {num_return_sequences} and num_beams: {num_beams}." - ) + max_length = kwargs.get("max_length") + max_new_tokens = kwargs.get("max_new_tokens") + assert (max_length is None) != ( + max_new_tokens is None + ), "You should set `max_length` or `max_new_tokens` (but not both) to reserve server-side attention caches" - constraints = self._get_constraints( - inputs=inputs, - eos_token_id=eos_token_id, - pad_token_id=pad_token_id, - provided_constraints=provided_constraints, - ) + if max_length is not None: + session_max_length = max_length + else: + session_max_length = (inputs.shape[1] if inputs is not None else 0) + max_new_tokens + context_manager = self.inference_session(max_length=session_max_length) - if session is None: - context_manager = self.inference_session(max_length=max_length) - else: - context_manager = contextlib.nullcontext(session) # Doesn't actually enter session or exit from it with context_manager as session: - outputs = [] - # Find samples with padded inputs. - # They will be changed before all of the samples have right length. - if torch.any(inputs == pad_token_id): # TODO: move to prepare_inputs - outputs += [inputs[:, : inputs.size(1) - (inputs == pad_token_id).sum(-1).max()]] + # Prepend the tokens from the previous .generate() call + n_prev_tokens = session.output_ids.shape[1] if session.output_ids is not None else 0 + if n_prev_tokens > 0: + if kwargs.get("num_beams", 1) > 1: + logger.warning( + "Beam search will not work properly in the resumed petals.InferenceSession " + "since intermediate beam entries are lost" + ) + + if inputs is not None: + inputs = torch.cat([session.output_ids, inputs], dim=1) + else: + inputs = session.output_ids + + # Don't actually run all previous tokens through the transformer, + # but keep them for transformers.GenerationMixin (e.g., to compute repetition_penalty) + _skipped_tokens.set(max(0, n_prev_tokens - 1)) + + result = super().generate(inputs, *args, **kwargs) + + sequences = result.sequences if isinstance(result, ModelOutput) else result + # Save tokens from this .generate() call + session.output_ids = sequences + # Crop the last tokens from the previous call + sequences = sequences[:, n_prev_tokens:].clone() + if isinstance(result, ModelOutput): + result.sequences = sequences else: - outputs += [inputs] - last_token_id = None - seq_idx = outputs[0].size(1) - hypo_ids = torch.arange(outputs[0].size(0)) - while True: - hidden_state = self.transformer.word_embeddings(outputs[-1]) - intermediate_prompts = None - if self.config.pre_seq_len > 0 and len(outputs) == 1: - prompts, intermediate_prompts = self.transformer.get_prompt(hidden_state.size(0)) - hidden_state = torch.cat([prompts, hidden_state], dim=1) - hidden_state = self.transformer.word_embeddings_layernorm(hidden_state) + result = sequences - hidden_state = session.step(hidden_state, prompts=intermediate_prompts, hypo_ids=hypo_ids)[:, -1] + return result - hidden_state = self.transformer.ln_f(hidden_state) - lm_logits = self.lm_head(hidden_state) + @staticmethod + def _fix_generate_kwargs(kwargs: dict) -> dict: + # Suppress inappropriate "Both max_new_tokens and max_length" HF warning + if "max_length" in kwargs and kwargs["max_length"] is None: + del kwargs["max_length"] - for constraint in constraints: - lm_logits = constraint(last_token_id, lm_logits, hypo_ids) - last_token_id, hypo_ids = decoding_algorithm(lm_logits) + # Support do_sample = {0, 1} for backward compatibility with Petals < 2.1.0 + do_sample = kwargs.get("do_sample") + if isinstance(do_sample, int): + kwargs["do_sample"] = bool(do_sample) - # If some samples were padded, change only these samples - if seq_idx < inputs.size(1): - pad_token_mask = inputs[:, seq_idx : seq_idx + 1] == pad_token_id - last_token_id = (~pad_token_mask) * inputs[ - :, seq_idx : seq_idx + 1 - ] + pad_token_mask * last_token_id - - # TODO: refactor outputs - if num_beams > 1: - for i in range(len(outputs), 1, -1): - outputs[i - 1] = outputs[i - 1][hypo_ids] - - outputs.append(last_token_id) - session.last_token_id = last_token_id - seq_idx += 1 - if torch.all(last_token_id == eos_token_id) or len(outputs) > max_new_tokens: - break - - outputs = torch.cat(outputs, dim=-1) - - if resuming_session: - outputs = outputs[:, 1:] - if num_beams > 1: - pre_return_idx = [ - torch.arange(idx, num_return_sequences * batch_size, batch_size) for idx in range(batch_size) - ] - return_idx = torch.cat(pre_return_idx, dim=0) - outputs = outputs[return_idx] - - return outputs - - def greedy_search( - self, - input_ids: torch.LongTensor, - max_length: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, - provided_constraints: List[ABCBloomConstraint] = [], - ) -> torch.LongTensor: - """ - Generates sequences of token ids for models with a language modeling head. Uses greedy search. - - :param input_ids: The input tokens to the model. - :param max_length: The maximum length of the sequence to generate. - :param pad_token_id: The id of the padding token. - :param eos_token_id: The id of the end of sentence token. - :param provided_constraints: A list of constraints to use. - """ - return self.generate( - inputs=input_ids, - max_new_tokens=max_length, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - decoding_algorithm=GreedyAlgorithm(), - provided_constraints=provided_constraints, - ) - - def sample( - self, - input_ids: torch.LongTensor, - temperature: float = 1.0, - top_k: Optional[int] = None, - top_p: Optional[float] = None, - max_length: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, - provided_constraints: List[ABCBloomConstraint] = [], - ) -> torch.LongTensor: - """ - Generates sequences of token ids for models with a language modeling head. Uses multinomial sampling. - If top_k is provided, uses top_k sampling. If top_p is provided, uses nucleus sampling. - - :param: input_ids: The input tokens to the model. - :param: temperature: The temperature to use for sampling. - :param: top_k: The number of samples to use for top_k sampling. - :param: top_p: The probability of using top_p sampling. - :param: max_length: The maximum length of the sequence to generate. - :param: pad_token_id: The id of the padding token. - :param: eos_token_id: The id of the end of sentence token. - :param: provided_constraints: A list of constraints to use. - """ - - return self.generate( - inputs=input_ids, - max_new_tokens=max_length, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - decoding_algorithm=self._choose_sample_algorithm(temperature, top_k, top_p), - provided_constraints=provided_constraints, - ) - - def beam_search( - self, - input_ids: torch.LongTensor, - num_beams: int = 1, - max_length: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, - provided_constraints: List[ABCBloomConstraint] = [], - ) -> torch.LongTensor: - """ - Generates sequences of token ids for models with a language modeling head. Uses beam search. - - :param input_ids: The input tokens to the model. - :param num_beams: The number of beams to use. - :param max_length: The maximum length of the sequence to generate. - :param pad_token_id: The id of the padding token. - :param eos_token_id: The id of the end of sentence token. - :param provided_constraints: A list of constraints to use. - """ - decoding_algorithm = BeamSearchAlgorithm( - num_beams=num_beams, - batch_size=input_ids.size(0), - ) - return self.generate( - inputs=input_ids, - num_beams=num_beams, - max_new_tokens=max_length, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - decoding_algorithm=decoding_algorithm, - provided_constraints=provided_constraints, - ) - - def beam_sample( - self, - input_ids: torch.LongTensor, - max_length: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, - provided_constraints: List[ABCBloomConstraint] = [], - ) -> torch.LongTensor: - raise NotImplementedError - - def group_beam_search( - self, - input_ids: torch.LongTensor, - max_length: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[int] = None, - provided_constraints: List[ABCBloomConstraint] = [], - ) -> torch.LongTensor: - raise NotImplementedError - - def _choose_sample_algorithm( - self, - temperature: float = 1.0, - top_k: Optional[int] = None, - top_p: Optional[float] = None, - ) -> DecodingAlgorithm: - if (top_k is not None) and (top_p is not None): - raise ValueError("You have to provide only top_k or top_p for sampling") - if top_k is not None: - return TopKAlgorithm(top_k, temperature) - elif top_p is not None: - return NucleusAlgorithm(top_p, temperature) - else: - return SamplingAlgorithm(temperature) + return kwargs - def _get_constraints( - self, - inputs: Optional[torch.Tensor] = None, - eos_token_id: Optional[int] = None, - pad_token_id: Optional[int] = None, - provided_constraints: List[ABCBloomConstraint] = [], - ) -> List[ABCBloomConstraint]: - constraints = [] - constraints.extend(provided_constraints) - constraints.append(EosConstraint(inputs, eos_token_id, pad_token_id)) - return constraints + @staticmethod + def _reorder_cache(past_key_values: RemotePastKeyValues, beam_idx: torch.LongTensor) -> RemotePastKeyValues: + return dataclasses.replace(past_key_values, hypo_ids=beam_idx) diff --git a/src/petals/client/remote_sequential.py b/src/petals/client/remote_sequential.py index 1df4a42..c6d2833 100644 --- a/src/petals/client/remote_sequential.py +++ b/src/petals/client/remote_sequential.py @@ -1,5 +1,7 @@ from __future__ import annotations +from contextlib import contextmanager +from contextvars import ContextVar from typing import Optional, Union import torch @@ -11,7 +13,6 @@ from petals.client.inference_session import InferenceSession from petals.client.routing import RemoteSequenceManager from petals.client.sequential_autograd import _RemoteSequentialAutogradFunction from petals.data_structures import UID_DELIMITER -from petals.utils.misc import DUMMY logger = get_logger(__name__) @@ -46,11 +47,52 @@ class RemoteSequential(nn.Module): sequence_manager = RemoteSequenceManager(config, block_uids, dht=dht, **kwargs) self.sequence_manager = sequence_manager - def forward(self, inputs: torch.Tensor, prompts: torch.Tensor = DUMMY): + self._active_session = ContextVar("active_session", default=None) + + def forward(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: assert inputs.ndim == 3, "inputs must be a tensor of shape [batch_size, seq_length, hidden_size]" - assert inputs.shape[1] <= 2048, "The sequence length is capped at 2048 tokens in this version" - outputs = _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager) - return outputs + if self.active_session is None: + assert all(v is None for v in kwargs.values()), f"Extra kwargs are not supported in forward: {kwargs}" + return _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager) + else: + return self.active_session.step(inputs, prompts, **kwargs) + + @property + def active_session(self) -> Optional[InferenceSession]: + """ + If called inside `with model.inference_session(...):` or `with model.use_session(...):`, + returns an active InferenceSession. Otherwise, returns None. + """ + + return self._active_session.get() + + @property + def position(self) -> int: + """Returns the prefix length (in tokens) in the active inference session or zero if no session is active.""" + + return self.active_session.position if self.active_session is not None else 0 + + @contextmanager + def use_session(self, session: Optional[InferenceSession]) -> InferenceSession: + """Inside this context, forward() will use an _existing_ InferenceSession provided as the argument.""" + + token = self._active_session.set(session) + try: + yield session + finally: + self._active_session.reset(token) + + @contextmanager + def inference_session(self, **kwargs) -> InferenceSession: + """ + Inside this context, forward() will use a _new_ InferenceSession created with given parameters. + + :param max_length: Maximal expected length of inference results. Servers use this parameter + to calculate the size of attention caches allocated to this client. + """ + + with InferenceSession(self.sequence_manager, **kwargs) as session, self.use_session(session): + yield session def __getitem__(self, ix: Union[int, slice]) -> RemoteSequential: return RemoteSequential( @@ -65,8 +107,5 @@ class RemoteSequential(nn.Module): def __len__(self): return len(self.sequence_manager) - def inference_session(self, **kwargs) -> InferenceSession: - return InferenceSession(self.sequence_manager, **kwargs) - def extra_repr(self) -> str: return f"modules={self.sequence_manager.block_uids[0]}..{self.sequence_manager.block_uids[-1]}" diff --git a/src/petals/client/sequential_autograd.py b/src/petals/client/sequential_autograd.py index 41bc994..9d965d2 100644 --- a/src/petals/client/sequential_autograd.py +++ b/src/petals/client/sequential_autograd.py @@ -230,7 +230,7 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function): def forward(ctx, inputs: torch.Tensor, prompts: torch.Tensor, sequence_manager: RemoteSequenceManager): batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1) input_batches: Sequence[torch.Tensor] = inputs.detach().split(batch_size) - if is_dummy(prompts): + if prompts is None or is_dummy(prompts): prompt_batches = [DUMMY] * len(input_batches) else: prompt_batches: Sequence[torch.Tensor] = prompts.detach().split(batch_size, dim=1) diff --git a/src/petals/models/bloom/model.py b/src/petals/models/bloom/model.py index e03adca..cf83822 100644 --- a/src/petals/models/bloom/model.py +++ b/src/petals/models/bloom/model.py @@ -10,7 +10,7 @@ from transformers.models.bloom import BloomForCausalLM, BloomForSequenceClassifi from petals.client.from_pretrained import FromPretrainedMixin from petals.client.lm_head import LMHead from petals.client.ptune import PTuneMixin -from petals.client.remote_generation import RemoteGenerationMixin +from petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues from petals.client.remote_sequential import RemoteSequential from petals.models.bloom.config import DistributedBloomConfig @@ -39,16 +39,15 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel): def forward( self, input_ids: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, + past_key_values: Optional[RemotePastKeyValues] = None, attention_mask: Optional[torch.Tensor] = None, - **kwargs, + head_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, ): - assert attention_mask is None, f"{self.__class__.__name__} does not support attention masks right now" - - for k, v in kwargs.items(): - if not (v is None or v is False): - logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})") - if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: @@ -59,21 +58,34 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel): else: raise ValueError("You have to specify either input_ids or inputs_embeds") + # The causal mask will be added on the server-side + assert ( + attention_mask is None or (attention_mask == 1).all() + ), f"Custom attention masks are not supported, {attention_mask=}" + assert head_mask is None, f"Custom head masks are not supported, {head_mask=}" + assert use_cache is None or use_cache, f"{use_cache=} is not supported" + assert not output_attentions, f"{output_attentions=} is not supported" + assert not output_hidden_states, f"{output_hidden_states=} is not supported" + assert return_dict is None or return_dict, f"{return_dict=} is not supported" + if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) - if self.config.tuning_mode and "ptune" in self.config.tuning_mode: + if self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0: batch_size = inputs_embeds.shape[0] prompts, intermediate_prompts = self.get_prompt(batch_size) inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1) + else: + prompts = intermediate_prompts = None hidden_states = self.word_embeddings_layernorm(inputs_embeds) output_shape = input_shape + (hidden_states.size(-1),) - if self.config.tuning_mode and "ptune" in self.config.tuning_mode: - hidden_states = self.h(hidden_states, prompts=intermediate_prompts) - else: - hidden_states = self.h(hidden_states) + hidden_states = self.h( + hidden_states, + prompts=intermediate_prompts, + hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None, + ) # Remove prefix if self.config.tuning_mode and "ptune" in self.config.tuning_mode: @@ -84,7 +96,7 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel): hidden_states = hidden_states.view(output_shape) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=None, + past_key_values=RemotePastKeyValues(), hidden_states=None, attentions=None, ) diff --git a/src/petals/models/llama/model.py b/src/petals/models/llama/model.py index cafb45b..a9dfcc1 100644 --- a/src/petals/models/llama/model.py +++ b/src/petals/models/llama/model.py @@ -10,7 +10,7 @@ from transformers.models.llama import LlamaForCausalLM, LlamaForSequenceClassifi from petals.client.from_pretrained import FromPretrainedMixin from petals.client.lm_head import LMHead from petals.client.ptune import PTuneMixin -from petals.client.remote_generation import RemoteGenerationMixin +from petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues from petals.client.remote_sequential import RemoteSequential from petals.models.llama.config import DistributedLlamaConfig @@ -39,16 +39,15 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel): def forward( self, input_ids: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, - **kwargs, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[RemotePastKeyValues] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, ) -> BaseModelOutputWithPast: - assert attention_mask is None, f"{self.__class__.__name__} does not support attention masks right now" - - for k, v in kwargs.items(): - if not (v is None or v is False): - logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})") - if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: @@ -59,21 +58,36 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel): else: raise ValueError("You have to specify either input_ids or inputs_embeds") + # The causal mask will be added on the server-side + assert ( + attention_mask is None or (attention_mask == 1).all() + ), f"Custom attention masks are not supported, {attention_mask=}" + assert ( + position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all() + ), f"Non-consecutive position_ids are not supported, {position_ids=}" + assert use_cache is None or use_cache, f"{use_cache=} is not supported" + assert not output_attentions, f"{output_attentions=} is not supported" + assert not output_hidden_states, f"{output_hidden_states=} is not supported" + assert return_dict is None or return_dict, f"{return_dict=} is not supported" + if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if self.config.tuning_mode and "ptune" in self.config.tuning_mode: + if self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.layers.position == 0: batch_size = inputs_embeds.shape[0] prompts, intermediate_prompts = self.get_prompt(batch_size) inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1) + else: + prompts = intermediate_prompts = None hidden_states = inputs_embeds output_shape = input_shape + (hidden_states.size(-1),) - if self.config.tuning_mode and "ptune" in self.config.tuning_mode: - hidden_states = self.layers(hidden_states, prompts=intermediate_prompts) - else: - hidden_states = self.layers(hidden_states) + hidden_states = self.layers( + hidden_states, + prompts=intermediate_prompts, + hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None, + ) # Remove prefix if self.config.tuning_mode and "ptune" in self.config.tuning_mode: @@ -84,7 +98,7 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel): hidden_states = hidden_states.view(output_shape) return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=None, + past_key_values=RemotePastKeyValues(), hidden_states=None, attentions=None, ) diff --git a/src/petals/server/block_functions.py b/src/petals/server/block_functions.py index c682663..2c37566 100644 --- a/src/petals/server/block_functions.py +++ b/src/petals/server/block_functions.py @@ -196,7 +196,7 @@ async def iterate_rpc_inference( hypo_ids, points=point_per_piece, requested_uids=requested_uids, - type="short_inference" if can_merge_pools else "inference", + type="inference", ) # A client may pass a tensor with 0 tokens. This is a special case that occurs, e.g. diff --git a/src/petals/server/task_prioritizer.py b/src/petals/server/task_prioritizer.py index 4a575b1..9f39b3c 100644 --- a/src/petals/server/task_prioritizer.py +++ b/src/petals/server/task_prioritizer.py @@ -14,9 +14,7 @@ class TaskPrioritizerBase(ABC): class DummyTaskPrioritizer(TaskPrioritizerBase): def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float: - # Inference steps (especially short ones) go first since they are more latency-sensitive - if kwargs.get("type") == "short_inference": - return 1.0 + # Inference steps go first since they are more latency-sensitive if kwargs.get("type") == "inference": - return 2.0 - return 3.0 # Forward, backward + return 1.0 + return 2.0 # Forward, backward diff --git a/src/petals/utils/generation_algorithms.py b/src/petals/utils/generation_algorithms.py deleted file mode 100644 index d085e8b..0000000 --- a/src/petals/utils/generation_algorithms.py +++ /dev/null @@ -1,128 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Tuple - -import torch - -TokenIds = torch.Tensor -HypoIds = torch.Tensor - - -class DecodingAlgorithm(ABC): - """ - An abstract class for decoding algorithms. Describes the base function of those algorithms: - they have to select new tokens and provide the corresponding hypotheses. - """ - - @abstractmethod - def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]: - """ - :param logits: A tensor of shape (batch_size, seq_length, vocab_size) - :return: A tuple of selected token ids and corresponding hypotheses. - The shape of the token ids is (batch_size, seq_length), and the shape of the hypotheses is (batch_size) - """ - pass - - -class GreedyAlgorithm(DecodingAlgorithm): - """ - The simplest algorithm for decoding. It selects the most probable token. - """ - - def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]: - """ - Returns the most probable token. The second returned object is always a range of integers - from 0 to batch_size - 1. - """ - return logits.max(-1)[1].unsqueeze(1), torch.arange(logits.size(0)) - - -class SamplingAlgorithm(DecodingAlgorithm): - def __init__(self, temperature: float = 1.0): - self.temperature = temperature - - def sample(self, logits: torch.Tensor, indices_to_remove: torch.Tensor) -> Tuple[TokenIds, HypoIds]: - """ - :param logits: A tensor of shape (batch_size * num_hypos, vocab_size) - :param indices_to_remove: A bool tensor of shape (batch_size * num_hypos, vocab_size) - :return: A tuple of selected token ids and corresponding hypotheses. - The shape of the token ids is (batch_size, seq_length), and the shape of the hypotheses is (batch_size). - """ - logits[indices_to_remove] = -float("Inf") - probs = torch.softmax(logits / self.temperature, -1) - return torch.multinomial(probs, num_samples=1), torch.arange(logits.size(0)) - - def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]: - indices_to_remove = torch.full_like(logits, False, dtype=torch.bool) - return self.sample(logits, indices_to_remove) - - -class TopKAlgorithm(SamplingAlgorithm): - def __init__(self, top_k: int, temperature: float = 1.0) -> None: - super().__init__(temperature=temperature) - self.top_k = top_k - - def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]: - indices_to_remove = logits < torch.topk(logits, self.top_k, dim=-1)[0][..., -1, None] - return self.sample(logits, indices_to_remove) - - -class NucleusAlgorithm(SamplingAlgorithm): - def __init__(self, top_p: float, temperature: float = 1.0) -> None: - super().__init__(temperature=temperature) - self.top_p = top_p - - def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]: - sorted_logits, sorted_indices = torch.sort(logits, descending=False, dim=-1) - probs = torch.softmax(sorted_logits / self.temperature, -1) - cumulative_probs = torch.cumsum(probs, dim=-1) - - sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p) - - indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) - return self.sample(logits, indices_to_remove) - - -class BeamSearchAlgorithm(DecodingAlgorithm): - def __init__(self, num_beams: int, batch_size: int) -> None: - self.num_beams = num_beams - self.batch_size = batch_size - - self._batch_beams = [list() for _ in range(batch_size)] - - def __call__(self, logits: torch.Tensor): - sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) - probs = torch.log_softmax(sorted_logits, -1) - - if len(self._batch_beams[0]) > 0: - for batch_idx in range(self.batch_size): - new_beams = [] - cur_beams = self._batch_beams[batch_idx] - for beam_idx in range(len(cur_beams)): - probs_idx = batch_idx + beam_idx * self.batch_size - new_beam = cur_beams[beam_idx] - for hypo_idx in range(self.num_beams): - new_beams.append( - (new_beam[0] + probs[probs_idx, hypo_idx].item(), beam_idx * self.num_beams + hypo_idx) - ) - self._batch_beams[batch_idx] = sorted(new_beams, reverse=True)[: self.num_beams] - else: - for batch_idx in range(self.batch_size): - for beam_idx in range(self.num_beams): - self._batch_beams[batch_idx].append((probs[batch_idx, beam_idx].item(), beam_idx)) - - return_hypos = [] - return_tokens = [] - for batch_idx in range(self.batch_size): - cur_beam = self._batch_beams[batch_idx] - return_hypos.append(list()) - return_tokens.append(list()) - for beam in cur_beam: - beam_idx = beam[1] // self.num_beams - hypo_idx = batch_idx + beam_idx * self.batch_size - token_idx = beam[1] % self.num_beams - return_hypos[-1].append(hypo_idx) - return_tokens[-1].append([sorted_indices[hypo_idx, token_idx].item()]) - return_hypos = [hypo_idx for hypo_indexes in zip(*return_hypos) for hypo_idx in hypo_indexes] - return_tokens = [token_idx for token_indexes in zip(*return_tokens) for token_idx in token_indexes] - - return torch.tensor(return_tokens), torch.tensor(return_hypos) diff --git a/src/petals/utils/generation_constraints.py b/src/petals/utils/generation_constraints.py deleted file mode 100644 index fa9304b..0000000 --- a/src/petals/utils/generation_constraints.py +++ /dev/null @@ -1,51 +0,0 @@ -from abc import ABC - -import torch - - -class ABCBloomConstraint(ABC): - """ - Base class of all kind of decoding constraints. It can be used to implement a new constraint. - """ - - def __init__(self) -> None: - pass - - def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor: - """ - This method is called by the decoding algorithm to apply the constraint. It changes and returns new logits. - :param tokens_id: The token id of the last chosen token. - :param logits: The logits from the Bloom model. - :param hypo_ids: The hypothesis ids of the last tokens. - """ - pass - - -class EosConstraint(ABCBloomConstraint): - """ - This constrained repeats EOS token if it was generated on the previous step. - Args: - prefix: The prefix of the sequence. - eos_token_id: The id of the end of sentence token. - pad_token_id: The id of the padding token. - min_logits: The minimum logits that can be generated. Default: -1e6. - """ - - def __init__(self, prefix: torch.Tensor, eos_token_id: int, pad_token_id: int, min_logits: float = -1e8) -> None: - self.eos_token_id = eos_token_id - self.min_logits = min_logits - self.past_tokens = None - - self.wait_until_starting = (prefix == pad_token_id).sum(1).unsqueeze(1) - - def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor: - if self.past_tokens is not None: - mask = (self.wait_until_starting < 0) & (self.past_tokens == self.eos_token_id) - logits += self.min_logits * mask - logits[mask[:, 0], self.eos_token_id] = 0 - - if tokens_id is not None: - self.past_tokens = tokens_id - self.wait_until_starting -= 1 - - return logits diff --git a/src/petals/utils/misc.py b/src/petals/utils/misc.py index d8068e1..afe9fc4 100644 --- a/src/petals/utils/misc.py +++ b/src/petals/utils/misc.py @@ -5,5 +5,13 @@ DUMMY = torch.empty(0) # dummy tensor that replaces empty prompt or adapter par DUMMY_INT64 = torch.empty(0, dtype=torch.int64) -def is_dummy(tensor: torch.Tensor): +def is_dummy(tensor: torch.Tensor) -> bool: return tensor.numel() == 0 + + +def docstring_from(source): + def add_docstring(dest): + dest.__doc__ = source.__doc__ + return dest + + return add_docstring diff --git a/tests/test_full_model.py b/tests/test_full_model.py index dc2f3d7..fafbd62 100644 --- a/tests/test_full_model.py +++ b/tests/test_full_model.py @@ -3,7 +3,6 @@ import pytest import torch import transformers from hivemind import get_logger -from transformers.generation import BeamSearchScorer, GenerationMixin as HfGenerationMixin from petals import AutoDistributedModelForCausalLM from test_utils import * @@ -17,18 +16,29 @@ def tokenizer(): return transformers.AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False) +@pytest.fixture +def model(): + return AutoDistributedModelForCausalLM.from_pretrained( + MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32 + ) + + +@pytest.fixture +def ref_model(): + return transformers.AutoModelForCausalLM.from_pretrained( + REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32 + ) + + @pytest.mark.forked @pytest.mark.parametrize("use_peft", (True, False) if ADAPTER_NAME else (False,)) @pytest.mark.parametrize("pass_empty_tensors", (True, False)) -def test_full_model_exact_match(tokenizer, use_peft, pass_empty_tensors, atol_forward=1e-3, atol_inference=1e-3): - model = AutoDistributedModelForCausalLM.from_pretrained( - MODEL_NAME, - initial_peers=INITIAL_PEERS, - torch_dtype=torch.float32, - active_adapter=ADAPTER_NAME if use_peft else None, - ) - config = model.config - assert len(model.transformer.h) == model.config.num_hidden_layers +def test_full_model_exact_match(tokenizer, model, ref_model, use_peft, pass_empty_tensors, atol=1e-3): + if use_peft: + model.config.active_adapter = ADAPTER_NAME + + ref_model = peft.PeftModel.from_pretrained(ref_model, ADAPTER_NAME) + ref_model.train(False) test_inputs = tokenizer("A quick brown fox was minding its own buisness", return_tensors="pt")["input_ids"] @@ -42,7 +52,7 @@ def test_full_model_exact_match(tokenizer, use_peft, pass_empty_tensors, atol_fo recurrent_outputs = [] with model.transformer.h.inference_session(max_length=embs.shape[1]) as sess: if pass_empty_tensors: - recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size))) + recurrent_outputs.append(sess.step(torch.empty(1, 0, model.config.hidden_size))) for t in range(embs.shape[1]): if t == 4: @@ -53,52 +63,39 @@ def test_full_model_exact_match(tokenizer, use_peft, pass_empty_tensors, atol_fo recurrent_outputs.append(sess.step(embs[:, t : t + 1, :])) if t == 2 and pass_empty_tensors: - recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size))) - recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size))) + recurrent_outputs.append(sess.step(torch.empty(1, 0, model.config.hidden_size))) + recurrent_outputs.append(sess.step(torch.empty(1, 0, model.config.hidden_size))) recurrent_outputs = torch.cat(recurrent_outputs, dim=1) recurrent_outputs = model.transformer.ln_f(recurrent_outputs) recurrent_outputs = model.lm_head(recurrent_outputs) - assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference) - logger.info("Inference is consistent with forward") - - del model, embs, recurrent_outputs - - if REF_NAME: - ref_model = transformers.AutoModelForCausalLM.from_pretrained( - REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32 - ) - if use_peft: - ref_model = peft.PeftModel.from_pretrained(ref_model, ADAPTER_NAME) - ref_model.train(False) - if config.vocab_size < ref_model.config.vocab_size: - ref_model.resize_token_embeddings(config.vocab_size) - logger.warning(f"Resized the reference model embeddings, new total = {ref_model.config.vocab_size}") - - dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool) - # note: this creates a dummy mask to make the test compatible with older transformer versions - # prior to https://github.com/huggingface/transformers/pull/17837 - ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits.float() - assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward) - logger.warning(f"Distributed forward is consistent with {type(ref_model)}.forward") - del ref_model, ref_outputs, dummy_mask - else: - logger.warning("Did not test exact match with local model: REF_NAME environment variable is not set") - assert False + assert torch.allclose( + recurrent_outputs, parallel_outputs, rtol=0, atol=atol + ), "Inference differs from forward pass" + + ref_outputs = ref_model.forward(test_inputs).logits.float() + assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol), "Outputs are not identical to HF" + + +def make_generate_calls(model, inputs, *, max_new_tokens, multiple_calls=False, **kwargs): + if not multiple_calls: + return model.generate(inputs, max_new_tokens=max_new_tokens, **kwargs) + + with model.inference_session(max_length=inputs.shape[1] + max_new_tokens) as sess: + return torch.cat( + [ + # Sessions provided both explicitly and implicitly should work + model.generate(inputs, max_new_tokens=1, **kwargs, session=sess), + model.generate(None, max_new_tokens=max_new_tokens - 2, **kwargs), + model.generate(None, max_new_tokens=1, **kwargs), + ], + dim=1, + ) @pytest.mark.forked -def test_greedy_generation(tokenizer, max_new_tokens=4): - model = AutoDistributedModelForCausalLM.from_pretrained( - MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32 - ) - inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"] - remote_outputs = model.generate( - inputs, - max_new_tokens=max_new_tokens, - ) - hf_outputs = HfGenerationMixin.greedy_search(model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens) - assert torch.allclose(remote_outputs, hf_outputs), "Greedy search results are not identical to HF" +def test_greedy_generation(tokenizer, model, ref_model, max_new_tokens=4): + inputs_single = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"] if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id @@ -106,85 +103,49 @@ def test_greedy_generation(tokenizer, max_new_tokens=4): "input_ids" ] - remote_outputs_batch = model.generate( - inputs_batch, - max_new_tokens=max_new_tokens, - ) - hf_outputs_batch = HfGenerationMixin.greedy_search( - model, input_ids=inputs_batch, max_length=inputs_batch.size(1) + max_new_tokens - ) - assert torch.allclose( - remote_outputs_batch, hf_outputs_batch - ), "Greedy search results are not identical to HF in multibatch mode" + options = dict(max_new_tokens=max_new_tokens, do_sample=False) + for multiple_calls in [False, True]: + for inputs in [inputs_single, inputs_batch]: + outputs = make_generate_calls(model, inputs, multiple_calls=multiple_calls, **options) + ref_outputs = ref_model.generate(inputs, **options) + assert torch.allclose( + outputs, ref_outputs + ), f"Greedy generation is not identical to HF with {multiple_calls=}, {inputs.shape=}" @pytest.mark.forked -@pytest.mark.parametrize("sampling_options", [dict(), dict(temperature=100.0), dict(top_k=5), dict(top_p=0.9)]) -@pytest.mark.skip("Sampling is currently not consistent with outputs from Transformers") -def test_sampling(tokenizer, sampling_options, max_new_tokens=4): - torch.manual_seed(0) - - model = AutoDistributedModelForCausalLM.from_pretrained( - MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32 - ) - logits_warper = HfGenerationMixin._get_logits_warper(model, num_beams=1, **sampling_options) - inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"] - with torch.random.fork_rng(): - remote_outputs = model.generate( - inputs, - max_new_tokens=max_new_tokens, - do_sample=True, - **sampling_options, - ) - with torch.random.fork_rng(): - hf_outputs = HfGenerationMixin.sample( - model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens, logits_warper=logits_warper - ) - assert torch.allclose(remote_outputs, hf_outputs), "Sampling results are not identical to HF" +def test_sampling(tokenizer, model, ref_model, max_new_tokens=10): + inputs_single = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"] + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id inputs_batch = tokenizer(["A cat sat on a mat", "A dog sat on a mat"], return_tensors="pt", padding=True)[ "input_ids" ] - with torch.random.fork_rng(): - remote_outputs_batch = model.generate( - inputs_batch, - max_new_tokens=max_new_tokens, - do_sample=True, - **sampling_options, - ) - with torch.random.fork_rng(): - hf_outputs_batch = HfGenerationMixin.sample( - model, - input_ids=inputs_batch, - max_length=inputs_batch.size(1) + max_new_tokens, - logits_warper=logits_warper, - ) - assert torch.allclose( - remote_outputs_batch, hf_outputs_batch - ), "Sampling results are not identical to HF in multibatch mode" + + for options in [ + dict(do_sample=True, temperature=0.5, top_k=5, top_p=0.9), + dict(do_sample=True, temperature=0.5, repetition_penalty=1.2), + ]: + options.update(max_new_tokens=max_new_tokens) + for multiple_calls in [False, True]: + for inputs in [inputs_single, inputs_batch]: + torch.manual_seed(0) + outputs = make_generate_calls(model, inputs, multiple_calls=multiple_calls, **options) + + torch.manual_seed(0) + ref_outputs = ref_model.generate(inputs, **options) + + assert torch.allclose( + outputs, ref_outputs + ), f"Sampling is not identical to HF with {options=}, {multiple_calls=}, {inputs.shape=}" @pytest.mark.forked -def test_beam_search_generation(tokenizer, max_new_tokens=4, num_beams=2): - model = AutoDistributedModelForCausalLM.from_pretrained( - MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32 - ) - text = "A cat sat on a mat" - inputs = tokenizer(text, return_tensors="pt")["input_ids"] - remote_outputs = model.generate( - inputs, - max_new_tokens=max_new_tokens, - num_beams=num_beams, - ) - beam_scorer = BeamSearchScorer( - batch_size=inputs.size(0), - num_beams=num_beams, - device=inputs.device, - length_penalty=0, - do_early_stopping=False, - ) - hf_inputs = tokenizer([text] * 2, return_tensors="pt")["input_ids"] - hf_outputs = HfGenerationMixin.beam_search( - model, input_ids=hf_inputs, max_length=inputs.size(1) + max_new_tokens, beam_scorer=beam_scorer - ) - assert torch.allclose(remote_outputs, hf_outputs), "Beam search results are not identical to HF" +def test_beam_search_generation(tokenizer, model, ref_model, max_new_tokens=4, num_beams=5): + inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"] + + options = dict(max_new_tokens=max_new_tokens, num_beams=num_beams, do_sample=False) + outputs = make_generate_calls(model, inputs, **options) + ref_outputs = ref_model.generate(inputs, **options) + assert torch.allclose(outputs, ref_outputs), f"Beam search results are not identical to HF"