benchmark

yozh-dev-branch
Just Heuristic 1 year ago
parent fa5ac6e3b4
commit c2e3c13241

@ -45,26 +45,7 @@ def load_pretrained_block(
cache_dir = DEFAULT_CACHE_DIR
block = WrappedBloomBlock(config)
state_dict = _load_state_dict(
converted_model_name_or_path,
block_index,
config,
use_auth_token=use_auth_token,
cache_dir=cache_dir,
max_disk_space=max_disk_space,
)
if torch_dtype == "auto":
with torch.no_grad():
for name, param in block.named_parameters():
assert name in state_dict, f"{name} not in state dict"
param.data = param.data.to(state_dict[name].dtype)
else:
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
block = block.to(dtype=torch_dtype)
report = block.load_state_dict(state_dict, strict=True)
logger.info(f"Loaded {converted_model_name_or_path} block {block_index}, {report}")
logger.info(f"Loaded {converted_model_name_or_path} block {block_index}, DEBUG NOT ACTUAL WEIGHTS!")
return block

@ -1,9 +1,5 @@
PUBLIC_INITIAL_PEERS = [
"/dns/bootstrap1.petals.ml/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY",
"/dns6/bootstrap1.petals.ml/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY",
"/dns/bootstrap2.petals.ml/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5",
"/dns6/bootstrap2.petals.ml/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5",
]
# The reachability API is currently used only when connecting to the public swarm
REACHABILITY_API_URL = "http://health.petals.ml"
REACHABILITY_API_URL = "REMOVED"

@ -6,6 +6,7 @@ from enum import Enum
from typing import Any, Dict, Tuple
from hivemind import PeerID
from hivemind.moe.expert_uid import ExpertUID
from petals.server.memory_cache import Handle
@ -48,5 +49,6 @@ RPCInfo = Dict[str, Any]
@dataclasses.dataclass(frozen=True)
class InferenceMetadata:
uid: ExpertUID
prefix_length: int
cache_handles: Tuple[Handle, ...]

@ -3,10 +3,11 @@ from __future__ import annotations
from collections import Counter
from itertools import chain
from typing import Any, Dict, Sequence, Tuple
from typing import Any, Dict, Optional, Sequence, Tuple
import torch
from hivemind import BatchTensorDescriptor, TensorDescriptor
from hivemind.moe.expert_uid import ExpertUID
from hivemind.moe.server.module_backend import ModuleBackend
from hivemind.utils import get_logger
from tensor_parallel import TensorParallel
@ -15,7 +16,7 @@ from transformers import BloomConfig
from transformers.models.bloom.modeling_bloom import BloomAttention
from petals.data_structures import InferenceMetadata
from petals.server.memory_cache import MemoryCache
from petals.server.memory_cache import Handle, MemoryCache
from petals.server.task_pool import PrioritizedTaskPool
from petals.utils.misc import is_dummy
@ -39,7 +40,7 @@ class TransformerBackend(ModuleBackend):
device = self.module.devices[self.module.output_device_index]
self.inference_pool = PrioritizedTaskPool(
self.inference_step, max_batch_size=max_batch_size, device=device, name=f"{self.name}_inference"
)
) # note: inference_pools may be merged later, see merge_inference_pools_inplace
self.forward_pool = PrioritizedTaskPool(
self.forward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_forward"
)
@ -69,6 +70,21 @@ class TransformerBackend(ModuleBackend):
for descr in self.get_inference_cache_descriptors(batch_size=1, max_length=1):
self.cache_bytes_per_token[descr.device] += descr.numel() * torch.finfo(descr.dtype).bits // 8
@staticmethod
def merge_inference_pools_inplace(backends: Dict[ExpertUID, TransformerBackend]):
"""Replace each backend's rpc_inference pools with a combined pool runs multiple blocks in one call"""
assert len(backends) != 0 and all(isinstance(b, TransformerBackend) for b in backends.values())
first_pool = next(iter(backends.values())).inference_pool
merged_pool = PrioritizedTaskPool(
_MergedInferenceStep(backends),
max_batch_size=first_pool.max_batch_size,
device=first_pool.device,
name=f"merged_inference",
)
for backend in backends.values():
assert not backend.inference_pool.is_alive()
backend.inference_pool = merged_pool
def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> Sequence[TensorDescriptor]:
"""Create tensor descriptors for attention cache tensors used during inference_step"""
head_dim = self.config.hidden_size // self.config.n_head
@ -79,22 +95,20 @@ class TransformerBackend(ModuleBackend):
cache_tensors.extend((keys, values))
return cache_tensors
@torch.inference_mode()
def inference_step(
self,
hidden_states: torch.Tensor,
hypo_ids: torch.LongTensor,
inference_info: InferenceMetadata,
) -> Tuple[torch.Tensor, ...]:
with torch.inference_mode():
assert (
hidden_states.ndim == 3
), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors:
self._reorder_cache_inplace(cache_tensors, hypo_ids)
layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length)
hidden_states, new_kvs = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True)
self._update_cache_inplace(cache_tensors, new_kvs, inference_info.prefix_length)
return (hidden_states,)
assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors:
self._reorder_cache_inplace(cache_tensors, hypo_ids)
layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length)
hidden_states, new_kvs = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True)
self._update_cache_inplace(cache_tensors, new_kvs, inference_info.prefix_length)
return (hidden_states,)
def _reorder_cache_inplace(self, cache_tensors: torch.Tensor, hypo_ids: torch.Tensor):
"""If hypo_ids is specified, reorder elements of each cache tensor in-place by taking indices from hypo_ids"""
@ -139,3 +153,24 @@ class TransformerBackend(ModuleBackend):
dummy = torch.tensor([])
for p in self.module.parameters():
p.data = dummy
class _MergedInferenceStep:
def __init__(self, backends: Dict[ExpertUID, TransformerBackend]):
self.backends = backends
def __call__(
self,
hidden_states: torch.Tensor,
hypo_ids: torch.LongTensor,
inference_infos: Sequence[InferenceMetadata],
*optional_prompts: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, ...]:
assert len(inference_infos) == len(
optional_prompts
), f"found {len(inference_infos)} blocks but {len(optional_prompts)} prompts"
for inference_info, optional_prompt in zip(inference_infos, optional_prompts):
if optional_prompt is not None:
hidden_states[:, : optional_prompt.shape[1]] += optional_prompt
(hidden_states,) = self.backends[inference_info.uid].inference_step(hidden_states, hypo_ids, inference_info)
return (hidden_states,)

