You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

144 lines
6.3 KiB

import contextlib
import dataclasses
from contextvars import ContextVar
from typing import ContextManager, List, Optional
import torch
import transformers
from hivemind.utils.logging import get_logger
from transformers.generation.utils import ModelOutput
from petals.client.inference_session import InferenceSession
from petals.client.remote_sequential import RemoteSequential
from petals.utils.misc import DUMMY, docstring_from
logger = get_logger(__name__)
class RemotePastKeyValues:
"""A mock class representing the fact that `past_key_values` do exist but are stored on remote servers."""
hypo_ids: Optional[torch.LongTensor] = None
def __getitem__(self, _index: int) -> List[torch.Tensor]:
return [DUMMY] # For compatibility with BloomForCausalLM.prepare_inputs_for_generation()
_skipped_tokens = ContextVar("skipped_tokens", default=0)
class _SkipTokensMixin:
# This override is used in RemoteGenerationMixin by has to be defined in a class not named as "GenerationMixin"
# due to how transformers.PreTrainedModel.can_generate() works
def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> dict:
input_ids = input_ids[:, _skipped_tokens.get() :]
return super().prepare_inputs_for_generation(input_ids, **kwargs)
class RemoteGenerationMixin(_SkipTokensMixin):
This class is an upgrade to `transformers.GenerationMixin` that:
- Designed to be compatible with most `transformers.GenerationMixin` strategies and options
- Supports generation inside a remote InferenceSession, so that remote servers store your attention caches and
you don't have to rerun the prefix through all the servers to generate each new token
- Supports multiple `.generate()` calls inside one InferenceSession, so you can easily run interactive generation
by showing tokens on the fly (multiple calls like `.generate(None, max_new_tokens=1, ...)`) or
accept prompts from a user in a chat bot (multiple calls like `.generate(new_prompts, ...)`).
- If there is no active session, `.generate()` will create a new InferenceSession with proper `max_length`.
Otherwise, `.generate()` will use the active session. You can use the `session=...` argument to override that.
def active_session(self) -> Optional[InferenceSession]:
return self.transformer.h.active_session
def use_session(self, session: Optional[InferenceSession]) -> ContextManager[InferenceSession]:
return self.transformer.h.use_session(session)
def inference_session(self, **kwargs) -> ContextManager[InferenceSession]:
return self.transformer.h.inference_session(**kwargs)
def generate(
self, inputs: Optional[torch.Tensor] = None, *args, session: Optional[InferenceSession] = None, **kwargs
if inputs is None:
inputs = kwargs.pop("input_ids", None)
if session is not None:
# If a session specified explicitly, use it
context_manager = self.use_session(session)
elif self.active_session is not None:
# If there's an active session, don't do anything
context_manager = contextlib.nullcontext(self.active_session)
# If there's no active session, create a new one
max_length = kwargs.get("max_length")
max_new_tokens = kwargs.get("max_new_tokens")
assert (max_length is None) != (
max_new_tokens is None
), "You should set `max_length` or `max_new_tokens` (but not both) to reserve server-side attention caches"
session_max_length = self.transformer.config.pre_seq_len
if max_length is not None:
session_max_length += max_length
session_max_length += (inputs.shape[1] if inputs is not None else 0) + max_new_tokens
context_manager = self.inference_session(max_length=session_max_length)
with context_manager as session:
# Prepend the tokens from the previous .generate() call
n_prev_tokens = session.output_ids.shape[1] if session.output_ids is not None else 0
if n_prev_tokens > 0:
if kwargs.get("num_beams", 1) > 1:
"Beam search will not work properly in the resumed petals.InferenceSession "
"since intermediate beam entries are lost"
if inputs is not None:
inputs =[session.output_ids, inputs], dim=1)
inputs = session.output_ids
# Don't actually run all previous tokens through the transformer,
# but keep them for transformers.GenerationMixin (e.g., to compute repetition_penalty)
_skipped_tokens.set(max(0, n_prev_tokens - 1))
result = super().generate(inputs, *args, **kwargs)
sequences = result.sequences if isinstance(result, ModelOutput) else result
# Save tokens from this .generate() call
session.output_ids = sequences
# Crop the last tokens from the previous call
sequences = sequences[:, n_prev_tokens:].clone()
if isinstance(result, ModelOutput):
result.sequences = sequences
result = sequences
return result
def _fix_generate_kwargs(kwargs: dict):
# Suppress inappropriate "Both max_new_tokens and max_length" HF warning
if "max_length" in kwargs and kwargs["max_length"] is None:
del kwargs["max_length"]
# Support do_sample = {0, 1} for backward compatibility with Petals < 2.1.0
do_sample = kwargs.get("do_sample")
if isinstance(do_sample, int):
kwargs["do_sample"] = bool(do_sample)
def _reorder_cache(past_key_values: RemotePastKeyValues, beam_idx: torch.LongTensor) -> RemotePastKeyValues:
return dataclasses.replace(past_key_values, hypo_ids=beam_idx)