from __future__ import annotations
from collections import Counter
from itertools import chain
from typing import Any, Dict, Optional, Sequence, Tuple, Union
import torch
from hivemind import BatchTensorDescriptor, TensorDescriptor
from import ExpertUID
from import ModuleBackend
from hivemind.utils import get_logger
from tensor_parallel import TensorParallel
from tensor_parallel.tensor_parallel import PerDeviceTensors
from transformers import PretrainedConfig
from petals.data_structures import InferenceMetadata
from petals.server.memory_cache import MemoryCache
from petals.server.task_pool import PrioritizedTaskPool
from petals.utils.misc import get_size_in_bytes, is_dummy
logger = get_logger(__name__)
class TransformerBackend(ModuleBackend):
"""A wrapper for a transformer block that can process requests for forward, backward and inference"""
_peft_module = None
def __init__(
block_index: int,
config: PretrainedConfig,
memory_cache: MemoryCache,
backend_dtype: torch.dtype,
max_chunk_size_bytes: int,
cache_dir: str,
max_disk_space: int,
import petals.utils.peft as _peft_module
self._peft_module = _peft_module
super().__init__(*args, **kwargs)
assert isinstance(self.module, TensorParallel)
self.block_index = block_index
self.config = config
self.memory_cache = memory_cache
self.max_chunk_size_bytes = max_chunk_size_bytes
self.cache_dir = cache_dir
self.max_disk_space = max_disk_space
for name, param in self.module.named_parameters():
assert not param.requires_grad, f"Block parameters must not accumulate gradients, but {name} does"
for name, buf in self.module.named_buffers():
assert not buf.requires_grad, f"Block parameters must not accumulate gradients, but {name} does"
max_batch_size = self.forward_pool.max_batch_size
device = self.module.devices[self.module.output_device_index]
self.inference_pool = PrioritizedTaskPool(
self.inference_step, max_batch_size=max_batch_size, device=device, name=f"{}_inference"
) # note: inference_pools may be merged later, see merge_inference_pools_inplace
self.forward_pool = PrioritizedTaskPool(
self.forward, max_batch_size=max_batch_size, device=device, name=f"{}_forward"
self.backward_pool = PrioritizedTaskPool(
self.backward, max_batch_size=max_batch_size, device=device, name=f"{}_backward"
self.dtype = backend_dtype
self.dtype_bytes = get_size_in_bytes(self.dtype)
self.shard_num_heads = []
for shard in self.module.module_shards:
for submodule in shard.modules():
if isinstance(submodule, config.attn_class):
assert len(self.shard_num_heads) == len(self.module.devices)
assert sum(self.shard_num_heads) == config.num_attention_heads
self.inference_schema = (
BatchTensorDescriptor((), dtype=self.dtype),
BatchTensorDescriptor((), dtype=torch.int64),
self.cache_bytes_per_token: Dict[torch.device, int] = Counter()
for descr in self.get_inference_cache_descriptors(batch_size=1, max_length=1):
self.cache_bytes_per_token[descr.device] += descr.numel() * get_size_in_bytes(descr.dtype)
def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> Sequence[TensorDescriptor]:
"""Create tensor descriptors for attention cache tensors used during inference_step"""
head_dim = self.config.hidden_size // self.config.num_attention_heads
cache_tensors = []
for device, num_heads in zip(self.module.devices, self.shard_num_heads):
num_heads //= self.config.num_key_value_groups
keys = TensorDescriptor((batch_size, num_heads, head_dim, max_length), dtype=self.dtype, device=device)
values = TensorDescriptor((batch_size, num_heads, max_length, head_dim), dtype=self.dtype, device=device)
cache_tensors.extend((keys, values))
return cache_tensors
def forward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]:
*inputs, active_adapter = inputs
with self._peft_module.using_adapter(active_adapter):
return super().forward(*inputs)
def backward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]:
*inputs, active_adapter = inputs
with self._peft_module.using_adapter(active_adapter):
return super().backward(*inputs)
def inference_step(
hidden_states: torch.Tensor,
hypo_ids: torch.LongTensor,
inference_info: InferenceMetadata,
) -> Tuple[torch.Tensor, ...]:
assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
seq_len = hidden_states.shape[1]
with self.memory_cache.use_cache(
) as cache_tensors, self._peft_module.using_adapter(inference_info.active_adapter):
self._reorder_cache_inplace(cache_tensors, hypo_ids)
# We chunk the inputs so that peak memory for long sequences fits into `autograd_memory`
# reserved in `Server._choose_num_blocks()`. This saves us from OOMs if `max_chunk_size_bytes`
# is at least 4-6x less than `autograd_memory`.
max_chunk_length = self._estimate_max_chunk_length(hidden_states, inference_info)
output_hidden_states = torch.empty_like(hidden_states) if seq_len > max_chunk_length else None
layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length)
for offset in range(0, seq_len, max_chunk_length):
hidden_states_chunk = hidden_states[:, offset : offset + max_chunk_length, :]
output_hidden_states_chunk, new_kvs = self.module.forward(
hidden_states_chunk, layer_past=layer_past, use_cache=True
if seq_len > max_chunk_length:
output_hidden_states[:, offset : offset + max_chunk_length] = output_hidden_states_chunk
output_hidden_states = output_hidden_states_chunk # saves one memcopy
layer_past = new_kvs
self._update_cache_inplace(cache_tensors, new_kvs, inference_info.prefix_length)
return (output_hidden_states,)
def _estimate_max_chunk_length(self, hidden_states: torch.Tensor, inference_info: InferenceMetadata) -> int:
# We assume that attention logit matrices are the main thing that consumes memory, given that
# the model uses multi-query attention
batch_size, seq_length, hidden_size = hidden_states.shape
worst_case_length = inference_info.prefix_length + seq_length
attn_bytes_per_token = max(self.shard_num_heads) * batch_size * self.dtype_bytes * worst_case_length
return max(1, self.max_chunk_size_bytes // attn_bytes_per_token)
def _reorder_cache_inplace(self, cache_tensors: torch.Tensor, hypo_ids: torch.Tensor):
"""If hypo_ids is specified, reorder elements of each cache tensor in-place by taking indices from hypo_ids"""
if not is_dummy(hypo_ids):
for cache_tensor in cache_tensors:
cache_tensor[...] = cache_tensor[] # in-place reorder cache by hypo ids
def _select_layer_past(self, cache_tensors: Sequence[torch.Tensor], prefix_length: int) -> Sequence[torch.Tensor]:
"""Extract first {prefix_length} tokens and reshape them such that they can be used as layer_past"""
key_cache, value_cache = list(cache_tensors[0::2]), list(cache_tensors[1::2])
for i in range(len(key_cache)):
key_cache[i] = key_cache[i].flatten(0, 1)[:, :, :prefix_length]
# shape: [batch * num_kv_heads, head_dim, kv_length]
value_cache[i] = value_cache[i].flatten(0, 1)[:, :prefix_length]
# shape: [batch * num_kv_heads, kv_length, head_dim]
layer_past = tuple(chain(*zip(key_cache, value_cache)))
return PerDeviceTensors(*layer_past) if len(self.module.module_shards) > 1 else layer_past
def _update_cache_inplace(
self, cache_tensors: Sequence[torch.Tensor], new_kvs: Sequence[torch.Tensor], prefix_length: int
"""Writes new key/value tensors back into cache, works in-place"""
_batch_size_times_num_kv_heads, head_dim, new_length = new_kvs[0].shape
for cache_key, new_key in zip(cache_tensors[0::2], new_kvs[0::2]):
new_key = new_key.view(*cache_key.shape[:3], new_length)
cache_key[:, :, :, prefix_length:new_length] = new_key[:, :, :, prefix_length:new_length]
for cache_value, new_value in zip(cache_tensors[1::2], new_kvs[1::2]):
new_value = new_value.view(*cache_value.shape[:2], new_length, head_dim)
cache_value[:, :, prefix_length:new_length, :] = new_value[:, :, prefix_length:new_length, :]
def get_pools(self) -> Sequence[PrioritizedTaskPool]:
return self.forward_pool, self.backward_pool, self.inference_pool
def get_info(self) -> Dict[str, Any]:
"""Get module parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration."""
return dict(super().get_info(), inference_schema=self.inference_schema)
def shutdown(self):
# Break the cyclic references, otherwise TransformerBackend may be not garbage-collected
self.forward_pool = self.backward_pool = self.inference_pool = None
# Explicitly free the GPU memory. This is not necessary at the time this code is written,
# but may help to avoid future issues when the module is not garbage-collected for some reasons
dummy = torch.tensor([])
for p in self.module.parameters(): = dummy
def merge_inference_pools_inplace(backends: Dict[ExpertUID, TransformerBackend]):
"""Replace each backend's rpc_inference pools with a combined pool runs multiple blocks in one call"""
assert len(backends) != 0 and all(isinstance(b, TransformerBackend) for b in backends.values())
first_pool = next(iter(backends.values())).inference_pool
merged_pool = PrioritizedTaskPool(
for backend in backends.values():
assert not backend.inference_pool.is_alive()
backend.inference_pool = merged_pool
class _MergedInferenceStep:
def __init__(self, backends: Dict[ExpertUID, TransformerBackend]):
self.backends = backends
def __call__(
hidden_states: torch.Tensor,
hypo_ids: torch.LongTensor,
inference_infos: Sequence[InferenceMetadata],
*optional_prompts: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, ...]:
assert len(inference_infos) == len(
), f"found {len(inference_infos)} blocks but {len(optional_prompts)} prompts"
for inference_info, optional_prompt in zip(inference_infos, optional_prompts):
if optional_prompt is not None:
hidden_states[:, : optional_prompt.shape[1]] += optional_prompt
(hidden_states,) = self.backends[inference_info.uid].inference_step(hidden_states, hypo_ids, inference_info)
return (hidden_states,)