@ -141,10 +141,11 @@ class TransformerConnectionHandler(ConnectionHandler):
assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}"
# parse deep prompts (optional argument)
if prompts is None or is_dummy(prompts) or is_dummy(prompts):
prompts = [DUMMY] * len(requested_backends)
if prompts is None or is_dummy(prompts):
prompts = [None] * len(requested_backends)
else:
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
prompts = [prompt if not is_dummy(prompt) else None for prompt in prompts]
if not (len(requested_backends) == len(prompts)):
raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends")
@ -156,33 +157,26 @@ class TransformerConnectionHandler(ConnectionHandler):
f" exceeds pre-allocated maximum {max_length}"
)
# run request tensors through all requested modules, update caches
for backend, backend_cache_handles, prompt in zip(requested_backends, cache_handles, prompts):
if not is_dummy(prompt):
hidden_states[:, : prompt.shape[1]] += prompt
if hidden_states.numel() == 0:
continue # user passed a tensor with 0 tokens. This is a special case that occurs, e.g.
# when user wants to pre-allocate cache or check that server *can* allocate that cache
metadata = InferenceMetadata(prefix_length, tuple(backend_cache_handles))
assert isinstance(
hidden_states, torch.Tensor
), f"hidden states must be tensor, got {type(hidden_states)}"
assert (
hidden_states.ndim == 3
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
assert isinstance(
backend.inference_pool, PrioritizedTaskPool
), "petals support only prioritized pools"
priority = self._prioritizer.prioritize(
hidden_states,
hypo_ids,
points=point_per_piece / len(requested_backends),
backend=backend,
type="inference",
)
(hidden_states,) = await backend.inference_pool.submit_task(
hidden_states, hypo_ids, metadata, priority=priority
priority = self._prioritizer.prioritize(
hidden_states,
hypo_ids,
points=point_per_piece,
requested_uids=requested_uids,
type="inference",
)
inference_infos = tuple(
InferenceMetadata(uid, prefix_length, tuple(handles))
for uid, handles in zip(requested_uids, cache_handles)
)
if hidden_states.numel() == 0:
pass # user passed a tensor with 0 tokens. This is a special case that occurs, e.g.
# when user wants to pre-allocate cache or check that server *can* allocate that cache
else:
assert hidden_states.ndim == 3, f"hidden states must be a single 3d tensor"
(hidden_states,) = await self.module_backends[requested_uids[0]].inference_pool.submit_task(
hidden_states, hypo_ids, inference_infos, *prompts, priority=priority
)
# serialize and send last layer outputs
@ -444,7 +438,6 @@ async def _rpc_forward(
hidden_states.ndim == 3
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
# Serialize the overall output
return hidden_states

@ -453,11 +453,12 @@ class ModuleContainer(threading.Thread):
joining_announcer.stop.set()
joining_announcer.join()
TransformerBackend.merge_inference_pools_inplace(blocks)
return cls(
dht,
blocks,
throughput=throughput,
device=device,
update_period=update_period,
expiration=expiration,
**kwargs,
@ -476,7 +477,6 @@ class ModuleContainer(threading.Thread):
request_timeout: float,
session_timeout: float,
step_timeout: float,
device: Union[str, torch.device],
start: bool,
**kwargs,
):
@ -495,7 +495,7 @@ class ModuleContainer(threading.Thread):
)
for _ in range(num_handlers)
]
self.runtime = Runtime(self.module_backends, device=None, **kwargs)
self.runtime = RuntimeWithDeduplicatedPools(self.module_backends, device=None, **kwargs)
# note: We set device=None in runtime to avoid moving all modules to device 0 in runtime.run(). tensor_parallel has already moved it as needed.
self.online_announcer = ModuleAnnouncerThread(
list(self.module_backends.keys()),
@ -633,3 +633,11 @@ class ModuleAnnouncerThread(threading.Thread):
)
if self.stop.wait(self.update_period):
break
class RuntimeWithDeduplicatedPools(Runtime):
"""A version of hivemind.moe.server.runtime.Runtime that allows multiple backends to reuse a task pool"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.pools = tuple(set(self.pools))

@ -16,4 +16,6 @@ class DummyTaskPrioritizer(TaskPrioritizerBase):
"""Simple implementation of TaskPrioritizer which gives constant zero priority for every task"""
def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float:
return 0.0
if kwargs.get("type") == "inference":
return 1.0 # inference steps go first since they are more latency-sensitive
return 2.0 # forward, backward

Loading…
Cancel
Save