Compare commits
No commits in common. 'main' and 'v2.0.0.post1' have entirely different histories.
main
...
v2.0.0.pos
@ -1,4 +1,4 @@
|
||||
from petals.client.config import ClientConfig
|
||||
from petals.client.inference_session import InferenceSession
|
||||
from petals.client.remote_sequential import RemoteSequential
|
||||
from petals.client.routing import NoSpendingPolicy, RemoteSequenceManager, SpendingPolicyBase
|
||||
from petals.client.routing.sequence_manager import RemoteSequenceManager
|
||||
from petals.client.routing.spending_policy import NoSpendingPolicy, SpendingPolicyBase
|
||||
|
@ -1,35 +0,0 @@
|
||||
import dataclasses
|
||||
import os
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
from hivemind import PeerID
|
||||
|
||||
from petals.constants import PUBLIC_INITIAL_PEERS
|
||||
|
||||
_max_retries = os.getenv("PETALS_MAX_RETRIES")
|
||||
DEFAULT_MAX_RETRIES = int(_max_retries) if isinstance(_max_retries, str) else None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ClientConfig:
|
||||
initial_peers: Sequence[str] = tuple(PUBLIC_INITIAL_PEERS) # a list of initial peers for hivemind DHT
|
||||
dht_prefix: Optional[str] = None # a prefix for all dht keys that correspond to this model (default: model name)
|
||||
daemon_startup_timeout: int = 60 # timeout for the libp2p daemon connecting to initial peers
|
||||
|
||||
show_route: Union[str, bool] = "inference" # show chosen route through servers. one of [False, "inference", True]
|
||||
allowed_servers: Optional[Sequence[Union[PeerID, str]]] = None # if defined, send requests only to these servers
|
||||
blocked_servers: Optional[Sequence[Union[PeerID, str]]] = None # if defined, do not use these servers
|
||||
use_server_to_server: bool = True # Use direct server-to-server communication
|
||||
|
||||
connect_timeout: float = 5 # timeout for opening a connection
|
||||
request_timeout: float = 3 * 60 # timeout for forward/backward/inference requests
|
||||
update_period: float = 60 # refresh DHT information once in this many seconds
|
||||
|
||||
max_retries: Optional[int] = DEFAULT_MAX_RETRIES # max number of retries before an exception (default: inf)
|
||||
min_backoff: float = 1 # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1)
|
||||
max_backoff: float = 60 # limit maximal sleep time between retries to this value
|
||||
ban_timeout: float = 15 # when a remote peer fails to respond, prevent routing to that peer for this many seconds
|
||||
active_adapter: Optional[str] = None # name of active LoRA adapter (usually, Hugging Face repo)
|
||||
|
||||
max_pinged: int = 3 # max servers to ping from each sequence side, per update
|
||||
ping_timeout: float = 2 # max time to wait for pings, per update
|
@ -1,164 +1,349 @@
|
||||
import contextlib
|
||||
import dataclasses
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, ContextManager, Dict, List, Optional, Tuple
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from hivemind.utils.logging import get_logger
|
||||
from torch import Tensor
|
||||
from transformers.cache_utils import Cache, DynamicCache
|
||||
from transformers.generation.utils import ModelOutput
|
||||
|
||||
from petals.client.inference_session import InferenceSession
|
||||
from petals.client.remote_sequential import RemoteSequential
|
||||
from petals.utils.misc import DUMMY, docstring_from
|
||||
from petals.utils.generation_algorithms import (
|
||||
BeamSearchAlgorithm,
|
||||
DecodingAlgorithm,
|
||||
GreedyAlgorithm,
|
||||
NucleusAlgorithm,
|
||||
SamplingAlgorithm,
|
||||
TopKAlgorithm,
|
||||
)
|
||||
from petals.utils.generation_constraints import ABCBloomConstraint, EosConstraint
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class RemotePastKeyValues(Cache):
|
||||
"""only keeps the number of seen tokens. pretends to be a legit cache"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.seen_tokens = 0
|
||||
self.hypo_ids: Optional[torch.LongTensor] = None
|
||||
|
||||
def __getitem__(self, _index: int) -> List[torch.Tensor]:
|
||||
return [DUMMY] # For compatibility with BloomForCausalLM.prepare_inputs_for_generation()
|
||||
|
||||
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
||||
return self.seen_tokens
|
||||
|
||||
def get_max_length(self) -> Optional[int]:
|
||||
return None
|
||||
|
||||
def update_seen(self, new_seen: int) -> None:
|
||||
self.seen_tokens += new_seen
|
||||
|
||||
def reorder_cache(self, beam_idx):
|
||||
raise NotImplementedError("Beam search reordering is not implemented yet")
|
||||
|
||||
|
||||
_skipped_tokens = ContextVar("skipped_tokens", default=0)
|
||||
|
||||
|
||||
class _SkipTokensMixin:
|
||||
# This override is used in RemoteGenerationMixin by has to be defined in a class not named as "GenerationMixin"
|
||||
# due to how transformers.PreTrainedModel.can_generate() works
|
||||
def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> dict:
|
||||
input_ids = input_ids[:, _skipped_tokens.get() :]
|
||||
_skipped_tokens.set(0)
|
||||
return super().prepare_inputs_for_generation(input_ids, **kwargs)
|
||||
|
||||
|
||||
class RemoteGenerationMixin(_SkipTokensMixin):
|
||||
class RemoteGenerationMixin:
|
||||
"""
|
||||
This class is an upgrade to `transformers.GenerationMixin` that:
|
||||
|
||||
- Designed to be compatible with most `transformers.GenerationMixin` strategies and options
|
||||
- Supports generation inside a remote InferenceSession, so that remote servers store your attention caches and
|
||||
you don't have to rerun the prefix through all the servers to generate each new token
|
||||
- Supports multiple `.generate()` calls inside one InferenceSession, so you can easily run interactive generation
|
||||
by showing tokens on the fly (multiple calls like `.generate(None, max_new_tokens=1, ...)`) or
|
||||
accept prompts from a user in a chat bot (multiple calls like `.generate(new_prompts, ...)`).
|
||||
- If there is no active session, `.generate()` will create a new InferenceSession with proper `max_length`.
|
||||
Otherwise, `.generate()` will use the active session. You can use the `session=...` argument to override that.
|
||||
A class containing all functions for auto-regressive text generation, to be used as a mixin in [`BloomForCausalLM`].
|
||||
The class exposes can be used for:
|
||||
- *greedy decoding*.
|
||||
- *multinomial, top-k and top-p sampling*.
|
||||
- *beam-search decoding*
|
||||
|
||||
This class is similar to transformer's [`generation_utils.GenerationMixin`], it can be used instead of it.
|
||||
However, it has some differences for remote usage.
|
||||
"""
|
||||
|
||||
@docstring_from(RemoteSequential.active_session)
|
||||
@property
|
||||
def active_session(self) -> Optional[InferenceSession]:
|
||||
return self.transformer.h.active_session
|
||||
def inference_session(self, **kwargs) -> InferenceSession:
|
||||
"""
|
||||
Returns an inference session for the model's RemoteSequential module.
|
||||
|
||||
@docstring_from(RemoteSequential.use_session)
|
||||
def use_session(self, session: Optional[InferenceSession]) -> ContextManager[InferenceSession]:
|
||||
return self.transformer.h.use_session(session)
|
||||
:param max_length: Maximal expected length of inference results. Servers use this parameter
|
||||
to calculate the size of attention caches allocated to this client.
|
||||
"""
|
||||
|
||||
@docstring_from(RemoteSequential.inference_session)
|
||||
def inference_session(self, **kwargs) -> ContextManager[InferenceSession]:
|
||||
return self.transformer.h.inference_session(**kwargs)
|
||||
|
||||
@docstring_from(transformers.GenerationMixin.generate.__doc__)
|
||||
@torch.inference_mode()
|
||||
def generate(
|
||||
self, inputs: Optional[torch.Tensor] = None, *args, session: Optional[InferenceSession] = None, **kwargs
|
||||
):
|
||||
self._fix_generate_kwargs(kwargs)
|
||||
if inputs is None:
|
||||
inputs = kwargs.pop("input_ids", None)
|
||||
|
||||
if session is not None:
|
||||
# If a session specified explicitly, use it
|
||||
context_manager = self.use_session(session)
|
||||
elif self.active_session is not None:
|
||||
# If there's an active session, don't do anything
|
||||
context_manager = contextlib.nullcontext(self.active_session)
|
||||
self,
|
||||
inputs: Optional[torch.Tensor] = None,
|
||||
*,
|
||||
do_sample: Optional[bool] = None,
|
||||
temperature: float = 1.0,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
num_beams: Optional[int] = 1,
|
||||
bos_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
max_length: Optional[int] = None,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
decoding_algorithm: Optional[DecodingAlgorithm] = None,
|
||||
provided_constraints: List[ABCBloomConstraint] = [],
|
||||
num_return_sequences: Optional[int] = None,
|
||||
session: Optional[InferenceSession] = None,
|
||||
) -> torch.LongTensor:
|
||||
"""
|
||||
Generates sequences of token ids for models with a language modeling head.
|
||||
|
||||
:param inputs: The input tokens to the model.
|
||||
:param do_sample: Whether to sample from the model predictions or take the argmax.
|
||||
:param temperature: The temperature to use for sampling.
|
||||
:param top_k: The number of results to return.
|
||||
:param top_p: The cumulative probability of results to return.
|
||||
:param num_beams: The number of beams to use for beam search.
|
||||
:param bos_token_id: The id of the beginning of sentence token.
|
||||
:param eos_token_id: The id of the end of sentence token.
|
||||
:param pad_token_id: The id of the padding token.
|
||||
:param max_length: The maximum number of tokens in the output (including input tokens).
|
||||
:param max_new_tokens: The maximum number of tokens to generate.
|
||||
:param decoding_algorithm: The decoding algorithm to use.
|
||||
:param provided_constraints: A list of constraints to use.
|
||||
:param num_return_sequences: How many hypothesis from the beam will be in output.
|
||||
"""
|
||||
|
||||
prefix_length = 0 if inputs is None else inputs.size(1)
|
||||
prefix_length += self.config.pre_seq_len
|
||||
|
||||
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
|
||||
assert (max_length is None) != (max_new_tokens is None), "please set max_length or max_new_tokens (not both)"
|
||||
if max_length is not None and max_new_tokens is None:
|
||||
max_new_tokens = max_length - prefix_length
|
||||
assert max_new_tokens > 0, f"Provided max_length is less than prefix size: {max_length} < {inputs.size(1)}"
|
||||
elif max_length is None and max_new_tokens is not None:
|
||||
max_length = prefix_length + max_new_tokens
|
||||
|
||||
resuming_session = session is not None and session.last_token_id is not None
|
||||
if num_beams > 1 and resuming_session:
|
||||
raise NotImplementedError(
|
||||
"Resuming inference session in .generate() along with beam search is not supported yet"
|
||||
)
|
||||
|
||||
if inputs is not None:
|
||||
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 2, "inputs must be a 2d tensor [batch, length]"
|
||||
if resuming_session:
|
||||
inputs = torch.cat([session.last_token_id, inputs], dim=1)
|
||||
else:
|
||||
# If there's no active session, create a new one
|
||||
|
||||
max_length = kwargs.get("max_length")
|
||||
max_new_tokens = kwargs.get("max_new_tokens")
|
||||
assert (max_length is None) != (
|
||||
max_new_tokens is None
|
||||
), "You should set `max_length` or `max_new_tokens` (but not both) to reserve server-side attention caches"
|
||||
|
||||
session_max_length = self.transformer.config.pre_seq_len
|
||||
if max_length is not None:
|
||||
session_max_length += max_length
|
||||
if resuming_session:
|
||||
inputs = session.last_token_id
|
||||
else:
|
||||
session_max_length += (inputs.shape[1] if inputs is not None else 0) + max_new_tokens
|
||||
context_manager = self.inference_session(max_length=session_max_length)
|
||||
|
||||
assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs"
|
||||
inputs = torch.tensor([[bos_token_id]] * num_beams, dtype=torch.long, device=self.device)
|
||||
batch_size = inputs.size(0)
|
||||
|
||||
if decoding_algorithm is None:
|
||||
if do_sample:
|
||||
decoding_algorithm = self._choose_sample_algorithm(temperature, top_k, top_p)
|
||||
elif num_beams is not None and num_beams > 1:
|
||||
decoding_algorithm = BeamSearchAlgorithm(num_beams, batch_size=batch_size)
|
||||
else:
|
||||
if top_k is not None or top_p is not None:
|
||||
logger.warning("You passed top_k or top_p but did not pass do_sample=True. Running greedy sampling")
|
||||
decoding_algorithm = GreedyAlgorithm()
|
||||
|
||||
if num_beams > 1:
|
||||
inputs = torch.cat([inputs] * num_beams, dim=0)
|
||||
if batch_size > 1:
|
||||
# TODO: resolve padding problem
|
||||
logger.warning(
|
||||
f"You set batch_size {batch_size} within beam search generation. "
|
||||
f"Be careful, results on sequences with different length may be padded wrong way"
|
||||
)
|
||||
|
||||
if num_return_sequences is None:
|
||||
num_return_sequences = 1
|
||||
|
||||
assert num_return_sequences <= num_beams, (
|
||||
f"You want more sequences than the beam has."
|
||||
" Check num_return_sequences: {num_return_sequences} and num_beams: {num_beams}."
|
||||
)
|
||||
|
||||
constraints = self._get_constraints(
|
||||
inputs=inputs,
|
||||
eos_token_id=eos_token_id,
|
||||
pad_token_id=pad_token_id,
|
||||
provided_constraints=provided_constraints,
|
||||
)
|
||||
|
||||
if session is None:
|
||||
context_manager = self.inference_session(max_length=max_length)
|
||||
else:
|
||||
context_manager = contextlib.nullcontext(session) # Doesn't actually enter session or exit from it
|
||||
with context_manager as session:
|
||||
# Prepend the tokens from the previous .generate() call
|
||||
n_prev_tokens = session.output_ids.shape[1] if session.output_ids is not None else 0
|
||||
if n_prev_tokens > 0:
|
||||
if kwargs.get("num_beams", 1) > 1:
|
||||
logger.warning(
|
||||
"Beam search will not work properly in the resumed petals.InferenceSession "
|
||||
"since intermediate beam entries are lost"
|
||||
)
|
||||
|
||||
if inputs is not None:
|
||||
inputs = torch.cat([session.output_ids, inputs], dim=1)
|
||||
else:
|
||||
inputs = session.output_ids
|
||||
|
||||
# Don't actually run all previous tokens through the transformer,
|
||||
# but keep them for transformers.GenerationMixin (e.g., to compute repetition_penalty)
|
||||
_skipped_tokens.set(max(0, n_prev_tokens - 1))
|
||||
|
||||
if self._supports_cache_class and "past_key_values" not in kwargs:
|
||||
past_key_values = RemotePastKeyValues()
|
||||
past_key_values.update_seen(session.position)
|
||||
kwargs["past_key_values"] = past_key_values
|
||||
|
||||
result = super().generate(inputs, *args, **kwargs)
|
||||
|
||||
sequences = result.sequences if isinstance(result, ModelOutput) else result
|
||||
# Save tokens from this .generate() call
|
||||
session.output_ids = sequences
|
||||
# Crop the last tokens from the previous call
|
||||
sequences = sequences[:, n_prev_tokens:].clone()
|
||||
if isinstance(result, ModelOutput):
|
||||
result.sequences = sequences
|
||||
outputs = []
|
||||
# Find samples with padded inputs.
|
||||
# They will be changed before all of the samples have right length.
|
||||
if torch.any(inputs == pad_token_id): # TODO: move to prepare_inputs
|
||||
outputs += [inputs[:, : inputs.size(1) - (inputs == pad_token_id).sum(-1).max()]]
|
||||
else:
|
||||
result = sequences
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _fix_generate_kwargs(kwargs: dict):
|
||||
# Suppress inappropriate "Both max_new_tokens and max_length" HF warning
|
||||
if "max_length" in kwargs and kwargs["max_length"] is None:
|
||||
del kwargs["max_length"]
|
||||
|
||||
# Support do_sample = {0, 1} for backward compatibility with Petals < 2.1.0
|
||||
do_sample = kwargs.get("do_sample")
|
||||
if isinstance(do_sample, int):
|
||||
kwargs["do_sample"] = bool(do_sample)
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past_key_values: RemotePastKeyValues, beam_idx: torch.LongTensor) -> RemotePastKeyValues:
|
||||
return dataclasses.replace(past_key_values, hypo_ids=beam_idx)
|
||||
outputs += [inputs]
|
||||
last_token_id = None
|
||||
seq_idx = outputs[0].size(1)
|
||||
hypo_ids = torch.arange(outputs[0].size(0))
|
||||
while True:
|
||||
hidden_state = self.transformer.word_embeddings(outputs[-1])
|
||||
intermediate_prompts = None
|
||||
if self.config.pre_seq_len > 0 and len(outputs) == 1:
|
||||
prompts, intermediate_prompts = self.transformer.get_prompt(hidden_state.size(0))
|
||||
hidden_state = torch.cat([prompts, hidden_state], dim=1)
|
||||
hidden_state = self.transformer.word_embeddings_layernorm(hidden_state)
|
||||
|
||||
hidden_state = session.step(hidden_state, prompts=intermediate_prompts, hypo_ids=hypo_ids)[:, -1]
|
||||
|
||||
hidden_state = self.transformer.ln_f(hidden_state)
|
||||
lm_logits = self.lm_head(hidden_state)
|
||||
|
||||
for constraint in constraints:
|
||||
lm_logits = constraint(last_token_id, lm_logits, hypo_ids)
|
||||
last_token_id, hypo_ids = decoding_algorithm(lm_logits)
|
||||
|
||||
# If some samples were padded, change only these samples
|
||||
if seq_idx < inputs.size(1):
|
||||
pad_token_mask = inputs[:, seq_idx : seq_idx + 1] == pad_token_id
|
||||
last_token_id = (~pad_token_mask) * inputs[
|
||||
:, seq_idx : seq_idx + 1
|
||||
] + pad_token_mask * last_token_id
|
||||
|
||||
# TODO: refactor outputs
|
||||
if num_beams > 1:
|
||||
for i in range(len(outputs), 1, -1):
|
||||
outputs[i - 1] = outputs[i - 1][hypo_ids]
|
||||
|
||||
outputs.append(last_token_id)
|
||||
session.last_token_id = last_token_id
|
||||
seq_idx += 1
|
||||
if torch.all(last_token_id == eos_token_id) or len(outputs) > max_new_tokens:
|
||||
break
|
||||
|
||||
outputs = torch.cat(outputs, dim=-1)
|
||||
|
||||
if resuming_session:
|
||||
outputs = outputs[:, 1:]
|
||||
if num_beams > 1:
|
||||
pre_return_idx = [
|
||||
torch.arange(idx, num_return_sequences * batch_size, batch_size) for idx in range(batch_size)
|
||||
]
|
||||
return_idx = torch.cat(pre_return_idx, dim=0)
|
||||
outputs = outputs[return_idx]
|
||||
|
||||
return outputs
|
||||
|
||||
def greedy_search(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
max_length: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
provided_constraints: List[ABCBloomConstraint] = [],
|
||||
) -> torch.LongTensor:
|
||||
"""
|
||||
Generates sequences of token ids for models with a language modeling head. Uses greedy search.
|
||||
|
||||
:param input_ids: The input tokens to the model.
|
||||
:param max_length: The maximum length of the sequence to generate.
|
||||
:param pad_token_id: The id of the padding token.
|
||||
:param eos_token_id: The id of the end of sentence token.
|
||||
:param provided_constraints: A list of constraints to use.
|
||||
"""
|
||||
return self.generate(
|
||||
inputs=input_ids,
|
||||
max_new_tokens=max_length,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
decoding_algorithm=GreedyAlgorithm(),
|
||||
provided_constraints=provided_constraints,
|
||||
)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
temperature: float = 1.0,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
max_length: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
provided_constraints: List[ABCBloomConstraint] = [],
|
||||
) -> torch.LongTensor:
|
||||
"""
|
||||
Generates sequences of token ids for models with a language modeling head. Uses multinomial sampling.
|
||||
If top_k is provided, uses top_k sampling. If top_p is provided, uses nucleus sampling.
|
||||
|
||||
:param: input_ids: The input tokens to the model.
|
||||
:param: temperature: The temperature to use for sampling.
|
||||
:param: top_k: The number of samples to use for top_k sampling.
|
||||
:param: top_p: The probability of using top_p sampling.
|
||||
:param: max_length: The maximum length of the sequence to generate.
|
||||
:param: pad_token_id: The id of the padding token.
|
||||
:param: eos_token_id: The id of the end of sentence token.
|
||||
:param: provided_constraints: A list of constraints to use.
|
||||
"""
|
||||
|
||||
return self.generate(
|
||||
inputs=input_ids,
|
||||
max_new_tokens=max_length,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
decoding_algorithm=self._choose_sample_algorithm(temperature, top_k, top_p),
|
||||
provided_constraints=provided_constraints,
|
||||
)
|
||||
|
||||
def beam_search(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
num_beams: int = 1,
|
||||
max_length: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
provided_constraints: List[ABCBloomConstraint] = [],
|
||||
) -> torch.LongTensor:
|
||||
"""
|
||||
Generates sequences of token ids for models with a language modeling head. Uses beam search.
|
||||
|
||||
:param input_ids: The input tokens to the model.
|
||||
:param num_beams: The number of beams to use.
|
||||
:param max_length: The maximum length of the sequence to generate.
|
||||
:param pad_token_id: The id of the padding token.
|
||||
:param eos_token_id: The id of the end of sentence token.
|
||||
:param provided_constraints: A list of constraints to use.
|
||||
"""
|
||||
decoding_algorithm = BeamSearchAlgorithm(
|
||||
num_beams=num_beams,
|
||||
batch_size=input_ids.size(0),
|
||||
)
|
||||
return self.generate(
|
||||
inputs=input_ids,
|
||||
num_beams=num_beams,
|
||||
max_new_tokens=max_length,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
decoding_algorithm=decoding_algorithm,
|
||||
provided_constraints=provided_constraints,
|
||||
)
|
||||
|
||||
def beam_sample(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
max_length: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
provided_constraints: List[ABCBloomConstraint] = [],
|
||||
) -> torch.LongTensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def group_beam_search(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
max_length: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
provided_constraints: List[ABCBloomConstraint] = [],
|
||||
) -> torch.LongTensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def _choose_sample_algorithm(
|
||||
self,
|
||||
temperature: float = 1.0,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
) -> DecodingAlgorithm:
|
||||
if (top_k is not None) and (top_p is not None):
|
||||
raise ValueError("You have to provide only top_k or top_p for sampling")
|
||||
if top_k is not None:
|
||||
return TopKAlgorithm(top_k, temperature)
|
||||
elif top_p is not None:
|
||||
return NucleusAlgorithm(top_p, temperature)
|
||||
else:
|
||||
return SamplingAlgorithm(temperature)
|
||||
|
||||
def _get_constraints(
|
||||
self,
|
||||
inputs: Optional[torch.Tensor] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
provided_constraints: List[ABCBloomConstraint] = [],
|
||||
) -> List[ABCBloomConstraint]:
|
||||
constraints = []
|
||||
constraints.extend(provided_constraints)
|
||||
constraints.append(EosConstraint(inputs, eos_token_id, pad_token_id))
|
||||
return constraints
|
||||
|
@ -1,2 +1 @@
|
||||
from petals.client.routing.sequence_manager import RemoteSequenceManager, maybe_log_traceback
|
||||
from petals.client.routing.spending_policy import NoSpendingPolicy, SpendingPolicyBase
|
||||
"""Client-side functions responsible for choosing the best server, """
|
||||
|
@ -1,18 +1,13 @@
|
||||
import torch
|
||||
|
||||
PUBLIC_INITIAL_PEERS = [
|
||||
# IPv4 DNS addresses
|
||||
"/dns/bootstrap1.petals.dev/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY",
|
||||
"/dns/bootstrap2.petals.dev/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5",
|
||||
# IPv6 DNS addresses
|
||||
"/dns6/bootstrap1.petals.dev/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY",
|
||||
"/dns6/bootstrap2.petals.dev/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5",
|
||||
# Reserved IPs
|
||||
"/ip4/159.89.214.152/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY",
|
||||
"/ip4/159.203.156.48/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5",
|
||||
"/dns/bootstrap1.petals.ml/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY",
|
||||
"/dns6/bootstrap1.petals.ml/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY",
|
||||
"/dns/bootstrap2.petals.ml/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5",
|
||||
"/dns6/bootstrap2.petals.ml/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5",
|
||||
]
|
||||
|
||||
# The reachability API is currently used only when connecting to the public swarm
|
||||
REACHABILITY_API_URL = "https://health.petals.dev"
|
||||
REACHABILITY_API_URL = "http://health.petals.ml"
|
||||
|
||||
DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
|
||||
|
@ -1,9 +1,124 @@
|
||||
import warnings
|
||||
"""
|
||||
Utilities for declaring and retrieving active model layers using a shared DHT.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
warnings.warn(
|
||||
"petals.dht_utils has been moved to petals.utils.dht. This alias will be removed in Petals 2.2.0+",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Dict, List, Optional, Sequence, Union
|
||||
|
||||
from petals.utils.dht import *
|
||||
from hivemind.dht import DHT, DHTNode, DHTValue
|
||||
from hivemind.p2p import PeerID
|
||||
from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger
|
||||
|
||||
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def declare_active_modules(
|
||||
dht: DHT,
|
||||
uids: Sequence[ModuleUID],
|
||||
server_info: ServerInfo,
|
||||
expiration_time: DHTExpiration,
|
||||
wait: bool = True,
|
||||
) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]:
|
||||
"""
|
||||
Declare that your node serves the specified modules; update timestamps if declared previously
|
||||
|
||||
:param uids: a list of module ids to declare
|
||||
:param wait: if True, awaits for declaration to finish, otherwise runs in background
|
||||
:param throughput: specify your performance in terms of compute throughput
|
||||
:param expiration_time: declared modules will be visible for this many seconds
|
||||
:returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
|
||||
"""
|
||||
if isinstance(uids, str):
|
||||
uids = [uids]
|
||||
if not isinstance(uids, list):
|
||||
uids = list(uids)
|
||||
for uid in uids:
|
||||
assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid
|
||||
|
||||
return dht.run_coroutine(
|
||||
partial(_declare_active_modules, uids=uids, server_info=server_info, expiration_time=expiration_time),
|
||||
return_future=not wait,
|
||||
)
|
||||
|
||||
|
||||
async def _declare_active_modules(
|
||||
dht: DHT,
|
||||
node: DHTNode,
|
||||
uids: List[ModuleUID],
|
||||
server_info: ServerInfo,
|
||||
expiration_time: DHTExpiration,
|
||||
) -> Dict[ModuleUID, bool]:
|
||||
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
|
||||
return await node.store_many(
|
||||
keys=uids,
|
||||
subkeys=[dht.peer_id.to_base58()] * len(uids),
|
||||
values=[server_info.to_tuple()] * len(uids),
|
||||
expiration_time=expiration_time,
|
||||
num_workers=num_workers,
|
||||
)
|
||||
|
||||
|
||||
def get_remote_module_infos(
|
||||
dht: DHT,
|
||||
uids: Sequence[ModuleUID],
|
||||
expiration_time: Optional[DHTExpiration] = None,
|
||||
active_adapter: Optional[str] = None,
|
||||
*,
|
||||
latest: bool = False,
|
||||
return_future: bool = False,
|
||||
) -> Union[List[Optional[RemoteModuleInfo]], MPFuture]:
|
||||
return dht.run_coroutine(
|
||||
partial(
|
||||
_get_remote_module_infos,
|
||||
uids=uids,
|
||||
active_adapter=active_adapter,
|
||||
expiration_time=expiration_time,
|
||||
latest=latest,
|
||||
),
|
||||
return_future=return_future,
|
||||
)
|
||||
|
||||
|
||||
async def _get_remote_module_infos(
|
||||
dht: DHT,
|
||||
node: DHTNode,
|
||||
uids: List[ModuleUID],
|
||||
active_adapter: Optional[str],
|
||||
expiration_time: Optional[DHTExpiration],
|
||||
latest: bool,
|
||||
) -> List[Optional[RemoteModuleInfo]]:
|
||||
if latest:
|
||||
assert expiration_time is None, "You should define either `expiration_time` or `latest`, not both"
|
||||
expiration_time = math.inf
|
||||
elif expiration_time is None:
|
||||
expiration_time = get_dht_time()
|
||||
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
|
||||
found: Dict[ModuleUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
|
||||
|
||||
modules: List[Optional[RemoteModuleInfo]] = [None] * len(uids)
|
||||
for i, uid in enumerate(uids):
|
||||
metadata = found[uid]
|
||||
if metadata is None or not isinstance(metadata.value, dict):
|
||||
if metadata is not None:
|
||||
logger.warning(f"Incorrect metadata for {uid}: {metadata}")
|
||||
continue
|
||||
servers = {}
|
||||
for peer_id, server_info in metadata.value.items():
|
||||
try:
|
||||
peer_id = PeerID.from_base58(peer_id)
|
||||
server_info = ServerInfo.from_tuple(server_info.value)
|
||||
|
||||
if active_adapter and active_adapter not in server_info.adapters:
|
||||
logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}")
|
||||
continue
|
||||
|
||||
servers[peer_id] = server_info
|
||||
except (TypeError, ValueError) as e:
|
||||
logger.warning(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}")
|
||||
if servers:
|
||||
modules[i] = RemoteModuleInfo(uid, servers)
|
||||
return modules
|
||||
|
@ -1,15 +0,0 @@
|
||||
from petals.models.falcon.block import WrappedFalconBlock
|
||||
from petals.models.falcon.config import DistributedFalconConfig
|
||||
from petals.models.falcon.model import (
|
||||
DistributedFalconForCausalLM,
|
||||
DistributedFalconForSequenceClassification,
|
||||
DistributedFalconModel,
|
||||
)
|
||||
from petals.utils.auto_config import register_model_classes
|
||||
|
||||
register_model_classes(
|
||||
config=DistributedFalconConfig,
|
||||
model=DistributedFalconModel,
|
||||
model_for_causal_lm=DistributedFalconForCausalLM,
|
||||
model_for_sequence_classification=DistributedFalconForSequenceClassification,
|
||||
)
|
@ -1,480 +0,0 @@
|
||||
"""
|
||||
Falcon intermediate layer
|
||||
Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py
|
||||
See commit history for authorship.
|
||||
"""
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers.models.falcon.modeling_falcon import (
|
||||
FalconAttention,
|
||||
FalconConfig,
|
||||
FalconDecoderLayer,
|
||||
FalconLinear,
|
||||
FalconMLP,
|
||||
FalconModel,
|
||||
LayerNorm,
|
||||
build_alibi_tensor,
|
||||
dropout_add,
|
||||
rotate_half,
|
||||
)
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
INFERENCE_MAX_LENGTH = 8192
|
||||
|
||||
|
||||
def apply_rotary(query, key, cos, sin):
|
||||
return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin)
|
||||
|
||||
|
||||
class OptimizedFalconRotaryEmbedding(nn.Module):
|
||||
def __init__(self, head_dim: int, base=10000):
|
||||
super().__init__()
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.head_dim = head_dim
|
||||
self.seq_len_cached = -1
|
||||
|
||||
self.cuda_graph = None
|
||||
self.input_surface = None
|
||||
self.static_outputs = None
|
||||
|
||||
def _optimized_apply_rotary(self, query, key, cos, sin):
|
||||
if self.cuda_graph is None:
|
||||
self.cuda_graph = torch.cuda.CUDAGraph()
|
||||
self.input_surface = (query, key, cos, sin)
|
||||
|
||||
s = torch.cuda.Stream()
|
||||
s.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(s):
|
||||
for _ in range(3):
|
||||
apply_rotary(*self.input_surface)
|
||||
torch.cuda.current_stream().wait_stream(s)
|
||||
|
||||
with torch.cuda.graph(self.cuda_graph):
|
||||
self.static_outputs = apply_rotary(*self.input_surface)
|
||||
|
||||
inputs = (query, key, cos, sin)
|
||||
for static_input, data in zip(self.input_surface, inputs):
|
||||
static_input.copy_(data)
|
||||
self.cuda_graph.replay()
|
||||
return tuple(o.detach() for o in self.static_outputs)
|
||||
|
||||
def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor:
|
||||
total_length = seq_len + past_key_values_length
|
||||
if self.seq_len_cached == -1:
|
||||
# warm up the cache
|
||||
total_length = max(INFERENCE_MAX_LENGTH, total_length)
|
||||
|
||||
if total_length > self.seq_len_cached:
|
||||
with torch.inference_mode(False):
|
||||
self.seq_len_cached = total_length
|
||||
t = torch.arange(total_length, device=device, dtype=self.inv_freq.dtype)
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
emb = torch.cat((freqs, freqs), dim=-1).to(device)
|
||||
|
||||
if dtype in [torch.float16, torch.bfloat16]:
|
||||
emb = emb.float()
|
||||
|
||||
self.register_buffer("cos_cached", emb.cos()[None, :, :].type(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin()[None, :, :].type(dtype), persistent=False)
|
||||
|
||||
return (
|
||||
self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length].type(dtype),
|
||||
self.sin_cached[:, past_key_values_length : seq_len + past_key_values_length].type(dtype),
|
||||
)
|
||||
|
||||
def forward(self, query, key, past_key_values_length=0):
|
||||
batch, seq_len, head_dim = query.shape
|
||||
cos, sin = self.cos_sin(seq_len, past_key_values_length, query.device, query.dtype)
|
||||
if seq_len == 1 and torch.is_inference_mode_enabled() and query.device.type == "cuda":
|
||||
return self._optimized_apply_rotary(query, key, cos, sin)
|
||||
else:
|
||||
return apply_rotary(query, key, cos, sin)
|
||||
|
||||
|
||||
def split_heads(
|
||||
fused_qkv: torch.Tensor, num_heads: int, num_kv_heads: int, head_dim: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
batch, seq_len, _ = fused_qkv.shape
|
||||
qkv = fused_qkv.view(batch, seq_len, -1, num_heads // num_kv_heads + 2, head_dim)
|
||||
query, key, value = torch.split(qkv, [num_heads // num_kv_heads, 1, 1], dim=3)
|
||||
key = torch.broadcast_to(key, query.shape)
|
||||
value = torch.broadcast_to(value, query.shape)
|
||||
|
||||
query, key, value = [x.flatten(2, 3) for x in (query, key, value)]
|
||||
return query, key, value
|
||||
|
||||
|
||||
class OptimizedFalconAttention(FalconAttention):
|
||||
def __init__(self, config: FalconConfig):
|
||||
nn.Module.__init__(self)
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.split_size = self.hidden_size
|
||||
self.hidden_dropout = config.hidden_dropout
|
||||
|
||||
if self.head_dim * self.num_heads != self.hidden_size:
|
||||
raise ValueError(
|
||||
f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
|
||||
f" {self.num_heads})."
|
||||
)
|
||||
|
||||
self.maybe_rotary = OptimizedFalconRotaryEmbedding(config.head_dim) if config.rotary else lambda q, k, t: (q, k)
|
||||
|
||||
# Layer-wise attention scaling
|
||||
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
|
||||
self.beta = self.inv_norm_factor
|
||||
if config.new_decoder_architecture:
|
||||
qkv_out_dim = (config.num_kv_heads * 2 + config.num_attention_heads) * self.head_dim
|
||||
elif config.multi_query:
|
||||
qkv_out_dim = self.hidden_size + 2 * self.head_dim
|
||||
else:
|
||||
qkv_out_dim = 3 * self.hidden_size
|
||||
self.query_key_value = FalconLinear(self.hidden_size, qkv_out_dim, bias=config.bias)
|
||||
self.new_decoder_architecture = config.new_decoder_architecture
|
||||
self.multi_query = config.multi_query
|
||||
self.dense = FalconLinear(self.hidden_size, self.hidden_size, bias=config.bias)
|
||||
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
||||
self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1
|
||||
|
||||
if self.new_decoder_architecture:
|
||||
self._split_heads = partial(
|
||||
split_heads, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, head_dim=self.head_dim
|
||||
)
|
||||
self.split_graph = None
|
||||
self.input_surface = None
|
||||
self.static_outputs = None
|
||||
|
||||
def _optimized_split_heads(self, fused_qkv):
|
||||
if self.split_graph is None:
|
||||
self.split_graph = torch.cuda.CUDAGraph()
|
||||
self.input_surface = fused_qkv
|
||||
|
||||
s = torch.cuda.Stream()
|
||||
s.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(s):
|
||||
for _ in range(3):
|
||||
self._split_heads(fused_qkv)
|
||||
torch.cuda.current_stream().wait_stream(s)
|
||||
|
||||
with torch.cuda.graph(self.split_graph):
|
||||
self.static_outputs = self._split_heads(self.input_surface)
|
||||
|
||||
self.input_surface.copy_(fused_qkv)
|
||||
self.split_graph.replay()
|
||||
return tuple(o.detach() for o in self.static_outputs)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
alibi: Optional[torch.Tensor],
|
||||
attention_mask: torch.Tensor,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
assert not output_attentions
|
||||
|
||||
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
||||
|
||||
if (
|
||||
self.new_decoder_architecture
|
||||
and hidden_states.size(1) == 1
|
||||
and torch.is_inference_mode_enabled()
|
||||
and hidden_states.device.type == "cuda"
|
||||
):
|
||||
query_layer, key_layer, value_layer = self._optimized_split_heads(fused_qkv)
|
||||
else:
|
||||
# 3 x [batch_size, seq_length, num_heads, head_dim]
|
||||
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
|
||||
|
||||
num_kv_heads = self.num_heads
|
||||
batch_size, query_length, _, _ = query_layer.shape
|
||||
|
||||
query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim)
|
||||
key_layer = key_layer.transpose(1, 2).reshape(
|
||||
batch_size * num_kv_heads,
|
||||
query_length,
|
||||
self.head_dim,
|
||||
)
|
||||
value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim)
|
||||
|
||||
past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
|
||||
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
|
||||
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past
|
||||
# concatenate along seq_length dimension:
|
||||
# - key: [batch_size * self.num_heads, kv_length, head_dim]
|
||||
# - value: [batch_size * self.num_heads, kv_length, head_dim]
|
||||
key_layer = torch.cat((past_key, key_layer), dim=1)
|
||||
value_layer = torch.cat((past_value, value_layer), dim=1)
|
||||
|
||||
_, kv_length, _ = key_layer.shape
|
||||
if use_cache:
|
||||
present = (key_layer, value_layer)
|
||||
else:
|
||||
present = None
|
||||
|
||||
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
||||
key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
|
||||
value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
|
||||
|
||||
attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype)
|
||||
|
||||
if alibi is None:
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
query_layer_, key_layer_, value_layer_, attn_mask=attention_mask_float, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim)
|
||||
attn_output = attn_output.permute(0, 2, 1, 3)
|
||||
attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
|
||||
|
||||
output_tensor = self.dense(attn_output)
|
||||
|
||||
return output_tensor, present
|
||||
else:
|
||||
matmul_result = query_layer_ @ key_layer_.transpose(-1, -2)
|
||||
|
||||
# change view to [batch_size, num_heads, q_length, kv_length]
|
||||
attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length)
|
||||
|
||||
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
|
||||
input_dtype = attention_scores.dtype
|
||||
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
|
||||
if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
|
||||
attention_scores = attention_scores.to(torch.float32)
|
||||
# Matt (HF) note: We could possibly use F.scaled_dot_product_attention here too, by
|
||||
# adding (alibi * self.inv_norm_factor) to attention_mask_float. I think this would be mathematically
|
||||
# equivalent and more performant, but there might be a numerical difference. If you're reading this
|
||||
# and you'd like to experiment and maybe file a PR, feel free!
|
||||
attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)
|
||||
attention_logits *= self.inv_norm_factor
|
||||
attention_probs = F.softmax(attention_logits + attention_mask_float, dim=-1, dtype=hidden_states.dtype)
|
||||
# [batch_size, num_heads, q_length, kv_length]
|
||||
attention_probs = self.attention_dropout(attention_probs)
|
||||
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
# change view [batch_size, num_heads, q_length, kv_length]
|
||||
attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length)
|
||||
|
||||
# matmul: [batch_size * num_heads, q_length, head_dim]
|
||||
context_layer = (attention_probs_reshaped @ value_layer_).flatten(0, 1)
|
||||
|
||||
# change view [batch_size, q_length, num_heads * head_dim]
|
||||
context_layer = self._merge_heads(context_layer)
|
||||
|
||||
output_tensor = self.dense(context_layer)
|
||||
|
||||
if output_attentions:
|
||||
return output_tensor, present, attention_probs
|
||||
else:
|
||||
return output_tensor, present
|
||||
|
||||
|
||||
class OptimizedFalconDecoderLayer(FalconDecoderLayer):
|
||||
def __init__(self, config: FalconConfig):
|
||||
nn.Module.__init__(self)
|
||||
hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
|
||||
self.mlp = FalconMLP(config)
|
||||
self.hidden_dropout = config.hidden_dropout
|
||||
self.config = config
|
||||
|
||||
self.self_attention = OptimizedFalconAttention(config)
|
||||
|
||||
if self.config.alibi or not config.new_decoder_architecture:
|
||||
if config.new_decoder_architecture:
|
||||
# The layer norm before self-attention
|
||||
self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
# The layer norm before the MLP
|
||||
self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
else:
|
||||
self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
if not config.parallel_attn:
|
||||
self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
else:
|
||||
self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
self.ln_graph = None
|
||||
self.static_input = None
|
||||
self.static_outputs = None
|
||||
|
||||
def _optimized_apply_ln(self, hidden_states):
|
||||
if self.ln_graph is None:
|
||||
self.ln_graph = torch.cuda.CUDAGraph()
|
||||
self.static_input = hidden_states
|
||||
|
||||
s = torch.cuda.Stream()
|
||||
s.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(s):
|
||||
for _ in range(3):
|
||||
self.ln_attn(hidden_states)
|
||||
self.ln_mlp(hidden_states)
|
||||
torch.cuda.current_stream().wait_stream(s)
|
||||
|
||||
with torch.cuda.graph(self.ln_graph):
|
||||
ln_attn_output = self.ln_attn(hidden_states)
|
||||
ln_mlp_output = self.ln_mlp(hidden_states)
|
||||
self.static_outputs = (ln_attn_output, ln_mlp_output)
|
||||
|
||||
self.static_input.copy_(hidden_states)
|
||||
self.ln_graph.replay()
|
||||
return tuple(o.detach() for o in self.static_outputs)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
alibi: Optional[torch.Tensor],
|
||||
attention_mask: torch.Tensor,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
residual = hidden_states
|
||||
|
||||
if self.config.new_decoder_architecture:
|
||||
if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
|
||||
attention_layernorm_out, mlp_layernorm_out = self._optimized_apply_ln(hidden_states)
|
||||
else:
|
||||
attention_layernorm_out = self.ln_attn(hidden_states)
|
||||
mlp_layernorm_out = self.ln_mlp(hidden_states)
|
||||
else:
|
||||
attention_layernorm_out = self.input_layernorm(hidden_states)
|
||||
|
||||
attn_outputs = self.self_attention(
|
||||
attention_layernorm_out,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
alibi=alibi,
|
||||
head_mask=head_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
attention_output = attn_outputs[0]
|
||||
|
||||
if not self.config.new_decoder_architecture:
|
||||
if self.config.parallel_attn:
|
||||
mlp_layernorm_out = attention_layernorm_out
|
||||
else:
|
||||
residual = dropout_add(
|
||||
attention_output, residual, self.config.attention_dropout, training=self.training
|
||||
)
|
||||
mlp_layernorm_out = self.post_attention_layernorm(residual)
|
||||
|
||||
outputs = attn_outputs[1:]
|
||||
|
||||
mlp_output = self.mlp(mlp_layernorm_out)
|
||||
|
||||
if self.config.new_decoder_architecture or self.config.parallel_attn:
|
||||
mlp_output += attention_output
|
||||
|
||||
output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training)
|
||||
|
||||
if use_cache:
|
||||
outputs = (output,) + outputs
|
||||
else:
|
||||
outputs = (output,) + outputs[1:]
|
||||
|
||||
return outputs # hidden_states, present, attentions
|
||||
|
||||
|
||||
class WrappedFalconBlock(OptimizedFalconDecoderLayer):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
*args,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
alibi: Optional[torch.Tensor] = None,
|
||||
layer_past: Optional[KVCache] = None,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
assert attention_mask is None
|
||||
|
||||
batch_size, seq_length = hidden_states.shape[:2]
|
||||
|
||||
if layer_past is not None:
|
||||
layer_past = self._reorder_cache_from_bloom_to_falcon(layer_past)
|
||||
past_length = 0 if layer_past is None else layer_past[0].shape[1]
|
||||
seq_length_with_past = seq_length + past_length
|
||||
|
||||
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
|
||||
if alibi is None and self.config.alibi:
|
||||
alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
|
||||
attention_mask = FalconModel._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length)
|
||||
|
||||
outputs = super().forward(
|
||||
hidden_states,
|
||||
*args,
|
||||
attention_mask=attention_mask,
|
||||
alibi=alibi,
|
||||
layer_past=layer_past,
|
||||
use_cache=use_cache,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if use_cache:
|
||||
present_key_value = outputs[-1]
|
||||
present_key_value = self._reorder_cache_from_falcon_to_bloom(present_key_value)
|
||||
outputs = outputs[:-1] + (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
def _reorder_cache_from_bloom_to_falcon(self, key_value: KVCache) -> KVCache:
|
||||
key_states, value_states = key_value
|
||||
|
||||
key_states = key_states.permute(0, 2, 1)
|
||||
assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim]
|
||||
|
||||
if self.config.new_decoder_architecture:
|
||||
key_states = self._expand_states(key_states)
|
||||
value_states = self._expand_states(value_states)
|
||||
|
||||
return (key_states, value_states)
|
||||
|
||||
def _reorder_cache_from_falcon_to_bloom(self, key_value: KVCache) -> KVCache:
|
||||
key_states, value_states = key_value
|
||||
|
||||
if self.config.new_decoder_architecture:
|
||||
key_states = self._collapse_states(key_states)
|
||||
value_states = self._collapse_states(value_states)
|
||||
|
||||
assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim]
|
||||
key_states = key_states.permute(0, 2, 1)
|
||||
|
||||
return (key_states, value_states)
|
||||
|
||||
def _expand_states(self, state: torch.Tensor) -> torch.Tensor:
|
||||
batch_size_x_num_kv_heads, seq_len, head_dim = state.shape
|
||||
batch_size = batch_size_x_num_kv_heads // self.config.num_kv_heads
|
||||
|
||||
state = state.view(batch_size, self.config.num_kv_heads, 1, seq_len, head_dim)
|
||||
state = state.expand(-1, -1, self.config.num_key_value_groups, -1, -1) # No copy
|
||||
state = state.reshape(batch_size * self.config.num_attention_heads, seq_len, head_dim) # Involves a copy
|
||||
return state
|
||||
|
||||
def _collapse_states(self, state: torch.Tensor) -> torch.Tensor:
|
||||
batch_size_x_num_attn_heads, seq_len, head_dim = state.shape
|
||||
batch_size = batch_size_x_num_attn_heads // self.config.num_attention_heads
|
||||
|
||||
state = state.view(batch_size, self.config.num_kv_heads, self.config.num_key_value_groups, seq_len, head_dim)
|
||||
state = state[:, :, 0]
|
||||
state = state.view(batch_size * self.config.num_kv_heads, seq_len, head_dim)
|
||||
return state
|
@ -1,48 +0,0 @@
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
from hivemind import get_logger
|
||||
from transformers.models.falcon import FalconConfig
|
||||
from transformers.models.falcon.modeling_falcon import FalconAttention
|
||||
|
||||
from petals.client.config import ClientConfig
|
||||
from petals.client.lm_head import LMHeadConfig
|
||||
from petals.client.ptune import PTuneConfig
|
||||
from petals.models.falcon.block import WrappedFalconBlock
|
||||
from petals.utils.auto_config import DefaultRevisionMixin
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DistributedFalconConfig(DefaultRevisionMixin, FalconConfig, ClientConfig, PTuneConfig, LMHeadConfig):
|
||||
block_class = WrappedFalconBlock
|
||||
attn_class = FalconAttention
|
||||
block_prefix = "transformer.h"
|
||||
|
||||
@property
|
||||
def num_key_value_groups(self) -> int:
|
||||
if self.new_decoder_architecture:
|
||||
return self.num_attention_heads // self.num_kv_heads
|
||||
if self.multi_query:
|
||||
return self.num_attention_heads
|
||||
return 1
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
|
||||
):
|
||||
if "180B" in model_name_or_path.upper():
|
||||
logger.info("Make sure you follow the Falcon-180B license: https://bit.ly/falcon-180b-license")
|
||||
|
||||
loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
|
||||
if loading_from_repo and dht_prefix is None:
|
||||
dht_prefix = str(model_name_or_path)
|
||||
dht_prefix = dht_prefix.split("/")[-1] # Use only repo name to merge blocks hosted by different accounts
|
||||
dht_prefix = dht_prefix.replace(".", "-")
|
||||
logger.info(f"Using DHT prefix: {dht_prefix}")
|
||||
|
||||
result = super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
|
||||
config = result[0] if isinstance(result, tuple) else result
|
||||
if config.pad_token_id is None:
|
||||
config.pad_token_id = 0
|
||||
return result
|
@ -1,154 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
import hivemind
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from hivemind.utils.logging import get_logger
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
|
||||
from transformers.models.falcon import (
|
||||
FalconForCausalLM,
|
||||
FalconForSequenceClassification,
|
||||
FalconModel,
|
||||
FalconPreTrainedModel,
|
||||
)
|
||||
|
||||
from petals.client.from_pretrained import FromPretrainedMixin
|
||||
from petals.client.lm_head import LMHead
|
||||
from petals.client.ptune import PTuneMixin
|
||||
from petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues
|
||||
from petals.client.remote_sequential import RemoteSequential
|
||||
from petals.models.falcon.config import DistributedFalconConfig
|
||||
from petals.utils.auto_config import DefaultRevisionMixin
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DistributedFalconModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMixin, FalconModel):
|
||||
"""FalconModel, but all transformer layers are hosted by the swarm"""
|
||||
|
||||
_keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing
|
||||
_keys_to_ignore_on_load_unexpected = [r"^transformer\.h\."]
|
||||
|
||||
config_class = DistributedFalconConfig
|
||||
|
||||
def __init__(self, config: DistributedFalconConfig, *, dht: Optional[hivemind.DHT] = None):
|
||||
n_layer, config.num_hidden_layers = config.num_hidden_layers, 0 # Prevent initialization
|
||||
super().__init__(config)
|
||||
assert len(self.h) == 0
|
||||
config.num_hidden_layers = n_layer
|
||||
|
||||
self.h = RemoteSequential(config, dht=dht)
|
||||
|
||||
self.requires_grad_(False) # Forbid accumulate grads for embeddings and layernorm
|
||||
self.init_prompts(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[RemotePastKeyValues] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
# The causal mask will be added on the server-side
|
||||
assert (
|
||||
attention_mask is None or (attention_mask == 1).all()
|
||||
), f"Custom attention masks are not supported, {attention_mask=}"
|
||||
assert (
|
||||
position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all()
|
||||
), f"Non-consecutive position_ids are not supported, {position_ids=}"
|
||||
assert head_mask is None, f"Custom head masks are not supported, {head_mask=}"
|
||||
assert use_cache is None or use_cache, f"{use_cache=} is not supported"
|
||||
assert not output_attentions, f"{output_attentions=} is not supported"
|
||||
assert not output_hidden_states, f"{output_hidden_states=} is not supported"
|
||||
assert return_dict is None or return_dict, f"{return_dict=} is not supported"
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
|
||||
use_prompts = self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0
|
||||
if use_prompts:
|
||||
batch_size = inputs_embeds.shape[0]
|
||||
prompts, intermediate_prompts = self.get_prompt(batch_size)
|
||||
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
|
||||
else:
|
||||
prompts = intermediate_prompts = None
|
||||
|
||||
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
|
||||
hidden_states = self.h(
|
||||
hidden_states,
|
||||
prompts=intermediate_prompts,
|
||||
hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None,
|
||||
)
|
||||
|
||||
# Remove prefix
|
||||
if use_prompts:
|
||||
hidden_states = hidden_states[:, self.pre_seq_len :]
|
||||
|
||||
# Add last hidden state
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
hidden_states = hidden_states.view(output_shape)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=RemotePastKeyValues(),
|
||||
hidden_states=None,
|
||||
attentions=None,
|
||||
)
|
||||
|
||||
@property
|
||||
def word_embeddings_layernorm(self) -> nn.Module: # For compatibility with RemoteGenerationMixin
|
||||
return nn.Identity()
|
||||
|
||||
|
||||
class DistributedFalconForCausalLM(DefaultRevisionMixin, FromPretrainedMixin, RemoteGenerationMixin, FalconForCausalLM):
|
||||
_keys_to_ignore_on_load_missing = DistributedFalconModel._keys_to_ignore_on_load_missing
|
||||
_keys_to_ignore_on_load_unexpected = DistributedFalconModel._keys_to_ignore_on_load_unexpected
|
||||
|
||||
config_class = DistributedFalconConfig
|
||||
|
||||
def __init__(self, config: DistributedFalconConfig):
|
||||
FalconPreTrainedModel.__init__(self, config)
|
||||
self.transformer = DistributedFalconModel(config)
|
||||
self.lm_head = LMHead(config)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
|
||||
class DistributedFalconForSequenceClassification(
|
||||
DefaultRevisionMixin, FromPretrainedMixin, FalconForSequenceClassification
|
||||
):
|
||||
_keys_to_ignore_on_load_missing = DistributedFalconModel._keys_to_ignore_on_load_missing
|
||||
_keys_to_ignore_on_load_unexpected = DistributedFalconModel._keys_to_ignore_on_load_unexpected
|
||||
|
||||
config_class = DistributedFalconConfig
|
||||
|
||||
def __init__(self, config: DistributedFalconConfig):
|
||||
FalconPreTrainedModel.__init__(self, config)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.transformer = DistributedFalconModel(config)
|
||||
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
@ -1,15 +0,0 @@
|
||||
from petals.models.mixtral.block import WrappedMixtralBlock
|
||||
from petals.models.mixtral.config import DistributedMixtralConfig
|
||||
from petals.models.mixtral.model import (
|
||||
DistributedMixtralForCausalLM,
|
||||
DistributedMixtralForSequenceClassification,
|
||||
DistributedMixtralModel,
|
||||
)
|
||||
from petals.utils.auto_config import register_model_classes
|
||||
|
||||
register_model_classes(
|
||||
config=DistributedMixtralConfig,
|
||||
model=DistributedMixtralModel,
|
||||
model_for_causal_lm=DistributedMixtralForCausalLM,
|
||||
model_for_sequence_classification=DistributedMixtralForSequenceClassification,
|
||||
)
|
@ -1,114 +0,0 @@
|
||||
import json
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from transformers import MixtralConfig
|
||||
from transformers.cache_utils import DynamicCache
|
||||
from transformers.modeling_attn_mask_utils import (
|
||||
_prepare_4d_causal_attention_mask,
|
||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
)
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralModel
|
||||
|
||||
|
||||
class WrappedMixtralBlock(MixtralDecoderLayer):
|
||||
def __init__(self, config: MixtralConfig, layer_idx: int):
|
||||
super().__init__(config, layer_idx)
|
||||
|
||||
self._attn_implementation = config._attn_implementation
|
||||
self.sliding_window = config.sliding_window
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
*args,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||
use_cache: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
batch_size, seq_length, _ = hidden_states.shape
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
past_key_value = layer_past
|
||||
|
||||
if past_key_value is not None:
|
||||
past_key_values_length = past_key_value[0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
_past_key_value = self._reorder_cache_from_bloom(past_key_value, batch_size, past_key_values_length)
|
||||
past_key_value = DynamicCache()
|
||||
past_key_value.key_cache = [torch.empty(0) for _ in range(self.layer_idx)] + [_past_key_value[0]]
|
||||
past_key_value.value_cache = [torch.empty(0) for _ in range(self.layer_idx)] + [_past_key_value[1]]
|
||||
past_key_value._seen_tokens = past_key_values_length
|
||||
|
||||
if self._attn_implementation == "flash_attention_2":
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
elif self._attn_implementation == "sdpa":
|
||||
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
||||
# the manual implementation that requires a 4D causal mask in all cases.
|
||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
hidden_states,
|
||||
past_key_values_length,
|
||||
)
|
||||
else:
|
||||
# 4d mask is passed through the layers
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
hidden_states,
|
||||
past_key_values_length,
|
||||
sliding_window=self.sliding_window,
|
||||
)
|
||||
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=hidden_states.device
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
|
||||
outputs = super().forward(
|
||||
hidden_states,
|
||||
*args,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
use_cache=use_cache,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if use_cache:
|
||||
present_key_value = outputs[-1]
|
||||
present_key_value = present_key_value[self.layer_idx]
|
||||
present_key_value = self._reorder_cache_to_bloom(present_key_value, batch_size, seq_length_with_past)
|
||||
outputs = outputs[:-1] + (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
def _reorder_cache_from_bloom(
|
||||
self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
|
||||
) -> Tuple[torch.Tensor]:
|
||||
# TODO: Move to mixin
|
||||
key_states, value_states = key_value
|
||||
key_states = key_states.permute(0, 2, 1)
|
||||
key_states = key_states.view(
|
||||
batch_size, self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
|
||||
)
|
||||
value_states = value_states.view(*key_states.shape)
|
||||
return (key_states, value_states)
|
||||
|
||||
def _reorder_cache_to_bloom(
|
||||
self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
|
||||
) -> Tuple[torch.Tensor]:
|
||||
# TODO: Move to mixin
|
||||
key_states, value_states = key_value
|
||||
value_states = value_states.view(
|
||||
batch_size * self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
|
||||
)
|
||||
key_states = key_states.view(*value_states.shape)
|
||||
key_states = key_states.permute(0, 2, 1)
|
||||
return (key_states, value_states)
|
@ -1,36 +0,0 @@
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
from hivemind import get_logger
|
||||
from transformers.models.mixtral import MixtralConfig
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralAttention
|
||||
|
||||
from petals.client.config import ClientConfig
|
||||
from petals.client.lm_head import LMHeadConfig
|
||||
from petals.client.ptune import PTuneConfig
|
||||
from petals.models.mixtral.block import WrappedMixtralBlock
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DistributedMixtralConfig(MixtralConfig, ClientConfig, PTuneConfig, LMHeadConfig):
|
||||
block_class = WrappedMixtralBlock
|
||||
attn_class = MixtralAttention
|
||||
block_prefix = "model.layers"
|
||||
|
||||
num_key_value_groups = 1
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
|
||||
):
|
||||
loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
|
||||
if loading_from_repo and dht_prefix is None:
|
||||
dht_prefix = str(model_name_or_path)
|
||||
dht_prefix = dht_prefix.replace(".", "-")
|
||||
logger.info(f"Using DHT prefix: {dht_prefix}")
|
||||
result = super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
|
||||
config = result[0] if isinstance(result, tuple) else result
|
||||
if config.pad_token_id is None:
|
||||
config.pad_token_id = 0
|
||||
return result
|
@ -1,178 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from hivemind import DHT
|
||||
from hivemind.utils.logging import get_logger
|
||||
from transformers.modeling_outputs import MoeModelOutputWithPast
|
||||
from transformers.models.mixtral import (
|
||||
MixtralForCausalLM,
|
||||
MixtralForSequenceClassification,
|
||||
MixtralModel,
|
||||
MixtralPreTrainedModel,
|
||||
)
|
||||
|
||||
from petals.client.from_pretrained import FromPretrainedMixin
|
||||
from petals.client.lm_head import LMHead
|
||||
from petals.client.ptune import PTuneMixin
|
||||
from petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues
|
||||
from petals.client.remote_sequential import RemoteSequential
|
||||
from petals.models.mixtral.config import DistributedMixtralConfig
|
||||
from petals.utils.auto_config import DefaultRevisionMixin
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DistributedMixtralModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMixin, MixtralModel):
|
||||
"""MixtralModel, but all transformer layers are hosted by the swarm"""
|
||||
|
||||
_keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing
|
||||
_keys_to_ignore_on_load_unexpected = [r"^model\.layers\."]
|
||||
|
||||
config_class = DistributedMixtralConfig
|
||||
|
||||
def __init__(self, config: DistributedMixtralConfig, *, dht: Optional[DHT] = None):
|
||||
n_layer, config.num_hidden_layers = config.num_hidden_layers, 0 # Prevent initialization
|
||||
super().__init__(config)
|
||||
assert len(self.layers) == 0
|
||||
config.num_hidden_layers = n_layer
|
||||
|
||||
self.layers = RemoteSequential(config, dht=dht)
|
||||
|
||||
self.requires_grad_(False) # Forbid accumulate grads for embeddings and layernorm
|
||||
self.init_prompts(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[RemotePastKeyValues] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_router_logits: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
# The causal mask will be added on the server-side
|
||||
assert (
|
||||
attention_mask is None or (attention_mask == 1).all()
|
||||
), f"Custom attention masks are not supported, {attention_mask=}"
|
||||
assert (
|
||||
position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all()
|
||||
), f"Non-consecutive position_ids are not supported, {position_ids=}"
|
||||
assert head_mask is None, f"Custom head masks are not supported, {head_mask=}"
|
||||
assert use_cache is None or use_cache, f"{use_cache=} is not supported"
|
||||
assert not output_attentions, f"{output_attentions=} is not supported"
|
||||
assert not output_hidden_states, f"{output_hidden_states=} is not supported"
|
||||
assert return_dict is None or return_dict, f"{return_dict=} is not supported"
|
||||
assert not output_router_logits, f"{output_router_logits=} is not supported"
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
use_prompts = self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0
|
||||
if use_prompts:
|
||||
batch_size = inputs_embeds.shape[0]
|
||||
prompts, intermediate_prompts = self.get_prompt(batch_size)
|
||||
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
|
||||
else:
|
||||
prompts = intermediate_prompts = None
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
|
||||
if past_key_values is None:
|
||||
past_key_values = RemotePastKeyValues()
|
||||
past_key_values.update_seen(hidden_states.size(1))
|
||||
|
||||
hidden_states = self.layers(
|
||||
hidden_states,
|
||||
prompts=intermediate_prompts,
|
||||
hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None,
|
||||
)
|
||||
|
||||
# Remove prefix
|
||||
if use_prompts:
|
||||
hidden_states = hidden_states[:, self.pre_seq_len :]
|
||||
|
||||
# Add last hidden state
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states = hidden_states.view(output_shape)
|
||||
return MoeModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=past_key_values,
|
||||
hidden_states=None,
|
||||
attentions=None,
|
||||
)
|
||||
|
||||
@property
|
||||
def word_embeddings(self) -> nn.Embedding: # For compatibility with RemoteGenerationMixin
|
||||
return self.embed_tokens
|
||||
|
||||
@property
|
||||
def word_embeddings_layernorm(self) -> nn.Module: # For compatibility with RemoteGenerationMixin in tests
|
||||
return nn.Identity()
|
||||
|
||||
@property
|
||||
def h(self) -> RemoteSequential: # For compatibility with RemoteGenerationMixin
|
||||
return self.layers
|
||||
|
||||
@property
|
||||
def ln_f(self) -> nn.Module: # For compatibility with RemoteGenerationMixin in tests
|
||||
return self.norm
|
||||
|
||||
|
||||
class DistributedMixtralForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, MixtralForCausalLM):
|
||||
_keys_to_ignore_on_load_missing = DistributedMixtralModel._keys_to_ignore_on_load_missing
|
||||
_keys_to_ignore_on_load_unexpected = DistributedMixtralModel._keys_to_ignore_on_load_unexpected
|
||||
|
||||
config_class = DistributedMixtralConfig
|
||||
|
||||
def __init__(self, config: DistributedMixtralConfig):
|
||||
MixtralPreTrainedModel.__init__(self, config)
|
||||
self.model = DistributedMixtralModel(config)
|
||||
self.lm_head = LMHead(config)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
@property
|
||||
def transformer(self) -> DistributedMixtralModel: # For compatibility with RemoteGenerationMixin
|
||||
return self.model
|
||||
|
||||
|
||||
class DistributedMixtralForSequenceClassification(FromPretrainedMixin, MixtralForSequenceClassification):
|
||||
_keys_to_ignore_on_load_missing = DistributedMixtralModel._keys_to_ignore_on_load_missing
|
||||
_keys_to_ignore_on_load_unexpected = DistributedMixtralModel._keys_to_ignore_on_load_unexpected
|
||||
|
||||
config_class = DistributedMixtralConfig
|
||||
|
||||
def __init__(self, config: DistributedMixtralConfig):
|
||||
MixtralPreTrainedModel.__init__(self, config)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.model = DistributedMixtralModel(config)
|
||||
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@property
|
||||
def transformer(self) -> DistributedMixtralModel: # For compatibility with RemoteGenerationMixin
|
||||
return self.model
|
@ -1,230 +0,0 @@
|
||||
"""
|
||||
This module implements server-side computations on served blocks: forward, backward and inference; used by handler
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
from hivemind.compression.serialization import deserialize_torch_tensor, serialize_torch_tensor
|
||||
from hivemind.moe.expert_uid import ExpertUID
|
||||
from hivemind.proto import runtime_pb2
|
||||
from hivemind.utils.logging import get_logger
|
||||
from hivemind.utils.nested import nested_flatten
|
||||
|
||||
from petals.data_structures import Handle, InferenceMetadata
|
||||
from petals.server.backend import TransformerBackend
|
||||
from petals.server.task_pool import PrioritizedTaskPool
|
||||
from petals.server.task_prioritizer import TaskPrioritizerBase
|
||||
from petals.utils.convert_block import QuantType
|
||||
from petals.utils.misc import DUMMY, is_dummy
|
||||
from petals.utils.packaging import unpack_args_kwargs
|
||||
|
||||
# We prioritize short inference requests and make them use a *merged* inference pool,
|
||||
# so they are processed without interruptions and extra overheads
|
||||
# TODO: Increase the NF4 threshold once bitsandbytes ships efficient NF4 kernel for parallel forward
|
||||
MAX_SHORT_INFERENCE_TOKENS = 128
|
||||
MAX_NF4_SHORT_INFERENCE_TOKENS = 1
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def run_rpc_forward(
|
||||
*flat_tensors: torch.Tensor,
|
||||
requested_backends: Sequence[TransformerBackend],
|
||||
active_adapter: str = "",
|
||||
prioritizer: TaskPrioritizerBase,
|
||||
points: int = 0,
|
||||
args_structure: Any = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
|
||||
|
||||
:param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors
|
||||
:note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy)
|
||||
:param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
|
||||
:returns: hidden states after the last layer [batch_size, seq_length, hid_size]
|
||||
"""
|
||||
if args_structure is not None:
|
||||
# TODO: kwargs currently is unused, it can be used later for peft-like adaptation
|
||||
flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure)
|
||||
hidden_states, prompts, *_ = flat_tensors
|
||||
|
||||
dtype = requested_backends[0].dtype
|
||||
# check parse input tensors and cast dtypes
|
||||
hidden_states = hidden_states.to(dtype)
|
||||
assert hidden_states.ndim == 3
|
||||
if prompts is None or is_dummy(prompts):
|
||||
prompts = [DUMMY] * len(requested_backends)
|
||||
else:
|
||||
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
|
||||
|
||||
# Run a chain of requested backends
|
||||
for backend, prompt in zip(requested_backends, prompts):
|
||||
if not is_dummy(prompt):
|
||||
hidden_states[:, : prompt.shape[1]] += prompt
|
||||
|
||||
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
|
||||
priority = prioritizer.prioritize(
|
||||
hidden_states, points=points / len(requested_backends), backend=backend, type="forward"
|
||||
)
|
||||
(hidden_states,) = await backend.forward_pool.submit_task(
|
||||
hidden_states,
|
||||
active_adapter,
|
||||
priority=priority,
|
||||
)
|
||||
assert isinstance(hidden_states, torch.Tensor)
|
||||
assert (
|
||||
hidden_states.ndim == 3
|
||||
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
async def run_rpc_backward(
|
||||
*flat_tensors: torch.Tensor,
|
||||
requested_backends: Sequence[TransformerBackend],
|
||||
active_adapter: str = "",
|
||||
prioritizer: TaskPrioritizerBase,
|
||||
points: int = 0,
|
||||
args_structure: Any = None,
|
||||
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
|
||||
if args_structure is not None:
|
||||
# TODO: kwargs currently is unused, it can be used later for peft-like adaptation
|
||||
flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure)
|
||||
inputs, grad_outputs, prompts, *_ = flat_tensors
|
||||
|
||||
# Cast inputs & grad outputs to backend dtype
|
||||
inputs = inputs.to(requested_backends[0].dtype)
|
||||
grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
|
||||
|
||||
if prompts is None or is_dummy(prompts):
|
||||
prompts = [DUMMY] * len(requested_backends)
|
||||
else:
|
||||
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
|
||||
|
||||
# Run a forward chain to collect intermediate inputs
|
||||
# Note that we do not forward for the last module since we do not need its output
|
||||
inter_inputs = []
|
||||
for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
|
||||
assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
|
||||
if not is_dummy(prompt):
|
||||
inputs[:, : prompt.shape[1]] += prompt
|
||||
inter_inputs.append(inputs)
|
||||
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
|
||||
priority = prioritizer.prioritize(
|
||||
inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
|
||||
)
|
||||
(inputs,) = await backend.forward_pool.submit_task(inputs, active_adapter, priority=priority)
|
||||
|
||||
assert isinstance(inputs, torch.Tensor)
|
||||
|
||||
if not is_dummy(prompts[-1]):
|
||||
inputs[:, : prompts[-1].shape[1]] += prompts[-1]
|
||||
inter_inputs.append(inputs)
|
||||
|
||||
assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
|
||||
grad_prompts_reversed = []
|
||||
# Run a chain of requested backends
|
||||
for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))):
|
||||
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
|
||||
priority = prioritizer.prioritize(
|
||||
inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
|
||||
)
|
||||
(grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, active_adapter, priority=priority)
|
||||
|
||||
assert isinstance(grad_outputs, torch.Tensor)
|
||||
if not is_dummy(prompt):
|
||||
grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
|
||||
|
||||
grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
|
||||
return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] # TODO un-duct-tape
|
||||
|
||||
|
||||
async def iterate_rpc_inference(
|
||||
requested_uids: Sequence[ExpertUID],
|
||||
requested_backends: Sequence[TransformerBackend],
|
||||
active_adapter: Optional[str],
|
||||
input_iterator: AsyncIterator[Tuple[runtime_pb2.ExpertRequest, dict]],
|
||||
cache_handles: Sequence[Sequence[Handle]],
|
||||
*,
|
||||
max_length: int,
|
||||
prioritizer: TaskPrioritizerBase,
|
||||
points: int,
|
||||
quant_type: QuantType,
|
||||
args_structure: Any = None,
|
||||
) -> AsyncIterator[Tuple[Sequence[runtime_pb2.Tensor], bool, Dict]]:
|
||||
assert len(cache_handles) == len(requested_backends)
|
||||
|
||||
prefix_length = 0
|
||||
point_per_piece = points / max_length if max_length > 0 else 0.0
|
||||
|
||||
async for request, step_metadata in input_iterator:
|
||||
flat_tensors = tuple(deserialize_torch_tensor(tensor) for tensor in request.tensors)
|
||||
if args_structure is not None:
|
||||
# TODO: kwargs currently is unused, it can be used later for peft-like adaptation
|
||||
flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure)
|
||||
|
||||
hidden_states, prompts, hypo_ids, *_ = flat_tensors
|
||||
batch_size, length_increment, _ = hidden_states.shape
|
||||
|
||||
# Cast inputs to backend dtype
|
||||
hidden_states = hidden_states.to(requested_backends[0].dtype)
|
||||
assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}"
|
||||
|
||||
# parse deep prompts (optional argument)
|
||||
has_prompts = prompts is not None and not is_dummy(prompts)
|
||||
if not has_prompts:
|
||||
prompts = [None] * len(requested_backends)
|
||||
else:
|
||||
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
|
||||
prompts = [prompt if not is_dummy(prompt) else None for prompt in prompts]
|
||||
|
||||
if not (len(requested_backends) == len(prompts)):
|
||||
raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends")
|
||||
|
||||
if prefix_length + length_increment > max_length:
|
||||
raise ValueError(
|
||||
f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
|
||||
f" exceeds pre-allocated maximum {max_length}"
|
||||
)
|
||||
|
||||
merge_max_tokens = MAX_NF4_SHORT_INFERENCE_TOKENS if quant_type == QuantType.NF4 else MAX_SHORT_INFERENCE_TOKENS
|
||||
can_merge_pools = batch_size * length_increment <= merge_max_tokens
|
||||
priority = prioritizer.prioritize(
|
||||
hidden_states,
|
||||
hypo_ids,
|
||||
points=point_per_piece,
|
||||
requested_uids=requested_uids,
|
||||
type="inference",
|
||||
)
|
||||
|
||||
# A client may pass a tensor with 0 tokens. This is a special case that occurs, e.g.
|
||||
# when user wants to pre-allocate cache or check that server *can* allocate that cache.
|
||||
if hidden_states.numel() > 0:
|
||||
assert hidden_states.ndim == 3, f"hidden states must be a single 3d tensor"
|
||||
if can_merge_pools:
|
||||
inference_infos = tuple(
|
||||
InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter)
|
||||
for uid, handles in zip(requested_uids, cache_handles)
|
||||
)
|
||||
(hidden_states,) = await requested_backends[0].inference_pool.submit_task(
|
||||
hidden_states, hypo_ids, inference_infos, *prompts, priority=priority
|
||||
)
|
||||
else:
|
||||
for backend, uid, handles, prompt in zip(requested_backends, requested_uids, cache_handles, prompts):
|
||||
inference_infos = (InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter),)
|
||||
(hidden_states,) = await backend.inference_pool.submit_task(
|
||||
hidden_states, hypo_ids, inference_infos, prompt, priority=priority
|
||||
)
|
||||
|
||||
# serialize and send last layer outputs
|
||||
output_tensors = [
|
||||
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
|
||||
for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
|
||||
]
|
||||
can_push = not has_prompts
|
||||
yield output_tensors, can_push, step_metadata
|
||||
|
||||
# prepare for next step
|
||||
prefix_length += length_increment
|
@ -1,76 +0,0 @@
|
||||
import torch
|
||||
from torch.utils._pytree import tree_flatten as _tree_flatten, tree_unflatten as _tree_unflatten
|
||||
|
||||
|
||||
def make_inference_graphed_callable(callable: callable, sample_args, num_warmup_iters=3):
|
||||
"""Similar to torch.cuda.make_graphed_callables, but takes only one function and does not build a graph for the backward pass"""
|
||||
assert not isinstance(callable, torch.nn.Module)
|
||||
if torch.is_autocast_enabled() and torch.is_autocast_cache_enabled():
|
||||
raise RuntimeError(
|
||||
"make_graphed_callables does not support the autocast caching. Please set `cache_enabled=False`."
|
||||
)
|
||||
|
||||
flatten_arg, _ = _tree_flatten(sample_args)
|
||||
flatten_sample_args = tuple(flatten_arg)
|
||||
assert all(
|
||||
isinstance(arg, torch.Tensor) for arg in flatten_arg
|
||||
), "In the beta API, sample_args for each callable must contain only Tensors. Other types are not allowed."
|
||||
|
||||
len_user_args = len(sample_args)
|
||||
static_input_surface = flatten_sample_args
|
||||
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
|
||||
# Warmup
|
||||
# Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work
|
||||
# from ending up in any captures.
|
||||
s = torch.cuda.Stream()
|
||||
s.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(s):
|
||||
for _ in range(num_warmup_iters):
|
||||
outputs, _ = _tree_flatten(callable(*sample_args))
|
||||
del outputs
|
||||
torch.cuda.current_stream().wait_stream(s)
|
||||
|
||||
# Capture forward graph
|
||||
with torch.cuda.graph(graph):
|
||||
outputs = callable(*sample_args)
|
||||
|
||||
flatten_outputs, output_unflatten_spec = _tree_flatten(outputs)
|
||||
static_outputs = tuple(flatten_outputs)
|
||||
|
||||
def make_graphed_function(
|
||||
graph,
|
||||
len_user_args,
|
||||
output_unflatten_spec,
|
||||
static_input_surface,
|
||||
static_outputs,
|
||||
):
|
||||
def replay_graph(*inputs):
|
||||
# At this stage, only the user args may (potentially) be new tensors.
|
||||
for i in range(len_user_args):
|
||||
if static_input_surface[i].data_ptr() != inputs[i].data_ptr():
|
||||
static_input_surface[i].copy_(inputs[i])
|
||||
graph.replay()
|
||||
assert isinstance(static_outputs, tuple)
|
||||
return tuple(o.detach() for o in static_outputs)
|
||||
|
||||
def functionalized(*user_args):
|
||||
# Runs the autograd function with inputs == all inputs to the graph that might require grad
|
||||
# (explicit user args + module parameters)
|
||||
# Assumes module params didn't change since capture.
|
||||
flatten_user_args, _ = _tree_flatten(user_args)
|
||||
out = replay_graph(*flatten_user_args)
|
||||
return _tree_unflatten(out, output_unflatten_spec)
|
||||
|
||||
return functionalized
|
||||
|
||||
# Put together the final graphed callable
|
||||
graphed = make_graphed_function(
|
||||
graph,
|
||||
len_user_args,
|
||||
output_unflatten_spec,
|
||||
static_input_surface,
|
||||
static_outputs,
|
||||
)
|
||||
return graphed
|
@ -1,153 +0,0 @@
|
||||
"""
|
||||
Utilities for declaring and retrieving active model layers using a shared DHT.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Dict, List, Optional, Sequence, Union
|
||||
|
||||
from hivemind.dht import DHT, DHTNode, DHTValue
|
||||
from hivemind.p2p import PeerID
|
||||
from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger
|
||||
|
||||
from petals.data_structures import (
|
||||
CHAIN_DELIMITER,
|
||||
UID_DELIMITER,
|
||||
ModuleUID,
|
||||
RemoteModuleInfo,
|
||||
RemoteSpanInfo,
|
||||
ServerInfo,
|
||||
ServerState,
|
||||
parse_uid,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def declare_active_modules(
|
||||
dht: DHT,
|
||||
uids: Sequence[ModuleUID],
|
||||
server_info: ServerInfo,
|
||||
expiration_time: DHTExpiration,
|
||||
wait: bool = True,
|
||||
) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]:
|
||||
"""
|
||||
Declare that your node serves the specified modules; update timestamps if declared previously
|
||||
|
||||
:param uids: a list of module ids to declare
|
||||
:param wait: if True, awaits for declaration to finish, otherwise runs in background
|
||||
:param throughput: specify your performance in terms of compute throughput
|
||||
:param expiration_time: declared modules will be visible for this many seconds
|
||||
:returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
|
||||
"""
|
||||
if isinstance(uids, str):
|
||||
uids = [uids]
|
||||
if not isinstance(uids, list):
|
||||
uids = list(uids)
|
||||
for uid in uids:
|
||||
assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid
|
||||
|
||||
return dht.run_coroutine(
|
||||
partial(_declare_active_modules, uids=uids, server_info=server_info, expiration_time=expiration_time),
|
||||
return_future=not wait,
|
||||
)
|
||||
|
||||
|
||||
async def _declare_active_modules(
|
||||
dht: DHT,
|
||||
node: DHTNode,
|
||||
uids: List[ModuleUID],
|
||||
server_info: ServerInfo,
|
||||
expiration_time: DHTExpiration,
|
||||
) -> Dict[ModuleUID, bool]:
|
||||
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
|
||||
return await node.store_many(
|
||||
keys=uids,
|
||||
subkeys=[dht.peer_id.to_base58()] * len(uids),
|
||||
values=[server_info.to_tuple()] * len(uids),
|
||||
expiration_time=expiration_time,
|
||||
num_workers=num_workers,
|
||||
)
|
||||
|
||||
|
||||
def get_remote_module_infos(
|
||||
dht: DHT,
|
||||
uids: Sequence[ModuleUID],
|
||||
expiration_time: Optional[DHTExpiration] = None,
|
||||
active_adapter: Optional[str] = None,
|
||||
*,
|
||||
latest: bool = False,
|
||||
return_future: bool = False,
|
||||
) -> Union[List[RemoteModuleInfo], MPFuture]:
|
||||
return dht.run_coroutine(
|
||||
partial(
|
||||
_get_remote_module_infos,
|
||||
uids=uids,
|
||||
active_adapter=active_adapter,
|
||||
expiration_time=expiration_time,
|
||||
latest=latest,
|
||||
),
|
||||
return_future=return_future,
|
||||
)
|
||||
|
||||
|
||||
async def _get_remote_module_infos(
|
||||
dht: DHT,
|
||||
node: DHTNode,
|
||||
uids: List[ModuleUID],
|
||||
active_adapter: Optional[str],
|
||||
expiration_time: Optional[DHTExpiration],
|
||||
latest: bool,
|
||||
) -> List[RemoteModuleInfo]:
|
||||
if latest:
|
||||
assert expiration_time is None, "You should define either `expiration_time` or `latest`, not both"
|
||||
expiration_time = math.inf
|
||||
elif expiration_time is None:
|
||||
expiration_time = get_dht_time()
|
||||
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
|
||||
found: Dict[ModuleUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
|
||||
|
||||
modules = [RemoteModuleInfo(uid=uid, servers={}) for uid in uids]
|
||||
for module_info in modules:
|
||||
metadata = found[module_info.uid]
|
||||
if metadata is None or not isinstance(metadata.value, dict):
|
||||
if metadata is not None:
|
||||
logger.warning(f"Incorrect metadata for {module_info.uid}: {metadata}")
|
||||
continue
|
||||
|
||||
for peer_id, server_info in metadata.value.items():
|
||||
try:
|
||||
peer_id = PeerID.from_base58(peer_id)
|
||||
server_info = ServerInfo.from_tuple(server_info.value)
|
||||
|
||||
if active_adapter and active_adapter not in server_info.adapters:
|
||||
logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}")
|
||||
continue
|
||||
|
||||
module_info.servers[peer_id] = server_info
|
||||
except (TypeError, ValueError) as e:
|
||||
logger.warning(f"Incorrect peer entry for uid={module_info.uid}, peer_id={peer_id}: {e}")
|
||||
return modules
|
||||
|
||||
|
||||
def compute_spans(module_infos: List[RemoteModuleInfo], *, min_state: ServerState) -> Dict[PeerID, RemoteSpanInfo]:
|
||||
block_offset = parse_uid(module_infos[0].uid)[1] if module_infos else 0
|
||||
num_blocks = len(module_infos)
|
||||
|
||||
spans = {}
|
||||
for block_idx, module_info in enumerate(module_infos):
|
||||
for peer_id, server_info in sorted(module_info.servers.items()):
|
||||
if server_info.state.value < min_state.value:
|
||||
continue
|
||||
|
||||
if peer_id not in spans or spans[peer_id].state.value < server_info.state.value:
|
||||
spans[peer_id] = RemoteSpanInfo(
|
||||
peer_id=peer_id, start=block_idx, end=block_idx + 1, server_info=server_info
|
||||
)
|
||||
if server_info.start_block is not None and server_info.end_block is not None:
|
||||
spans[peer_id].start = max(server_info.start_block - block_offset, 0)
|
||||
spans[peer_id].end = min(server_info.end_block - block_offset, num_blocks)
|
||||
elif spans[peer_id].state == server_info.state:
|
||||
spans[peer_id].end = max(spans[peer_id].end, block_idx + 1)
|
||||
return spans
|
@ -0,0 +1,128 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
TokenIds = torch.Tensor
|
||||
HypoIds = torch.Tensor
|
||||
|
||||
|
||||
class DecodingAlgorithm(ABC):
|
||||
"""
|
||||
An abstract class for decoding algorithms. Describes the base function of those algorithms:
|
||||
they have to select new tokens and provide the corresponding hypotheses.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
||||
"""
|
||||
:param logits: A tensor of shape (batch_size, seq_length, vocab_size)
|
||||
:return: A tuple of selected token ids and corresponding hypotheses.
|
||||
The shape of the token ids is (batch_size, seq_length), and the shape of the hypotheses is (batch_size)
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class GreedyAlgorithm(DecodingAlgorithm):
|
||||
"""
|
||||
The simplest algorithm for decoding. It selects the most probable token.
|
||||
"""
|
||||
|
||||
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
||||
"""
|
||||
Returns the most probable token. The second returned object is always a range of integers
|
||||
from 0 to batch_size - 1.
|
||||
"""
|
||||
return logits.max(-1)[1].unsqueeze(1), torch.arange(logits.size(0))
|
||||
|
||||
|
||||
class SamplingAlgorithm(DecodingAlgorithm):
|
||||
def __init__(self, temperature: float = 1.0):
|
||||
self.temperature = temperature
|
||||
|
||||
def sample(self, logits: torch.Tensor, indices_to_remove: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
||||
"""
|
||||
:param logits: A tensor of shape (batch_size * num_hypos, vocab_size)
|
||||
:param indices_to_remove: A bool tensor of shape (batch_size * num_hypos, vocab_size)
|
||||
:return: A tuple of selected token ids and corresponding hypotheses.
|
||||
The shape of the token ids is (batch_size, seq_length), and the shape of the hypotheses is (batch_size).
|
||||
"""
|
||||
logits[indices_to_remove] = -float("Inf")
|
||||
probs = torch.softmax(logits / self.temperature, -1)
|
||||
return torch.multinomial(probs, num_samples=1), torch.arange(logits.size(0))
|
||||
|
||||
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
||||
indices_to_remove = torch.full_like(logits, False, dtype=torch.bool)
|
||||
return self.sample(logits, indices_to_remove)
|
||||
|
||||
|
||||
class TopKAlgorithm(SamplingAlgorithm):
|
||||
def __init__(self, top_k: int, temperature: float = 1.0) -> None:
|
||||
super().__init__(temperature=temperature)
|
||||
self.top_k = top_k
|
||||
|
||||
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
||||
indices_to_remove = logits < torch.topk(logits, self.top_k, dim=-1)[0][..., -1, None]
|
||||
return self.sample(logits, indices_to_remove)
|
||||
|
||||
|
||||
class NucleusAlgorithm(SamplingAlgorithm):
|
||||
def __init__(self, top_p: float, temperature: float = 1.0) -> None:
|
||||
super().__init__(temperature=temperature)
|
||||
self.top_p = top_p
|
||||
|
||||
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=False, dim=-1)
|
||||
probs = torch.softmax(sorted_logits / self.temperature, -1)
|
||||
cumulative_probs = torch.cumsum(probs, dim=-1)
|
||||
|
||||
sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p)
|
||||
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
return self.sample(logits, indices_to_remove)
|
||||
|
||||
|
||||
class BeamSearchAlgorithm(DecodingAlgorithm):
|
||||
def __init__(self, num_beams: int, batch_size: int) -> None:
|
||||
self.num_beams = num_beams
|
||||
self.batch_size = batch_size
|
||||
|
||||
self._batch_beams = [list() for _ in range(batch_size)]
|
||||
|
||||
def __call__(self, logits: torch.Tensor):
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||
probs = torch.log_softmax(sorted_logits, -1)
|
||||
|
||||
if len(self._batch_beams[0]) > 0:
|
||||
for batch_idx in range(self.batch_size):
|
||||
new_beams = []
|
||||
cur_beams = self._batch_beams[batch_idx]
|
||||
for beam_idx in range(len(cur_beams)):
|
||||
probs_idx = batch_idx + beam_idx * self.batch_size
|
||||
new_beam = cur_beams[beam_idx]
|
||||
for hypo_idx in range(self.num_beams):
|
||||
new_beams.append(
|
||||
(new_beam[0] + probs[probs_idx, hypo_idx].item(), beam_idx * self.num_beams + hypo_idx)
|
||||
)
|
||||
self._batch_beams[batch_idx] = sorted(new_beams, reverse=True)[: self.num_beams]
|
||||
else:
|
||||
for batch_idx in range(self.batch_size):
|
||||
for beam_idx in range(self.num_beams):
|
||||
self._batch_beams[batch_idx].append((probs[batch_idx, beam_idx].item(), beam_idx))
|
||||
|
||||
return_hypos = []
|
||||
return_tokens = []
|
||||
for batch_idx in range(self.batch_size):
|
||||
cur_beam = self._batch_beams[batch_idx]
|
||||
return_hypos.append(list())
|
||||
return_tokens.append(list())
|
||||
for beam in cur_beam:
|
||||
beam_idx = beam[1] // self.num_beams
|
||||
hypo_idx = batch_idx + beam_idx * self.batch_size
|
||||
token_idx = beam[1] % self.num_beams
|
||||
return_hypos[-1].append(hypo_idx)
|
||||
return_tokens[-1].append([sorted_indices[hypo_idx, token_idx].item()])
|
||||
return_hypos = [hypo_idx for hypo_indexes in zip(*return_hypos) for hypo_idx in hypo_indexes]
|
||||
return_tokens = [token_idx for token_indexes in zip(*return_tokens) for token_idx in token_indexes]
|
||||
|
||||
return torch.tensor(return_tokens), torch.tensor(return_hypos)
|
@ -0,0 +1,51 @@
|
||||
from abc import ABC
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class ABCBloomConstraint(ABC):
|
||||
"""
|
||||
Base class of all kind of decoding constraints. It can be used to implement a new constraint.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
This method is called by the decoding algorithm to apply the constraint. It changes and returns new logits.
|
||||
:param tokens_id: The token id of the last chosen token.
|
||||
:param logits: The logits from the Bloom model.
|
||||
:param hypo_ids: The hypothesis ids of the last tokens.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class EosConstraint(ABCBloomConstraint):
|
||||
"""
|
||||
This constrained repeats EOS token if it was generated on the previous step.
|
||||
Args:
|
||||
prefix: The prefix of the sequence.
|
||||
eos_token_id: The id of the end of sentence token.
|
||||
pad_token_id: The id of the padding token.
|
||||
min_logits: The minimum logits that can be generated. Default: -1e6.
|
||||
"""
|
||||
|
||||
def __init__(self, prefix: torch.Tensor, eos_token_id: int, pad_token_id: int, min_logits: float = -1e8) -> None:
|
||||
self.eos_token_id = eos_token_id
|
||||
self.min_logits = min_logits
|
||||
self.past_tokens = None
|
||||
|
||||
self.wait_until_starting = (prefix == pad_token_id).sum(1).unsqueeze(1)
|
||||
|
||||
def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor:
|
||||
if self.past_tokens is not None:
|
||||
mask = (self.wait_until_starting < 0) & (self.past_tokens == self.eos_token_id)
|
||||
logits += self.min_logits * mask
|
||||
logits[mask[:, 0], self.eos_token_id] = 0
|
||||
|
||||
if tokens_id is not None:
|
||||
self.past_tokens = tokens_id
|
||||
self.wait_until_starting -= 1
|
||||
|
||||
return logits
|
@ -1,49 +0,0 @@
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
from hivemind import nested_flatten, nested_pack
|
||||
|
||||
# TODO: Move functions to hivemind
|
||||
|
||||
|
||||
def _mark_masked_tensor(index: int) -> bytes:
|
||||
return b"__T" + str(index).encode()
|
||||
|
||||
|
||||
def _is_masked_tensor(item: Any) -> bool:
|
||||
return isinstance(item, bytes) and item.startswith(b"__T")
|
||||
|
||||
|
||||
def _get_tensor_index(item: bytes) -> int:
|
||||
return int(item[3:])
|
||||
|
||||
|
||||
def pack_args_kwargs(*args, **kwargs) -> Tuple[List[torch.Tensor], Any]:
|
||||
"""
|
||||
Check the function's arguments and pack all tensors into different flattened lists.
|
||||
:returns: a flattened list of tensors and args and kwargs, where tensors were masked
|
||||
"""
|
||||
masked_flat_values, flat_tensors, tensor_to_index = [], [], {}
|
||||
for value in nested_flatten((args, kwargs)):
|
||||
if isinstance(value, torch.Tensor):
|
||||
tensor_index = tensor_to_index.setdefault(value, len(flat_tensors))
|
||||
if tensor_index == len(flat_tensors):
|
||||
flat_tensors.append(value)
|
||||
masked_flat_values.append(_mark_masked_tensor(tensor_index))
|
||||
else:
|
||||
masked_flat_values.append(value)
|
||||
return flat_tensors, nested_pack(masked_flat_values, (args, kwargs))
|
||||
|
||||
|
||||
def unpack_args_kwargs(flat_tensors: List[torch.Tensor], args_structure: Any):
|
||||
"""
|
||||
Restore arguments after `pack_args_kwargs` function.
|
||||
:returns: list of args and dict of kwargs
|
||||
"""
|
||||
return nested_pack(
|
||||
(
|
||||
value if not _is_masked_tensor(value) else flat_tensors[_get_tensor_index(value)]
|
||||
for value in nested_flatten(args_structure)
|
||||
),
|
||||
args_structure,
|
||||
)
|
@ -0,0 +1,25 @@
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
|
||||
from huggingface_hub import delete_repo, list_models
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Remove old testing models from HF hub")
|
||||
parser.add_argument("--author", type=str, default="bloom-testing", help="auth token for from_pretrained")
|
||||
parser.add_argument("--seconds_since_last_updated", type=int, default=7 * 24 * 60 * 60)
|
||||
parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
|
||||
parser.add_argument("--dry_run", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
for model in list_models(author=args.author, full=True):
|
||||
last_modified = datetime.strptime(model.lastModified, "%Y-%m-%dT%H:%M:%S.%fZ")
|
||||
|
||||
if model.modelId.endswith("-main") or "/test-" not in model.modelId:
|
||||
continue # remove only test models
|
||||
|
||||
if (datetime.now() - last_modified).total_seconds() > args.seconds_since_last_updated:
|
||||
if args.dry_run:
|
||||
print(f"{model.modelId} can be deleted")
|
||||
else:
|
||||
delete_repo(repo_id=model.modelId, token=args.use_auth_token)
|
Binary file not shown.
@ -1,184 +0,0 @@
|
||||
import asyncio
|
||||
import multiprocessing as mp
|
||||
import random
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio # make sure the module exists; otherwise the test will be skipped
|
||||
import torch
|
||||
from hivemind import TensorDescriptor
|
||||
|
||||
from petals.server.memory_cache import AllocationFailed, MemoryCache
|
||||
from petals.utils.misc import get_size_in_bytes
|
||||
|
||||
|
||||
def _make_tensor_descriptor(num_bytes: int, dtype: Optional[torch.dtype] = None):
|
||||
if dtype is None:
|
||||
dtype = random.choice((torch.int64, torch.int8, torch.uint8, torch.float32, torch.bfloat16, torch.bool))
|
||||
elem_size_bytes = get_size_in_bytes(dtype)
|
||||
descr = TensorDescriptor.from_tensor(torch.empty((num_bytes // elem_size_bytes,), dtype=dtype))
|
||||
return descr
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_timeout():
|
||||
cache = MemoryCache(max_size_bytes=1024, max_alloc_timeout=0.5)
|
||||
cache.runtime_pid += 1 # pretend we're another process
|
||||
async with cache.allocate_cache(_make_tensor_descriptor(768), timeout=0):
|
||||
pass
|
||||
|
||||
async with cache.allocate_cache(_make_tensor_descriptor(100), timeout=999):
|
||||
async with cache.allocate_cache(_make_tensor_descriptor(512), timeout=0):
|
||||
async with cache.allocate_cache(_make_tensor_descriptor(128), _make_tensor_descriptor(32), timeout=1):
|
||||
t_start = time.perf_counter()
|
||||
with pytest.raises(AllocationFailed):
|
||||
async with cache.allocate_cache(_make_tensor_descriptor(768), timeout=0.1):
|
||||
pass
|
||||
assert 0.1 < time.perf_counter() - t_start < 0.2, "wait time exceeds alloc timeout"
|
||||
async with cache.allocate_cache(_make_tensor_descriptor(128), timeout=float("inf")):
|
||||
pass
|
||||
|
||||
t_start = time.perf_counter()
|
||||
with pytest.raises(AllocationFailed):
|
||||
async with cache.allocate_cache(_make_tensor_descriptor(384), timeout=1.0): # exceeds max timeout
|
||||
pass
|
||||
assert 0.5 < time.perf_counter() - t_start < 0.6, "wait time exceeds max alloc timeout"
|
||||
|
||||
# test memory allocation when another task frees the memory
|
||||
async def _klog_the_cache():
|
||||
async with cache.allocate_cache(_make_tensor_descriptor(512), timeout=0.2):
|
||||
pass
|
||||
|
||||
large_alloc_task = asyncio.create_task(_klog_the_cache())
|
||||
|
||||
t_start = time.perf_counter()
|
||||
await asyncio.sleep(0.05) # wait for large alloc to enqueue
|
||||
async with cache.allocate_cache(_make_tensor_descriptor(128), timeout=float("inf")): # exceeds max timeout
|
||||
pass # this memory should allocate once the background task clears the queue
|
||||
assert 0.2 < time.perf_counter() - t_start < 0.3, "memory should be allocated after background task clears"
|
||||
with pytest.raises(AllocationFailed):
|
||||
await large_alloc_task
|
||||
|
||||
# test that zero-timeout allocation fails instantaneously even if someone else is awaiting alloc
|
||||
large_alloc_task = asyncio.create_task(_klog_the_cache())
|
||||
t_start = time.perf_counter()
|
||||
await asyncio.sleep(0.05) # wait for large alloc to enqueue
|
||||
with pytest.raises(AllocationFailed):
|
||||
async with cache.allocate_cache(_make_tensor_descriptor(512), timeout=0):
|
||||
pass # this memory should allocate once the background task clears the queue
|
||||
assert time.perf_counter() - t_start < 0.1, "zero-timeout task should fail (or succeed) instantaneously"
|
||||
with pytest.raises(AllocationFailed):
|
||||
await large_alloc_task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unlimited_timeout():
|
||||
cache = MemoryCache(max_size_bytes=1024)
|
||||
cache.runtime_pid += 1 # pretend we're another process
|
||||
t_start = time.perf_counter()
|
||||
|
||||
async def _klog_the_cache():
|
||||
async with cache.allocate_cache(_make_tensor_descriptor(512), timeout=0.2):
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
alloc_task = asyncio.create_task(_klog_the_cache())
|
||||
await asyncio.sleep(0.1)
|
||||
async with cache.allocate_cache(_make_tensor_descriptor(768), timeout=float("inf")):
|
||||
await alloc_task
|
||||
assert 0.5 < time.perf_counter() - t_start < 0.6, "memory should be allocated after background task clears"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_usage():
|
||||
cache = MemoryCache(max_size_bytes=2048)
|
||||
alloc_event, dealloc_a_event, dealloc_bcd_event, dealloc_e_event, dealloc_f_event = (mp.Event() for _ in range(5))
|
||||
pipe_receiver, pipe_sender = mp.Pipe(duplex=False)
|
||||
with pytest.raises(AssertionError):
|
||||
async with cache.allocate_cache(_make_tensor_descriptor(123), timeout=1):
|
||||
pass # fails because cache must be allocated from another process
|
||||
|
||||
descr_a = TensorDescriptor.from_tensor(torch.empty(768, dtype=torch.uint8)) # 768 bytes
|
||||
descr_b = TensorDescriptor.from_tensor(torch.empty((), dtype=torch.float64)) # 8 bytes
|
||||
descr_c = TensorDescriptor.from_tensor(torch.empty((33,), dtype=torch.bool)) # 33 bytes
|
||||
descr_d = TensorDescriptor.from_tensor(torch.empty((0,), dtype=torch.int64)) # 0 bytes
|
||||
descr_e = TensorDescriptor.from_tensor(torch.empty((96, 8), dtype=torch.bfloat16)) # 1536 bytes
|
||||
descr_f = TensorDescriptor.from_tensor(torch.empty((1792,), dtype=torch.uint8)) # 1792 bytes
|
||||
|
||||
async def _allocate_and_wait(dealloc_event, *descrs, timeout=None):
|
||||
loop = asyncio.get_event_loop()
|
||||
async with cache.allocate_cache(*descrs, timeout=timeout) as handles:
|
||||
pipe_sender.send(handles)
|
||||
await loop.run_in_executor(None, dealloc_event.wait)
|
||||
|
||||
async def _allocate_af():
|
||||
alloc_event.wait()
|
||||
allocate_a_task = asyncio.create_task(_allocate_and_wait(dealloc_a_event, descr_a))
|
||||
await allocate_a_task
|
||||
allocate_f_task = asyncio.create_task(_allocate_and_wait(dealloc_f_event, descr_f)) # klogs the cache
|
||||
await allocate_f_task
|
||||
|
||||
alloc_process1 = mp.context.ForkProcess(target=lambda: asyncio.run(_allocate_af()), daemon=True)
|
||||
alloc_process1.start()
|
||||
|
||||
async def _allocate_bcde():
|
||||
alloc_event.wait()
|
||||
await asyncio.sleep(0.1) # ensure that the other tensor is always allocated (and sent through pipe) first
|
||||
allocate_bcd_task = asyncio.create_task(_allocate_and_wait(dealloc_bcd_event, descr_b, descr_c, descr_d))
|
||||
allocate_e_task = asyncio.create_task(_allocate_and_wait(dealloc_e_event, descr_e)) # doesn't fit
|
||||
await asyncio.wait({allocate_e_task, allocate_bcd_task}, return_when=asyncio.ALL_COMPLETED)
|
||||
|
||||
alloc_process2 = mp.context.ForkProcess(target=lambda: asyncio.run(_allocate_bcde()), daemon=True)
|
||||
alloc_process2.start()
|
||||
assert cache.current_size_bytes == 0
|
||||
alloc_event.set()
|
||||
(handle_a,) = pipe_receiver.recv()
|
||||
|
||||
handle_b, handle_c, handle_d = pipe_receiver.recv()
|
||||
|
||||
with cache.use_cache(handle_a) as (tensor_a,):
|
||||
assert tensor_a.dtype == torch.uint8
|
||||
tensor_a[2:5] = torch.tensor((42, 43, 44))
|
||||
|
||||
with cache.use_cache(handle_a, handle_b, handle_d) as (tensor_a, tensor_b, tensor_d):
|
||||
assert tensor_b.dtype == torch.float64 and tensor_b.numel() == 1 and tensor_b.ndim == 0
|
||||
assert tensor_d.dtype == torch.int64 and tensor_d.numel() == 0
|
||||
tensor_a += 1
|
||||
tensor_b[...] = -1.337
|
||||
assert cache.current_size_bytes == 809 # this checks a,b,c,d are allocated but b still awaits memory
|
||||
|
||||
dealloc_bcd_event.set()
|
||||
await asyncio.sleep(0.1)
|
||||
assert cache.current_size_bytes == 768 # only tensor a should be allocated
|
||||
with pytest.raises(KeyError):
|
||||
with cache.use_cache(handle_a, handle_b):
|
||||
pass # one of handles (c) is deallocated
|
||||
with pytest.raises(KeyError):
|
||||
with cache.use_cache(handle_d):
|
||||
pass # handle_d is deallocated correctly, even though it is never used
|
||||
with cache.use_cache(handle_a) as (tensor_a,):
|
||||
assert tuple(tensor_a[2:5]) == (43, 44, 45)
|
||||
|
||||
dealloc_a_event.set()
|
||||
(handle_e,) = pipe_receiver.recv() # e can finally be allocated
|
||||
await asyncio.sleep(0.1)
|
||||
assert cache.current_size_bytes == 1536 # tensor e should finally be able to allocate
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
with cache.use_cache(handle_a):
|
||||
pass # tensor a is no longer allocated
|
||||
with cache.use_cache(handle_e) as (tensor_e,):
|
||||
assert tensor_e.dtype == torch.bfloat16 and tensor_e.shape == (96, 8)
|
||||
|
||||
dealloc_e_event.set()
|
||||
await asyncio.sleep(0.1)
|
||||
assert cache.current_size_bytes == 1792 # only tensor f is still allocated
|
||||
dealloc_f_event.set()
|
||||
|
||||
alloc_process1.join()
|
||||
alloc_process2.join()
|
||||
await asyncio.sleep(0.1)
|
||||
assert cache.current_size_bytes == 0
|
||||
assert cache.current_size_bytes == 0
|
||||
assert alloc_process1.exitcode == 0, "allocation process 1 failed or did not finish, see stderr for details"
|
||||
assert alloc_process2.exitcode == 0, "allocation process 2 failed or did not finish, see stderr for details"
|
@ -1,224 +0,0 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers.cache_utils import DynamicCache
|
||||
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor
|
||||
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
|
||||
|
||||
from petals.server.block_utils import get_model_block
|
||||
from petals.utils.auto_config import AutoDistributedConfig
|
||||
from petals.utils.convert_block import QuantType, convert_block
|
||||
from test_utils import MODEL_NAME
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class UnoptimizedWrappedFalconBlock(FalconDecoderLayer):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
*args,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
alibi: Optional[torch.Tensor] = None,
|
||||
layer_past: Optional[KVCache] = None,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
batch_size, seq_length = hidden_states.shape[:2]
|
||||
|
||||
if layer_past is not None:
|
||||
layer_past = self._reorder_cache_from_bloom_to_falcon(layer_past)
|
||||
past_length = 0 if layer_past is None else layer_past[0].shape[1]
|
||||
seq_length_with_past = seq_length + past_length
|
||||
|
||||
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
|
||||
if alibi is None and self.config.alibi:
|
||||
alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
|
||||
attention_mask = FalconModel._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length)
|
||||
|
||||
outputs = super().forward(
|
||||
hidden_states,
|
||||
*args,
|
||||
attention_mask=attention_mask,
|
||||
alibi=alibi,
|
||||
layer_past=layer_past,
|
||||
use_cache=use_cache,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if use_cache:
|
||||
present_key_value = outputs[-1]
|
||||
present_key_value = self._reorder_cache_from_falcon_to_bloom(present_key_value)
|
||||
outputs = outputs[:-1] + (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
def _reorder_cache_from_bloom_to_falcon(self, key_value: KVCache) -> KVCache:
|
||||
key_states, value_states = key_value
|
||||
|
||||
key_states = key_states.permute(0, 2, 1)
|
||||
assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim]
|
||||
|
||||
if self.config.new_decoder_architecture:
|
||||
key_states = self._expand_states(key_states)
|
||||
value_states = self._expand_states(value_states)
|
||||
|
||||
return (key_states, value_states)
|
||||
|
||||
def _reorder_cache_from_falcon_to_bloom(self, key_value: KVCache) -> KVCache:
|
||||
key_states, value_states = key_value
|
||||
|
||||
if self.config.new_decoder_architecture:
|
||||
key_states = self._collapse_states(key_states)
|
||||
value_states = self._collapse_states(value_states)
|
||||
|
||||
assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim]
|
||||
key_states = key_states.permute(0, 2, 1)
|
||||
|
||||
return (key_states, value_states)
|
||||
|
||||
def _expand_states(self, state: torch.Tensor) -> torch.Tensor:
|
||||
batch_size_x_num_kv_heads, seq_len, head_dim = state.shape
|
||||
batch_size = batch_size_x_num_kv_heads // self.config.num_kv_heads
|
||||
|
||||
state = state.view(batch_size, self.config.num_kv_heads, 1, seq_len, head_dim)
|
||||
state = state.expand(-1, -1, self.config.num_key_value_groups, -1, -1) # No copy
|
||||
state = state.reshape(batch_size * self.config.num_attention_heads, seq_len, head_dim) # Involves a copy
|
||||
return state
|
||||
|
||||
def _collapse_states(self, state: torch.Tensor) -> torch.Tensor:
|
||||
batch_size_x_num_attn_heads, seq_len, head_dim = state.shape
|
||||
batch_size = batch_size_x_num_attn_heads // self.config.num_attention_heads
|
||||
|
||||
state = state.view(batch_size, self.config.num_kv_heads, self.config.num_key_value_groups, seq_len, head_dim)
|
||||
state = state[:, :, 0]
|
||||
state = state.view(batch_size * self.config.num_kv_heads, seq_len, head_dim)
|
||||
return state
|
||||
|
||||
|
||||
class UnoptimizedWrappedLlamaBlock(LlamaDecoderLayer):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
*args,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
batch_size, seq_length, _ = hidden_states.shape
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
past_key_value = layer_past
|
||||
if past_key_value is not None:
|
||||
past_key_values_length = past_key_value[0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length)
|
||||
elif use_cache:
|
||||
past_key_value = DynamicCache()
|
||||
|
||||
if position_ids is None:
|
||||
device = hidden_states.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
# embed positions
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
|
||||
)
|
||||
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
|
||||
)
|
||||
|
||||
outputs = super().forward(
|
||||
hidden_states,
|
||||
*args,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
use_cache=use_cache,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if use_cache:
|
||||
present_key_value = outputs[-1]
|
||||
present_key_value = self._reorder_cache_from_llama_to_bloom(
|
||||
present_key_value, batch_size, seq_length_with_past
|
||||
)
|
||||
outputs = outputs[:-1] + (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
def _reorder_cache_from_bloom_to_llama(
|
||||
self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
|
||||
) -> DynamicCache:
|
||||
key_states, value_states = key_value
|
||||
key_states = key_states.permute(0, 2, 1)
|
||||
key_states = key_states.view(
|
||||
batch_size, self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
|
||||
)
|
||||
value_states = value_states.view(*key_states.shape)
|
||||
past_key_values = ((key_states, value_states),)
|
||||
return DynamicCache.from_legacy_cache(past_key_values)
|
||||
|
||||
def _reorder_cache_from_llama_to_bloom(
|
||||
self, key_value: DynamicCache, batch_size: int, seq_length: int
|
||||
) -> Tuple[torch.Tensor]:
|
||||
key_states, value_states = key_value.to_legacy_cache()[0]
|
||||
value_states = value_states.view(
|
||||
batch_size * self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
|
||||
)
|
||||
key_states = key_states.view(*value_states.shape)
|
||||
key_states = key_states.permute(0, 2, 1)
|
||||
return (key_states, value_states)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cpu", "cuda:0"])
|
||||
@pytest.mark.forked
|
||||
def test_optimized_block(device):
|
||||
if device == "cuda:0" and not torch.cuda.is_available():
|
||||
pytest.skip("CUDA tests can be run only in CUDA-enabled setups")
|
||||
|
||||
config = AutoDistributedConfig.from_pretrained(MODEL_NAME)
|
||||
|
||||
tensor_parallel_devices = (device,)
|
||||
dtype = torch.bfloat16
|
||||
quant_type = QuantType.NONE
|
||||
|
||||
block_idx = 1
|
||||
block = get_model_block(config, layer_idx=block_idx).to(dtype)
|
||||
block = convert_block(block, block_idx, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
|
||||
|
||||
if config.model_type == "falcon":
|
||||
unopt_block = UnoptimizedWrappedFalconBlock(config).to(dtype)
|
||||
elif config.model_type == "llama":
|
||||
unopt_block = UnoptimizedWrappedLlamaBlock(config, layer_idx=0).to(dtype)
|
||||
else:
|
||||
pytest.skip(f"This test is not applicable to {config.model_type} models")
|
||||
|
||||
unopt_block = convert_block(
|
||||
unopt_block, block_idx, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True
|
||||
)
|
||||
|
||||
unopt_block.load_state_dict(block.state_dict())
|
||||
cache = unopt_cache = None
|
||||
|
||||
with torch.inference_mode():
|
||||
for length in [10, 1, 1, 1]:
|
||||
dummy_input = torch.randn(1, length, config.hidden_size, device=device, dtype=dtype)
|
||||
block_output, cache = block(dummy_input, layer_past=cache, use_cache=True)
|
||||
unopt_block_output, unopt_cache = unopt_block(dummy_input, layer_past=unopt_cache, use_cache=True)
|
||||
assert torch.allclose(block_output, unopt_block_output, atol=1e-6, rtol=0), length
|
||||
assert torch.allclose(cache[0], unopt_cache[0], atol=1e-6, rtol=0), length
|
||||
assert torch.allclose(cache[1], unopt_cache[1], atol=1e-6, rtol=0), length
|
Loading…
Reference in New Issue