|
|
@ -3,10 +3,11 @@ from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
|
|
from collections import Counter
|
|
|
|
from collections import Counter
|
|
|
|
from itertools import chain
|
|
|
|
from itertools import chain
|
|
|
|
from typing import Any, Dict, Sequence, Tuple
|
|
|
|
from typing import Any, Dict, Optional, Sequence, Tuple
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch
|
|
|
|
from hivemind import BatchTensorDescriptor, TensorDescriptor
|
|
|
|
from hivemind import BatchTensorDescriptor, TensorDescriptor
|
|
|
|
|
|
|
|
from hivemind.moe.expert_uid import ExpertUID
|
|
|
|
from hivemind.moe.server.module_backend import ModuleBackend
|
|
|
|
from hivemind.moe.server.module_backend import ModuleBackend
|
|
|
|
from hivemind.utils import get_logger
|
|
|
|
from hivemind.utils import get_logger
|
|
|
|
from tensor_parallel import TensorParallel
|
|
|
|
from tensor_parallel import TensorParallel
|
|
|
@ -15,7 +16,7 @@ from transformers import BloomConfig
|
|
|
|
from transformers.models.bloom.modeling_bloom import BloomAttention
|
|
|
|
from transformers.models.bloom.modeling_bloom import BloomAttention
|
|
|
|
|
|
|
|
|
|
|
|
from petals.data_structures import InferenceMetadata
|
|
|
|
from petals.data_structures import InferenceMetadata
|
|
|
|
from petals.server.memory_cache import MemoryCache
|
|
|
|
from petals.server.memory_cache import Handle, MemoryCache
|
|
|
|
from petals.server.task_pool import PrioritizedTaskPool
|
|
|
|
from petals.server.task_pool import PrioritizedTaskPool
|
|
|
|
from petals.utils.misc import is_dummy
|
|
|
|
from petals.utils.misc import is_dummy
|
|
|
|
|
|
|
|
|
|
|
@ -39,7 +40,7 @@ class TransformerBackend(ModuleBackend):
|
|
|
|
device = self.module.devices[self.module.output_device_index]
|
|
|
|
device = self.module.devices[self.module.output_device_index]
|
|
|
|
self.inference_pool = PrioritizedTaskPool(
|
|
|
|
self.inference_pool = PrioritizedTaskPool(
|
|
|
|
self.inference_step, max_batch_size=max_batch_size, device=device, name=f"{self.name}_inference"
|
|
|
|
self.inference_step, max_batch_size=max_batch_size, device=device, name=f"{self.name}_inference"
|
|
|
|
)
|
|
|
|
) # note: inference_pools may be merged later, see merge_inference_pools_inplace
|
|
|
|
self.forward_pool = PrioritizedTaskPool(
|
|
|
|
self.forward_pool = PrioritizedTaskPool(
|
|
|
|
self.forward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_forward"
|
|
|
|
self.forward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_forward"
|
|
|
|
)
|
|
|
|
)
|
|
|
@ -79,22 +80,20 @@ class TransformerBackend(ModuleBackend):
|
|
|
|
cache_tensors.extend((keys, values))
|
|
|
|
cache_tensors.extend((keys, values))
|
|
|
|
return cache_tensors
|
|
|
|
return cache_tensors
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
|
def inference_step(
|
|
|
|
def inference_step(
|
|
|
|
self,
|
|
|
|
self,
|
|
|
|
hidden_states: torch.Tensor,
|
|
|
|
hidden_states: torch.Tensor,
|
|
|
|
hypo_ids: torch.LongTensor,
|
|
|
|
hypo_ids: torch.LongTensor,
|
|
|
|
inference_info: InferenceMetadata,
|
|
|
|
inference_info: InferenceMetadata,
|
|
|
|
) -> Tuple[torch.Tensor, ...]:
|
|
|
|
) -> Tuple[torch.Tensor, ...]:
|
|
|
|
with torch.inference_mode():
|
|
|
|
assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
|
|
|
|
assert (
|
|
|
|
with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors:
|
|
|
|
hidden_states.ndim == 3
|
|
|
|
self._reorder_cache_inplace(cache_tensors, hypo_ids)
|
|
|
|
), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
|
|
|
|
layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length)
|
|
|
|
with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors:
|
|
|
|
hidden_states, new_kvs = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True)
|
|
|
|
self._reorder_cache_inplace(cache_tensors, hypo_ids)
|
|
|
|
self._update_cache_inplace(cache_tensors, new_kvs, inference_info.prefix_length)
|
|
|
|
layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length)
|
|
|
|
return (hidden_states,)
|
|
|
|
hidden_states, new_kvs = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True)
|
|
|
|
|
|
|
|
self._update_cache_inplace(cache_tensors, new_kvs, inference_info.prefix_length)
|
|
|
|
|
|
|
|
return (hidden_states,)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _reorder_cache_inplace(self, cache_tensors: torch.Tensor, hypo_ids: torch.Tensor):
|
|
|
|
def _reorder_cache_inplace(self, cache_tensors: torch.Tensor, hypo_ids: torch.Tensor):
|
|
|
|
"""If hypo_ids is specified, reorder elements of each cache tensor in-place by taking indices from hypo_ids"""
|
|
|
|
"""If hypo_ids is specified, reorder elements of each cache tensor in-place by taking indices from hypo_ids"""
|
|
|
@ -139,3 +138,39 @@ class TransformerBackend(ModuleBackend):
|
|
|
|
dummy = torch.tensor([])
|
|
|
|
dummy = torch.tensor([])
|
|
|
|
for p in self.module.parameters():
|
|
|
|
for p in self.module.parameters():
|
|
|
|
p.data = dummy
|
|
|
|
p.data = dummy
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def merge_inference_pools_inplace(backends: Dict[ExpertUID, TransformerBackend]):
|
|
|
|
|
|
|
|
"""Replace each backend's rpc_inference pools with a combined pool runs multiple blocks in one call"""
|
|
|
|
|
|
|
|
assert len(backends) != 0 and all(isinstance(b, TransformerBackend) for b in backends.values())
|
|
|
|
|
|
|
|
first_pool = next(iter(backends.values())).inference_pool
|
|
|
|
|
|
|
|
merged_pool = PrioritizedTaskPool(
|
|
|
|
|
|
|
|
_MergedInferenceStep(backends),
|
|
|
|
|
|
|
|
max_batch_size=first_pool.max_batch_size,
|
|
|
|
|
|
|
|
device=first_pool.device,
|
|
|
|
|
|
|
|
name=f"merged_inference",
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
for backend in backends.values():
|
|
|
|
|
|
|
|
assert not backend.inference_pool.is_alive()
|
|
|
|
|
|
|
|
backend.inference_pool = merged_pool
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _MergedInferenceStep:
|
|
|
|
|
|
|
|
def __init__(self, backends: Dict[ExpertUID, TransformerBackend]):
|
|
|
|
|
|
|
|
self.backends = backends
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __call__(
|
|
|
|
|
|
|
|
self,
|
|
|
|
|
|
|
|
hidden_states: torch.Tensor,
|
|
|
|
|
|
|
|
hypo_ids: torch.LongTensor,
|
|
|
|
|
|
|
|
inference_infos: Sequence[InferenceMetadata],
|
|
|
|
|
|
|
|
*optional_prompts: Optional[torch.Tensor],
|
|
|
|
|
|
|
|
) -> Tuple[torch.Tensor, ...]:
|
|
|
|
|
|
|
|
assert len(inference_infos) == len(
|
|
|
|
|
|
|
|
optional_prompts
|
|
|
|
|
|
|
|
), f"found {len(inference_infos)} blocks but {len(optional_prompts)} prompts"
|
|
|
|
|
|
|
|
for inference_info, optional_prompt in zip(inference_infos, optional_prompts):
|
|
|
|
|
|
|
|
if optional_prompt is not None:
|
|
|
|
|
|
|
|
hidden_states[:, : optional_prompt.shape[1]] += optional_prompt
|
|
|
|
|
|
|
|
(hidden_states,) = self.backends[inference_info.uid].inference_step(hidden_states, hypo_ids, inference_info)
|
|
|
|
|
|
|
|
return (hidden_states,)
|
|
|
|