You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
petals/src/petals/server/backend.py

203 lines
10 KiB
Python

from __future__ import annotations
from collections import Counter
from itertools import chain
from typing import Any, Dict, Optional, Sequence, Tuple, Union
import peft
import torch
from hivemind import BatchTensorDescriptor, TensorDescriptor
from hivemind.moe.expert_uid import ExpertUID
from hivemind.moe.server.module_backend import ModuleBackend
from hivemind.utils import get_logger
from tensor_parallel import TensorParallel
from tensor_parallel.tensor_parallel import PerDeviceTensors
from transformers import PretrainedConfig
from petals.data_structures import InferenceMetadata
from petals.server.memory_cache import MemoryCache
from petals.server.task_pool import PrioritizedTaskPool
from petals.utils.misc import is_dummy
logger = get_logger(__name__)
class TransformerBackend(ModuleBackend):
"""A wrapper for a transformer block that can process requests for forward, backward and inference"""
def __init__(
self, *args, config: PretrainedConfig, memory_cache: MemoryCache, backend_dtype: torch.dtype, **kwargs
):
super().__init__(*args, **kwargs)
assert isinstance(self.module, TensorParallel)
self.config = config
self.memory_cache = memory_cache
for name, param in self.module.named_parameters():
assert not param.requires_grad, f"Block parameters must not accumulate gradients, but {name} does"
for name, buf in self.module.named_buffers():
assert not buf.requires_grad, f"Block parameters must not accumulate gradients, but {name} does"
max_batch_size = self.forward_pool.max_batch_size
device = self.module.devices[self.module.output_device_index]
self.inference_pool = PrioritizedTaskPool(
self.inference_step, max_batch_size=max_batch_size, device=device, name=f"{self.name}_inference"
) # note: inference_pools may be merged later, see merge_inference_pools_inplace
self.forward_pool = PrioritizedTaskPool(
self.forward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_forward"
)
self.backward_pool = PrioritizedTaskPool(
self.backward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_backward"
)
self.dtype = backend_dtype
self.shard_num_heads = []
for shard in self.module.module_shards:
for submodule in shard.modules():
if isinstance(submodule, config.attn_class):
self.shard_num_heads.append(submodule.num_heads)
assert len(self.shard_num_heads) == len(self.module.devices)
assert sum(self.shard_num_heads) == config.num_attention_heads
self.inference_schema = (
(
*self.args_schema,
BatchTensorDescriptor((), dtype=self.dtype),
BatchTensorDescriptor((), dtype=torch.int64),
),
self.kwargs_schema,
)
self.cache_bytes_per_token: Dict[torch.device, int] = Counter()
for descr in self.get_inference_cache_descriptors(batch_size=1, max_length=1):
self.cache_bytes_per_token[descr.device] += descr.numel() * torch.finfo(descr.dtype).bits // 8
def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> Sequence[TensorDescriptor]:
"""Create tensor descriptors for attention cache tensors used during inference_step"""
head_dim = self.config.hidden_size // self.config.num_attention_heads
cache_tensors = []
for device, num_heads in zip(self.module.devices, self.shard_num_heads):
keys = TensorDescriptor((batch_size, num_heads, head_dim, max_length), dtype=self.dtype, device=device)
values = TensorDescriptor((batch_size, num_heads, max_length, head_dim), dtype=self.dtype, device=device)
cache_tensors.extend((keys, values))
return cache_tensors
def forward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]:
*inputs, active_adapter = inputs
if not self.load_adapter_(active_adapter):
raise KeyError(f"Could not find adapter {active_adapter}; perhaps it is not loaded")
return super().forward(*inputs)
def backward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]:
*inputs, active_adapter = inputs
if not self.load_adapter_(active_adapter):
raise KeyError(f"Could not find adapter {active_adapter}; perhaps it is not loaded")
return super().backward(*inputs)
@torch.inference_mode()
def inference_step(
self,
hidden_states: torch.Tensor,
hypo_ids: torch.LongTensor,
inference_info: InferenceMetadata,
) -> Tuple[torch.Tensor, ...]:
assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
if not self.load_adapter_(inference_info.active_adapter):
raise KeyError(f"Could not find adapter {inference_info.active_adapter}; perhaps it is not loaded")
with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors:
self._reorder_cache_inplace(cache_tensors, hypo_ids)
layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length)
hidden_states, new_kvs = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True)
self._update_cache_inplace(cache_tensors, new_kvs, inference_info.prefix_length)
return (hidden_states,)
def _reorder_cache_inplace(self, cache_tensors: torch.Tensor, hypo_ids: torch.Tensor):
"""If hypo_ids is specified, reorder elements of each cache tensor in-place by taking indices from hypo_ids"""
if not is_dummy(hypo_ids):
for cache_tensor in cache_tensors:
cache_tensor[...] = cache_tensor[hypo_ids.to(cache_tensor.device)] # in-place reorder cache by hypo ids
def _select_layer_past(self, cache_tensors: Sequence[torch.Tensor], prefix_length: int) -> Sequence[torch.Tensor]:
"""Extract first {prefix_length} tokens and reshape them such that they can be used as layer_past"""
key_cache, value_cache = list(cache_tensors[0::2]), list(cache_tensors[1::2])
for i in range(len(key_cache)):
key_cache[i] = key_cache[i].flatten(0, 1)[:, :, :prefix_length] # [batch * num_heads, head_dim, kv_length]
value_cache[i] = value_cache[i].flatten(0, 1)[:, :prefix_length] # [batch * num_heads, kv_length, head_dim]
layer_past = tuple(chain(*zip(key_cache, value_cache)))
return PerDeviceTensors(*layer_past) if len(self.module.module_shards) > 1 else layer_past
def _update_cache_inplace(
self, cache_tensors: Sequence[torch.Tensor], new_kvs: Sequence[torch.Tensor], prefix_length: int
):
"""Writes new key/value tensors back into cache, works in-place"""
_batch_size_times_num_heads, head_dim, new_length = new_kvs[0].shape
for cache_key, new_key in zip(cache_tensors[0::2], new_kvs[0::2]):
new_key = new_key.view(*cache_key.shape[:3], new_length)
cache_key[:, :, :, prefix_length:new_length] = new_key[:, :, :, prefix_length:new_length]
for cache_value, new_value in zip(cache_tensors[1::2], new_kvs[1::2]):
new_value = new_value.view(*cache_value.shape[:2], new_length, head_dim)
cache_value[:, :, prefix_length:new_length, :] = new_value[:, :, prefix_length:new_length, :]
def get_pools(self) -> Sequence[PrioritizedTaskPool]:
return self.forward_pool, self.backward_pool, self.inference_pool
def get_info(self) -> Dict[str, Any]:
"""Get module parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration."""
return dict(super().get_info(), inference_schema=self.inference_schema)
def shutdown(self):
# Break the cyclic references, otherwise TransformerBackend may be not garbage-collected
self.forward_pool = self.backward_pool = self.inference_pool = None
# Explicitly free the GPU memory. This is not necessary at the time this code is written,
# but may help to avoid future issues when the module is not garbage-collected for some reasons
dummy = torch.tensor([])
for p in self.module.parameters():
p.data = dummy
def load_adapter_(self, active_adapter: Optional[str] = None) -> bool:
"""Activate a given adapter set if available. Return True if available (or no adapter), False if missing"""
adapter_was_loaded = False
for layer in self.module.modules(): # select adapter set -- leave empty string for no adapter
if isinstance(layer, (peft.tuners.lora.Linear, peft.tuners.lora.Linear8bitLt, peft.tuners.lora.Linear4bit)):
layer.active_adapter = active_adapter # empty string for no adapter
if active_adapter in layer.lora_A.keys():
adapter_was_loaded = True
return adapter_was_loaded or not active_adapter
def merge_inference_pools_inplace(backends: Dict[ExpertUID, TransformerBackend]):
"""Replace each backend's rpc_inference pools with a combined pool runs multiple blocks in one call"""
assert len(backends) != 0 and all(isinstance(b, TransformerBackend) for b in backends.values())
first_pool = next(iter(backends.values())).inference_pool
merged_pool = PrioritizedTaskPool(
_MergedInferenceStep(backends),
max_batch_size=first_pool.max_batch_size,
device=first_pool.device,
name=f"merged_inference",
)
for backend in backends.values():
assert not backend.inference_pool.is_alive()
backend.inference_pool = merged_pool
class _MergedInferenceStep:
def __init__(self, backends: Dict[ExpertUID, TransformerBackend]):
self.backends = backends
@torch.inference_mode()
def __call__(
self,
hidden_states: torch.Tensor,
hypo_ids: torch.LongTensor,
inference_infos: Sequence[InferenceMetadata],
*optional_prompts: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, ...]:
assert len(inference_infos) == len(
optional_prompts
), f"found {len(inference_infos)} blocks but {len(optional_prompts)} prompts"
for inference_info, optional_prompt in zip(inference_infos, optional_prompts):
if optional_prompt is not None:
hidden_states[:, : optional_prompt.shape[1]] += optional_prompt
(hidden_states,) = self.backends[inference_info.uid].inference_step(hidden_states, hypo_ids, inference_info)
return (hidden_states,)