Merge branch 'memcache_touchup' of github.com:bigscience-workshop/petals into memcache_touchup

pull/434/head
Your Name 10 months ago
commit 458cf3339b

@ -11,7 +11,7 @@ from petals.models import *
from petals.utils import *
from petals.utils.logging import initialize_logs as _initialize_logs
__version__ = "2.0.1.post1"
__version__ = "2.0.1.post2"
if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):

@ -1,4 +1,4 @@
from petals.client.config import ClientConfig
from petals.client.inference_session import InferenceSession
from petals.client.remote_sequential import RemoteSequential
from petals.client.routing.sequence_manager import RemoteSequenceManager
from petals.client.routing.spending_policy import NoSpendingPolicy, SpendingPolicyBase
from petals.client.routing import NoSpendingPolicy, RemoteSequenceManager, SpendingPolicyBase

@ -0,0 +1,31 @@
import dataclasses
from typing import Optional, Sequence, Union
from hivemind import PeerID
from petals.constants import PUBLIC_INITIAL_PEERS
@dataclasses.dataclass
class ClientConfig:
initial_peers: Sequence[str] = tuple(PUBLIC_INITIAL_PEERS) # a list of initial peers for hivemind DHT
dht_prefix: Optional[str] = None # a prefix for all dht keys that correspond to this model (default: model name)
daemon_startup_timeout: int = 60 # timeout for the libp2p daemon connecting to initial peers
show_route: Union[str, bool] = "inference" # show chosen route through servers. one of [False, "inference", True]
allowed_servers: Optional[Sequence[Union[PeerID, str]]] = None # if defined, send requests only to these servers
blocked_servers: Optional[Sequence[Union[PeerID, str]]] = None # if defined, do not use these servers
use_server_to_server: bool = True # Use direct server-to-server communication
connect_timeout: float = 5 # timeout for opening a connection
request_timeout: float = 3 * 60 # timeout for forward/backward/inference requests
update_period: float = 60 # refresh DHT information once in this many seconds
max_retries: Optional[int] = None # max number retries before the client raises an exception (default: inf)
min_backoff: float = 1 # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1)
max_backoff: float = 60 # limit maximal sleep time between retries to this value
ban_timeout: float = 15 # when a remote peer fails to respond, prevent routing to that peer for this many seconds
active_adapter: Optional[str] = None # name of active LoRA adapter (usually, Hugging Face repo)
max_pinged: int = 3 # max servers to ping from each sequence side, per update
ping_timeout: float = 2 # max time to wait for pings, per update

@ -7,22 +7,18 @@ import uuid
from typing import AsyncIterator, List, Optional, Tuple
import torch
from hivemind import (
MSGPackSerializer,
anext,
deserialize_torch_tensor,
get_logger,
nested_flatten,
serialize_torch_tensor,
)
from hivemind import MSGPackSerializer, anext, deserialize_torch_tensor, get_logger, serialize_torch_tensor
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.p2p import P2P
from hivemind.proto import runtime_pb2
from hivemind.utils.tensor_descr import BatchTensorDescriptor
from petals.client.routing.sequence_manager import RemoteSequenceManager, SequenceManagerConfig, maybe_log_traceback
from petals.client.config import ClientConfig
from petals.client.routing import RemoteSequenceManager, maybe_log_traceback
from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
from petals.server.handler import TransformerConnectionHandler
from petals.utils.misc import DUMMY, is_dummy
from petals.utils.misc import DUMMY, DUMMY_INT64, is_dummy
from petals.utils.packaging import pack_args_kwargs
logger = get_logger(__name__)
@ -36,7 +32,7 @@ class _ServerInferenceSession:
def __init__(
self,
config: SequenceManagerConfig,
config: ClientConfig,
span: RemoteSpanInfo,
uid: ModuleUID,
rpc_info: RPCInfo,
@ -63,7 +59,7 @@ class _ServerInferenceSession:
@classmethod
async def create(
cls,
config: SequenceManagerConfig,
config: ClientConfig,
p2p: P2P,
span: RemoteSpanInfo,
uid: ModuleUID,
@ -128,13 +124,13 @@ class _ServerInferenceSession:
assert prompts.shape[3] == inputs.shape[2]
if hypo_ids is None or is_dummy(hypo_ids):
hypo_ids = DUMMY
hypo_ids = DUMMY_INT64
else:
assert len(hypo_ids) == len(inputs)
assert hypo_ids.dtype == torch.int64
# serialize inputs and put them into the queue
input_tensors = (inputs, prompts, hypo_ids)
input_tensors, args_structure = pack_args_kwargs(inputs, prompts, hypo_ids)
request_metadata = dict(session_id=self.session_id, step_id=step_id)
if not self.stepped:
@ -144,13 +140,25 @@ class _ServerInferenceSession:
if next_servers:
request_metadata["next_servers"] = next_servers
request_metadata["args_structure"] = args_structure
# TODO: make possible to use different compression method for different tensors
server_side_inference_schema, kwargs_schema = self.rpc_info["inference_schema"]
compression = server_side_inference_schema[0].compression
inference_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in input_tensors)
# TODO: create more explicit way to check servers schema and client's structure
assert len(input_tensors) >= len(
server_side_inference_schema
), "Hidden_state, prompts and hypo_ids tensors are necessary for an inference step"
outputs_serialized = RemoteExpertWorker.run_coroutine(
self._step(
runtime_pb2.ExpertRequest(
uid=self.uid,
tensors=[
serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
for tensor, proto in zip(input_tensors, nested_flatten(self.rpc_info["inference_schema"]))
for tensor, proto in zip(input_tensors, inference_schema)
],
metadata=MSGPackSerializer.dumps(request_metadata),
)

@ -12,13 +12,14 @@ from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, MAX_U
from hivemind.proto import runtime_pb2
from hivemind.utils.asyncio import aiter_with_timeout, iter_as_aiter
from hivemind.utils.streaming import split_for_streaming
from hivemind.utils.tensor_descr import BatchTensorDescriptor
from petals.client.routing.sequence_manager import SequenceManagerConfig
from petals.client.config import ClientConfig
from petals.data_structures import ModuleUID, RPCInfo
async def _forward_unary(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: SequenceManagerConfig, **kwargs
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
) -> List[torch.Tensor]:
outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs),
@ -28,7 +29,7 @@ async def _forward_unary(
async def _backward_unary(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: SequenceManagerConfig, **kwargs
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
) -> List[torch.Tensor]:
grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs),
@ -38,7 +39,7 @@ async def _backward_unary(
async def _forward_stream(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: SequenceManagerConfig, **kwargs
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
) -> List[torch.Tensor]:
parts = (
runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)
@ -51,7 +52,7 @@ async def _forward_stream(
async def _backward_stream(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: SequenceManagerConfig, **kwargs
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
) -> List[torch.Tensor]:
parts = (
runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)
@ -68,7 +69,7 @@ async def run_remote_forward(
stub: StubBase,
rpc_info: RPCInfo,
*inputs: torch.Tensor,
config: SequenceManagerConfig,
config: ClientConfig,
metadata: Optional[bytes] = None,
**kwargs,
) -> Tuple[torch.Tensor, ...]:
@ -84,26 +85,20 @@ async def run_remote_forward(
kwargs = {key: kwargs[key] for key in rpc_info["keyword_names"]}
# Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
forward_inputs = (inputs, kwargs)
# Modify forward_schema to support prompts
forward_inputs = tuple(nested_flatten((inputs, kwargs)))
args_schema, kwargs_schema = rpc_info["forward_schema"]
# TODO: rm this assert when support arbitrary number of input tensors
assert len(args_schema) == 1 and len(inputs) == 2
forward_schema_with_prompts = (tuple(args_schema * len(inputs)), kwargs_schema)
if not nested_compare(forward_inputs, forward_schema_with_prompts):
raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
forward_inputs = nested_flatten(forward_inputs)
compression = args_schema[0].compression
forward_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in forward_inputs)
inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs)
# TODO: create more explicit way to check servers schema and client's structure
assert len(inputs) >= len(args_schema) + 1, "Inputs and prompt tensors are necessary for a forward step"
# Asynchronous serialization
loop = asyncio.get_running_loop()
serialized_tensors = await asyncio.gather(
*(
loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
for tensor, proto in zip(inputs, nested_flatten(forward_schema_with_prompts))
for tensor, proto in zip(inputs, forward_schema)
)
)
@ -119,10 +114,8 @@ async def run_remote_backward(
uid: ModuleUID,
stub: StubBase,
rpc_info: RPCInfo,
inputs: torch.Tensor,
grad_outputs: List[torch.Tensor],
*extra_tensors: torch.Tensor,
config: SequenceManagerConfig,
*inputs_and_grad_outputs: torch.Tensor,
config: ClientConfig,
metadata: Optional[bytes] = None,
**kwargs,
) -> Sequence[torch.Tensor]:
@ -131,16 +124,14 @@ async def run_remote_backward(
Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L221
but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
"""
grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
inputs_and_grad_outputs = tuple(nested_flatten((inputs, grad_outputs_cpu, *extra_tensors)))
# Modify forward_schema to support prompts
args_schema, kwargs_schema = rpc_info["forward_schema"]
assert len(args_schema) == 1 and isinstance(inputs, torch.Tensor)
# TODO generalize this
prompts_schema = next(iter(args_schema))
backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"], prompts_schema)))
outputs_schema = rpc_info["outputs_schema"]
compression = args_schema[0].compression
backward_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in inputs_and_grad_outputs)
# TODO: create more explicit way to check servers schema and client's structure
assert (
len(inputs_and_grad_outputs) >= len(args_schema) + len(outputs_schema) + 1
), "Inputs, grad_outputs and prompt tensors are necessary for a backward step"
# Asynchronous serialization
loop = asyncio.get_running_loop()

@ -6,8 +6,9 @@ import torch
from hivemind import DHT, get_logger
from torch import nn
from petals.client.config import ClientConfig
from petals.client.inference_session import InferenceSession
from petals.client.routing.sequence_manager import RemoteSequenceManager, SequenceManagerConfig
from petals.client.routing import RemoteSequenceManager
from petals.client.sequential_autograd import _RemoteSequentialAutogradFunction
from petals.data_structures import UID_DELIMITER
from petals.utils.misc import DUMMY
@ -22,7 +23,7 @@ class RemoteSequential(nn.Module):
def __init__(
self,
config: SequenceManagerConfig,
config: ClientConfig,
*,
sequence_manager: Optional[RemoteSequenceManager] = None,
dht: Optional[DHT] = None,

@ -1 +1,2 @@
"""Client-side functions responsible for choosing the best server, """
from petals.client.routing.sequence_manager import RemoteSequenceManager, maybe_log_traceback
from petals.client.routing.spending_policy import NoSpendingPolicy, SpendingPolicyBase

@ -7,7 +7,8 @@ import logging
import random
import threading
import time
from typing import Any, Collection, Dict, List, Optional, Sequence, Union
import warnings
from typing import Any, Dict, List, Optional, Sequence, Set, Union
from weakref import WeakMethod
import dijkstar
@ -18,40 +19,27 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.proto import runtime_pb2
from hivemind.utils.logging import get_logger
import petals.dht_utils
from petals.client.config import ClientConfig
from petals.client.routing.sequence_info import RemoteSequenceInfo
from petals.client.routing.spending_policy import NoSpendingPolicy
from petals.constants import PUBLIC_INITIAL_PEERS
from petals.data_structures import ModuleUID, RemoteSpanInfo, ServerState
from petals.server.handler import TransformerConnectionHandler
from petals.utils.dht import get_remote_module_infos
from petals.utils.ping import PingAggregator
from petals.utils.random import sample_up_to
logger = get_logger(__name__)
@dataclasses.dataclass
class SequenceManagerConfig:
initial_peers: Sequence[str] = tuple(PUBLIC_INITIAL_PEERS) # a list of initial peers for hivemind DHT
dht_prefix: Optional[str] = None # a prefix for all dht keys that correspond to this model (default: model name)
daemon_startup_timeout: int = 60 # timeout for the libp2p daemon connecting to initial peers
show_route: Union[str, bool] = "inference" # show chosen route through servers. one of [False, "inference", True]
allowed_servers: Optional[Collection[Union[PeerID, str]]] = None # if defined, send requests only to these servers
use_server_to_server: bool = True # Use direct server-to-server communication
connect_timeout: float = 5 # timeout for opening a connection
request_timeout: float = 3 * 60 # timeout for forward/backward/inference requests
update_period: float = 60 # refresh DHT information once in this many seconds
max_retries: Optional[int] = None # max number retries before the client raises an exception (default: inf)
min_backoff: float = 1 # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1)
max_backoff: float = 60 # limit maximal sleep time between retries to this value
ban_timeout: float = 15 # when a remote peer fails to respond, prevent routing to that peer for this many seconds
active_adapter: Optional[str] = None # name of active LoRA adapter (usually, Hugging Face repo)
max_pinged: int = 3 # max servers to ping from each sequence side, per update
ping_timeout: float = 2 # max time to wait for pings, per update
class SequenceManagerConfig(ClientConfig):
def __init__(self, *args, **kwargs):
warnings.warn(
"petals.client.routing.SequenceManagerConfig has been moved to petals.ClientConfig. "
"This alias will be removed in Petals 2.2.0+",
DeprecationWarning,
stacklevel=2,
)
super().__init__(*args, **kwargs)
@dataclasses.dataclass
@ -82,7 +70,7 @@ class RemoteSequenceManager:
def __init__(
self,
config: SequenceManagerConfig,
config: ClientConfig,
block_uids: Sequence[ModuleUID],
*,
dht: Optional[DHT] = None,
@ -116,6 +104,9 @@ class RemoteSequenceManager:
self._thread_start_lock = threading.Lock()
self.policy = NoSpendingPolicy()
self.allowed_servers = self._peer_ids_to_set(config.allowed_servers)
self.blocked_servers = self._peer_ids_to_set(config.blocked_servers)
self.ping_aggregator = PingAggregator(dht)
if state.banned_peers is None:
@ -128,6 +119,23 @@ class RemoteSequenceManager:
self._thread.ready.set() # no need to await the first dht fetch
self._need_latest_infos = True
@staticmethod
def _peer_ids_to_set(peer_ids: Optional[Sequence[Union[PeerID, str]]]) -> Optional[Set[PeerID]]:
if peer_ids is None:
return None
result = set()
for peer_id in peer_ids:
if isinstance(peer_id, PeerID):
result.add(peer_id)
elif isinstance(peer_id, str):
result.add(PeerID.from_base58(peer_id))
else:
raise TypeError(
f"`allowed_servers` and `blocked_servers` have to contain only PeerIDs or strings, but got {type(peer_id)}"
)
return result
def make_sequence(
self,
start_index: int = 0,
@ -333,7 +341,7 @@ class RemoteSequenceManager:
def _update(self):
"""Perform an immediate and synchronous refresh, may take time"""
new_block_infos = petals.dht_utils.get_remote_module_infos(
new_block_infos = get_remote_module_infos(
self.dht, self.block_uids, active_adapter=self.config.active_adapter, latest=True
)
@ -341,13 +349,13 @@ class RemoteSequenceManager:
if not block_info:
continue
# Apply whitelist, if defined
if self.config.allowed_servers is not None:
block_info.servers = {
peer_id: server_info
for peer_id, server_info in block_info.servers.items()
if peer_id in self.config.allowed_servers or str(peer_id) in self.config.allowed_servers
}
# Apply allow and block lists
block_info.servers = {
peer_id: server_info
for peer_id, server_info in block_info.servers.items()
if (self.allowed_servers is None or peer_id in self.allowed_servers)
and (self.blocked_servers is None or peer_id not in self.blocked_servers)
}
# Remove temporarily banned peers, unless there are no peers left
valid_servers = {
@ -466,14 +474,21 @@ class RemoteSequenceManager:
return 0
return min(self.config.min_backoff * 2 ** (attempt_no - 1), self.config.max_backoff)
def get_request_metadata(self, protocol: str, *args, **kwargs) -> Optional[Dict[str, Any]]:
def get_request_metadata(
self, protocol: str, args_structure: Any = None, *args, **kwargs
) -> Optional[Dict[str, Any]]:
"""
:param protocol: one of "rpc_forward", "rpc_backward" or "rpc_inference"
:param args_structure: the structure of flattened tensors from pack_args_kwargs in petals.utils.packaging
:param args: request-specific inputs, typically block uids and input tensors
:param kwargs: additional request context, such as remote peer ID
:returns: msgpack-serialized metadata dict that will be passed alongside a given request
"""
return dict(points=self.policy.get_points(protocol, *args, **kwargs), active_adapter=self.config.active_adapter)
return dict(
points=self.policy.get_points(protocol, *args, **kwargs),
active_adapter=self.config.active_adapter,
args_structure=args_structure,
)
def shutdown(self):
self._thread.shutdown()

@ -12,10 +12,11 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.utils.logging import get_logger
from petals.client.remote_forward_backward import run_remote_backward, run_remote_forward
from petals.client.routing.sequence_manager import RemoteSequenceManager, maybe_log_traceback
from petals.client.routing import RemoteSequenceManager, maybe_log_traceback
from petals.data_structures import CHAIN_DELIMITER, RemoteSpanInfo
from petals.server.handler import TransformerConnectionHandler
from petals.utils.misc import DUMMY, is_dummy
from petals.utils.packaging import pack_args_kwargs
logger = get_logger(__name__)
@ -67,15 +68,17 @@ async def sequential_forward(
span = sequences.popleft()
stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id)
inputs_and_prompts = [inputs, prompts[span.start : span.end]]
flat_tensors, args_structure = pack_args_kwargs(inputs, prompts[span.start : span.end])
span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
metadata = sequence_manager.get_request_metadata("rpc_forward", span_uids, *inputs_and_prompts)
metadata = sequence_manager.get_request_metadata(
"rpc_forward", args_structure, span_uids, *flat_tensors
)
(outputs,) = await run_remote_forward(
span_uids,
stub,
sequence_manager.rpc_info,
*inputs_and_prompts,
*flat_tensors,
config=sequence_manager.config,
metadata=MSGPackSerializer.dumps(metadata),
)
@ -149,18 +152,21 @@ async def sequential_backward(
inputs = intermediate_inputs.pop()
span = forward_sequences.pop()
grad_outputs_cpu = [grad.cpu() for grad in grad_outputs]
flat_tensors, args_structure = pack_args_kwargs(
inputs, *grad_outputs_cpu, prompts[span.start : span.end]
)
span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id)
metadata = sequence_manager.get_request_metadata(
"rpc_backward", span_uids, *inputs, *grad_outputs, peer_id=span.peer_id
"rpc_backward", args_structure, span_uids, *flat_tensors, peer_id=span.peer_id
)
grad_outputs, *span_grad_prompts = await run_remote_backward(
span_uids,
stub,
sequence_manager.rpc_info,
inputs,
grad_outputs,
prompts[span.start : span.end],
*flat_tensors,
config=sequence_manager.config,
metadata=MSGPackSerializer.dumps(metadata),
)

@ -6,8 +6,6 @@ import pydantic
from hivemind import PeerID
from hivemind.moe.expert_uid import ExpertUID
from petals.server.memory_cache import Handle
ModuleUID = str
UID_DELIMITER = "." # delimits parts of one module uid, e.g. "bloom.transformer.h.4.self_attention"
CHAIN_DELIMITER = " " # delimits multiple uids in a sequence, e.g. "bloom.layer3 bloom.layer4"
@ -78,6 +76,8 @@ class RemoteSpanInfo:
RPCInfo = Dict[str, Any]
Handle = int
@dataclasses.dataclass(frozen=True)
class InferenceMetadata:

@ -1,124 +1,9 @@
"""
Utilities for declaring and retrieving active model layers using a shared DHT.
"""
from __future__ import annotations
import warnings
import math
from functools import partial
from typing import Dict, List, Optional, Sequence, Union
warnings.warn(
"petals.dht_utils has been moved to petals.utils.dht. This alias will be removed in Petals 2.2.0+",
DeprecationWarning,
stacklevel=2,
)
from hivemind.dht import DHT, DHTNode, DHTValue
from hivemind.p2p import PeerID
from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo
logger = get_logger(__name__)
def declare_active_modules(
dht: DHT,
uids: Sequence[ModuleUID],
server_info: ServerInfo,
expiration_time: DHTExpiration,
wait: bool = True,
) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]:
"""
Declare that your node serves the specified modules; update timestamps if declared previously
:param uids: a list of module ids to declare
:param wait: if True, awaits for declaration to finish, otherwise runs in background
:param throughput: specify your performance in terms of compute throughput
:param expiration_time: declared modules will be visible for this many seconds
:returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
"""
if isinstance(uids, str):
uids = [uids]
if not isinstance(uids, list):
uids = list(uids)
for uid in uids:
assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid
return dht.run_coroutine(
partial(_declare_active_modules, uids=uids, server_info=server_info, expiration_time=expiration_time),
return_future=not wait,
)
async def _declare_active_modules(
dht: DHT,
node: DHTNode,
uids: List[ModuleUID],
server_info: ServerInfo,
expiration_time: DHTExpiration,
) -> Dict[ModuleUID, bool]:
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
return await node.store_many(
keys=uids,
subkeys=[dht.peer_id.to_base58()] * len(uids),
values=[server_info.to_tuple()] * len(uids),
expiration_time=expiration_time,
num_workers=num_workers,
)
def get_remote_module_infos(
dht: DHT,
uids: Sequence[ModuleUID],
expiration_time: Optional[DHTExpiration] = None,
active_adapter: Optional[str] = None,
*,
latest: bool = False,
return_future: bool = False,
) -> Union[List[Optional[RemoteModuleInfo]], MPFuture]:
return dht.run_coroutine(
partial(
_get_remote_module_infos,
uids=uids,
active_adapter=active_adapter,
expiration_time=expiration_time,
latest=latest,
),
return_future=return_future,
)
async def _get_remote_module_infos(
dht: DHT,
node: DHTNode,
uids: List[ModuleUID],
active_adapter: Optional[str],
expiration_time: Optional[DHTExpiration],
latest: bool,
) -> List[Optional[RemoteModuleInfo]]:
if latest:
assert expiration_time is None, "You should define either `expiration_time` or `latest`, not both"
expiration_time = math.inf
elif expiration_time is None:
expiration_time = get_dht_time()
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
found: Dict[ModuleUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
modules: List[Optional[RemoteModuleInfo]] = [None] * len(uids)
for i, uid in enumerate(uids):
metadata = found[uid]
if metadata is None or not isinstance(metadata.value, dict):
if metadata is not None:
logger.warning(f"Incorrect metadata for {uid}: {metadata}")
continue
servers = {}
for peer_id, server_info in metadata.value.items():
try:
peer_id = PeerID.from_base58(peer_id)
server_info = ServerInfo.from_tuple(server_info.value)
if active_adapter and active_adapter not in server_info.adapters:
logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}")
continue
servers[peer_id] = server_info
except (TypeError, ValueError) as e:
logger.warning(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}")
if servers:
modules[i] = RemoteModuleInfo(uid, servers)
return modules
from petals.utils.dht import *

@ -5,15 +5,15 @@ from hivemind import get_logger
from transformers.models.bloom import BloomConfig
from transformers.models.bloom.modeling_bloom import BloomAttention
from petals.client.config import ClientConfig
from petals.client.lm_head import LMHeadConfig
from petals.client.ptune import PTuneConfig
from petals.client.routing.sequence_manager import SequenceManagerConfig
from petals.models.bloom.block import WrappedBloomBlock
logger = get_logger(__name__)
class DistributedBloomConfig(BloomConfig, SequenceManagerConfig, PTuneConfig, LMHeadConfig):
class DistributedBloomConfig(BloomConfig, ClientConfig, PTuneConfig, LMHeadConfig):
block_class = WrappedBloomBlock
attn_class = BloomAttention
block_prefix = "h"

@ -5,15 +5,15 @@ from hivemind import get_logger
from transformers.models.llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaAttention
from petals.client.config import ClientConfig
from petals.client.lm_head import LMHeadConfig
from petals.client.ptune import PTuneConfig
from petals.client.routing.sequence_manager import SequenceManagerConfig
from petals.models.llama.block import WrappedLlamaBlock
logger = get_logger(__name__)
class DistributedLlamaConfig(LlamaConfig, SequenceManagerConfig, PTuneConfig, LMHeadConfig):
class DistributedLlamaConfig(LlamaConfig, ClientConfig, PTuneConfig, LMHeadConfig):
block_class = WrappedLlamaBlock
attn_class = LlamaAttention
block_prefix = "model.layers"

@ -3,20 +3,30 @@ This module implements server-side computations on served blocks: forward, backw
"""
from __future__ import annotations
from typing import AsyncIterator, Optional, Sequence, Tuple, Union
from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Union
import torch
from hivemind.compression.serialization import deserialize_torch_tensor, serialize_torch_tensor
from hivemind.moe.expert_uid import ExpertUID
from hivemind.proto import runtime_pb2
from hivemind.utils.logging import get_logger
from hivemind.utils.nested import nested_flatten
from petals.data_structures import InferenceMetadata
from petals.data_structures import Handle, InferenceMetadata
from petals.server.backend import TransformerBackend
from petals.server.memory_cache import Handle
from petals.server.task_pool import PrioritizedTaskPool
from petals.server.task_prioritizer import TaskPrioritizerBase
from petals.utils.convert_block import QuantType
from petals.utils.misc import DUMMY, is_dummy
from petals.utils.packaging import unpack_args_kwargs
# We prioritize short inference requests and make them use a *merged* inference pool,
# so they are processed without interruptions and extra overheads
# TODO: Increase the NF4 threshold once bitsandbytes ships efficient NF4 kernel for parallel forward
MAX_SHORT_INFERENCE_TOKENS = 128
MAX_NF4_SHORT_INFERENCE_TOKENS = 1
logger = get_logger(__name__)
async def run_rpc_forward(
@ -25,6 +35,7 @@ async def run_rpc_forward(
active_adapter: str = "",
prioritizer: TaskPrioritizerBase,
points: int = 0,
args_structure: Any = None,
) -> torch.Tensor:
"""
Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
@ -34,7 +45,11 @@ async def run_rpc_forward(
:param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
:returns: hidden states after the last layer [batch_size, seq_length, hid_size]
"""
hidden_states, prompts = flat_tensors
if args_structure is not None:
# TODO: kwargs currently is unused, it can be used later for peft-like adaptation
flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure)
hidden_states, prompts, *_ = flat_tensors
dtype = requested_backends[0].dtype
# check parse input tensors and cast dtypes
hidden_states = hidden_states.to(dtype)
@ -72,8 +87,13 @@ async def run_rpc_backward(
active_adapter: str = "",
prioritizer: TaskPrioritizerBase,
points: int = 0,
args_structure: Any = None,
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
inputs, grad_outputs, prompts = flat_tensors
if args_structure is not None:
# TODO: kwargs currently is unused, it can be used later for peft-like adaptation
flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure)
inputs, grad_outputs, prompts, *_ = flat_tensors
# Cast inputs & grad outputs to backend dtype
inputs = inputs.to(requested_backends[0].dtype)
grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
@ -127,9 +147,12 @@ async def iterate_rpc_inference(
active_adapter: Optional[str],
input_iterator: AsyncIterator[Tuple[runtime_pb2.ExpertRequest, dict]],
cache_handles: Sequence[Sequence[Handle]],
*,
max_length: int,
prioritizer: TaskPrioritizerBase,
points: int,
quant_type: QuantType,
args_structure: Any = None,
) -> AsyncIterator[Tuple[Sequence[runtime_pb2.Tensor], bool]]:
assert len(cache_handles) == len(requested_backends)
@ -137,7 +160,13 @@ async def iterate_rpc_inference(
point_per_piece = points / max_length if max_length > 0 else 0.0
async for request, step_metadata in input_iterator:
hidden_states, prompts, hypo_ids = map(deserialize_torch_tensor, request.tensors)
flat_tensors = tuple(deserialize_torch_tensor(tensor) for tensor in request.tensors)
if args_structure is not None:
# TODO: kwargs currently is unused, it can be used later for peft-like adaptation
flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure)
hidden_states, prompts, hypo_ids, *_ = flat_tensors
batch_size, length_increment, _ = hidden_states.shape
# Cast inputs to backend dtype
hidden_states = hidden_states.to(requested_backends[0].dtype)
@ -154,34 +183,40 @@ async def iterate_rpc_inference(
if not (len(requested_backends) == len(prompts)):
raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends")
length_increment = hidden_states.shape[1] # how many tokens are added this step (in each seq)
if prefix_length + length_increment > max_length:
raise ValueError(
f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
f" exceeds pre-allocated maximum {max_length}"
)
merge_max_tokens = MAX_NF4_SHORT_INFERENCE_TOKENS if quant_type == QuantType.NF4 else MAX_SHORT_INFERENCE_TOKENS
can_merge_pools = batch_size * length_increment <= merge_max_tokens
priority = prioritizer.prioritize(
hidden_states,
hypo_ids,
points=point_per_piece,
requested_uids=requested_uids,
type="inference",
type="short_inference" if can_merge_pools else "inference",
)
inference_infos = tuple(
InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter)
for uid, handles in zip(requested_uids, cache_handles)
)
if hidden_states.numel() == 0:
pass # user passed a tensor with 0 tokens. This is a special case that occurs, e.g.
# when user wants to pre-allocate cache or check that server *can* allocate that cache
else:
# A client may pass a tensor with 0 tokens. This is a special case that occurs, e.g.
# when user wants to pre-allocate cache or check that server *can* allocate that cache.
if hidden_states.numel() > 0:
assert hidden_states.ndim == 3, f"hidden states must be a single 3d tensor"
(hidden_states,) = await requested_backends[0].inference_pool.submit_task(
hidden_states, hypo_ids, inference_infos, *prompts, priority=priority
)
if can_merge_pools:
inference_infos = tuple(
InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter)
for uid, handles in zip(requested_uids, cache_handles)
)
(hidden_states,) = await requested_backends[0].inference_pool.submit_task(
hidden_states, hypo_ids, inference_infos, *prompts, priority=priority
)
else:
for backend, uid, handles, prompt in zip(requested_backends, requested_uids, cache_handles, prompts):
inference_infos = (InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter),)
(hidden_states,) = await backend.inference_pool.submit_task(
hidden_states, hypo_ids, inference_infos, prompt, priority=priority
)
# serialize and send last layer outputs
output_tensors = [

@ -29,11 +29,11 @@ from hivemind.utils.logging import get_logger
from hivemind.utils.streaming import split_for_streaming
import petals
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, Handle, ModuleUID
from petals.server.backend import TransformerBackend
from petals.server.block_functions import iterate_rpc_inference, run_rpc_backward, run_rpc_forward
from petals.server.memory_cache import Handle
from petals.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase
from petals.utils.convert_block import QuantType
logger = get_logger(__name__)
@ -71,6 +71,7 @@ class TransformerConnectionHandler(ConnectionHandler):
session_timeout: float,
step_timeout: float,
task_prioritizer: TaskPrioritizerBase = DummyTaskPrioritizer(),
quant_type: QuantType,
):
super().__init__(dht, module_backends)
for module_backend in self.module_backends.values():
@ -88,6 +89,7 @@ class TransformerConnectionHandler(ConnectionHandler):
self.request_timeout = request_timeout
self.session_timeout, self.step_timeout = session_timeout, step_timeout
self._prioritizer = task_prioritizer
self.quant_type = quant_type
async def add_p2p_handlers(self, *args, **kwargs) -> None:
if self._listener_task is None:
@ -149,6 +151,7 @@ class TransformerConnectionHandler(ConnectionHandler):
points = metadata.get("points", 0)
session_id = metadata.get("session_id")
alloc_timeout = float(metadata.get("alloc_timeout", 0.0))
args_structure = metadata.get("args_structure")
if not requested_uids:
raise ValueError("User must specify at least one block for inference, but got none")
assert isinstance(
@ -179,6 +182,8 @@ class TransformerConnectionHandler(ConnectionHandler):
max_length=max_length,
prioritizer=self._prioritizer,
points=points,
quant_type=self.quant_type,
args_structure=args_structure,
):
if can_push:
task = asyncio.create_task(self._push_outputs(request, output_tensors[0], metadata))
@ -355,6 +360,7 @@ class TransformerConnectionHandler(ConnectionHandler):
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
active_adapter = self._get_active_adapter(metadata)
points = metadata.get("points", 0)
args_structure = metadata.get("args_structure")
assert isinstance(
points, (float, int)
), f"rpc_forward should have number of points as number or None, got {points}"
@ -365,6 +371,7 @@ class TransformerConnectionHandler(ConnectionHandler):
prioritizer=self._prioritizer,
active_adapter=active_adapter,
points=points,
args_structure=args_structure,
)
return runtime_pb2.ExpertResponse(
tensors=self._serialize_outputs(hidden_states, requested_backends, metadata)
@ -382,6 +389,7 @@ class TransformerConnectionHandler(ConnectionHandler):
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
active_adapter = self._get_active_adapter(metadata)
points = metadata.get("points", 0)
args_structure = metadata.get("args_structure")
assert isinstance(
points, (float, int)
), f"rpc_forward_stream should have number of points as number or None, got {points}"
@ -392,6 +400,7 @@ class TransformerConnectionHandler(ConnectionHandler):
prioritizer=self._prioritizer,
active_adapter=active_adapter,
points=points,
args_structure=args_structure,
)
# Split the serialized_output for streaming and respond to client
@ -433,6 +442,7 @@ class TransformerConnectionHandler(ConnectionHandler):
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
active_adapter = self._get_active_adapter(metadata)
points = metadata.get("points", 0)
args_structure = metadata.get("args_structure")
assert isinstance(
points, (float, int)
), f"rpc_backward should have number of points as number or None, got {points}"
@ -443,6 +453,7 @@ class TransformerConnectionHandler(ConnectionHandler):
prioritizer=self._prioritizer,
active_adapter=active_adapter,
points=points,
args_structure=args_structure,
)
return runtime_pb2.ExpertResponse(tensors=self._serialize_grads(grads, requested_backends, metadata))
@ -458,6 +469,7 @@ class TransformerConnectionHandler(ConnectionHandler):
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
active_adapter = self._get_active_adapter(metadata)
points = metadata.get("points", 0)
args_structure = metadata.get("args_structure")
assert isinstance(
points, (float, int)
), f"rpc_backward_stream should have number of points as number or None, got {points}"
@ -468,6 +480,7 @@ class TransformerConnectionHandler(ConnectionHandler):
prioritizer=self._prioritizer,
active_adapter=active_adapter,
points=points,
args_structure=args_structure,
)
# Split the serialized_grad_inputs for streaming and respond
for tensor in self._serialize_grads(grads, requested_backends, metadata):

@ -16,13 +16,12 @@ import async_timeout
import torch
from hivemind.utils import TensorDescriptor, anext, enter_asynchronously, get_logger
from petals.data_structures import Handle
from petals.utils.asyncio import shield_and_wait
from petals.utils.misc import get_size_in_bytes
logger = get_logger(__name__)
Handle = int
class MemoryCache:
"""A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs"""
@ -102,7 +101,7 @@ class MemoryCache:
alloc_task = asyncio.create_task(self._schedule_alloc(max_alloc_size, *descriptors, timeout=timeout))
try:
handles = await shield_and_wait(alloc_task)
logger.info(f"rpc_inference.alloc-done(size={max_alloc_size / gib:.2f} GiB)")
logger.info(f"rpc_inference.alloc_done(size={max_alloc_size / gib:.2f} GiB)")
yield handles
finally:
self._free(max_alloc_size, alloc_task)

@ -20,7 +20,6 @@ from transformers import PretrainedConfig
import petals
from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerInfo, ServerState
from petals.dht_utils import declare_active_modules, get_remote_module_infos
from petals.server import block_selection
from petals.server.backend import TransformerBackend, merge_inference_pools_inplace
from petals.server.block_utils import get_block_size, resolve_block_dtype
@ -31,6 +30,7 @@ from petals.server.reachability import ReachabilityProtocol, check_direct_reacha
from petals.server.throughput import get_dtype_name, get_server_throughput
from petals.utils.auto_config import AutoDistributedConfig
from petals.utils.convert_block import QuantType, check_device_balance, convert_block
from petals.utils.dht import declare_active_modules, get_remote_module_infos
from petals.utils.misc import get_size_in_bytes
from petals.utils.ping import PingAggregator
from petals.utils.random import sample_up_to
@ -560,6 +560,7 @@ class ModuleContainer(threading.Thread):
request_timeout=request_timeout,
session_timeout=session_timeout,
step_timeout=step_timeout,
quant_type=QuantType[server_info.quant_type.upper()],
)
for i in range(num_handlers)
]

@ -13,9 +13,10 @@ class TaskPrioritizerBase(ABC):
class DummyTaskPrioritizer(TaskPrioritizerBase):
"""Simple implementation of TaskPrioritizer which gives constant zero priority for every task"""
def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float:
# Inference steps (especially short ones) go first since they are more latency-sensitive
if kwargs.get("type") == "short_inference":
return 1.0
if kwargs.get("type") == "inference":
return 1.0 # inference steps go first since they are more latency-sensitive
return 2.0 # forward, backward
return 2.0
return 3.0 # Forward, backward

@ -4,3 +4,4 @@ from petals.utils.auto_config import (
AutoDistributedModelForCausalLM,
AutoDistributedModelForSequenceClassification,
)
from petals.utils.dht import declare_active_modules, get_remote_module_infos

@ -0,0 +1,124 @@
"""
Utilities for declaring and retrieving active model layers using a shared DHT.
"""
from __future__ import annotations
import math
from functools import partial
from typing import Dict, List, Optional, Sequence, Union
from hivemind.dht import DHT, DHTNode, DHTValue
from hivemind.p2p import PeerID
from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo
logger = get_logger(__name__)
def declare_active_modules(
dht: DHT,
uids: Sequence[ModuleUID],
server_info: ServerInfo,
expiration_time: DHTExpiration,
wait: bool = True,
) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]:
"""
Declare that your node serves the specified modules; update timestamps if declared previously
:param uids: a list of module ids to declare
:param wait: if True, awaits for declaration to finish, otherwise runs in background
:param throughput: specify your performance in terms of compute throughput
:param expiration_time: declared modules will be visible for this many seconds
:returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
"""
if isinstance(uids, str):
uids = [uids]
if not isinstance(uids, list):
uids = list(uids)
for uid in uids:
assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid
return dht.run_coroutine(
partial(_declare_active_modules, uids=uids, server_info=server_info, expiration_time=expiration_time),
return_future=not wait,
)
async def _declare_active_modules(
dht: DHT,
node: DHTNode,
uids: List[ModuleUID],
server_info: ServerInfo,
expiration_time: DHTExpiration,
) -> Dict[ModuleUID, bool]:
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
return await node.store_many(
keys=uids,
subkeys=[dht.peer_id.to_base58()] * len(uids),
values=[server_info.to_tuple()] * len(uids),
expiration_time=expiration_time,
num_workers=num_workers,
)
def get_remote_module_infos(
dht: DHT,
uids: Sequence[ModuleUID],
expiration_time: Optional[DHTExpiration] = None,
active_adapter: Optional[str] = None,
*,
latest: bool = False,
return_future: bool = False,
) -> Union[List[Optional[RemoteModuleInfo]], MPFuture]:
return dht.run_coroutine(
partial(
_get_remote_module_infos,
uids=uids,
active_adapter=active_adapter,
expiration_time=expiration_time,
latest=latest,
),
return_future=return_future,
)
async def _get_remote_module_infos(
dht: DHT,
node: DHTNode,
uids: List[ModuleUID],
active_adapter: Optional[str],
expiration_time: Optional[DHTExpiration],
latest: bool,
) -> List[Optional[RemoteModuleInfo]]:
if latest:
assert expiration_time is None, "You should define either `expiration_time` or `latest`, not both"
expiration_time = math.inf
elif expiration_time is None:
expiration_time = get_dht_time()
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
found: Dict[ModuleUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
modules: List[Optional[RemoteModuleInfo]] = [None] * len(uids)
for i, uid in enumerate(uids):
metadata = found[uid]
if metadata is None or not isinstance(metadata.value, dict):
if metadata is not None:
logger.warning(f"Incorrect metadata for {uid}: {metadata}")
continue
servers = {}
for peer_id, server_info in metadata.value.items():
try:
peer_id = PeerID.from_base58(peer_id)
server_info = ServerInfo.from_tuple(server_info.value)
if active_adapter and active_adapter not in server_info.adapters:
logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}")
continue
servers[peer_id] = server_info
except (TypeError, ValueError) as e:
logger.warning(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}")
if servers:
modules[i] = RemoteModuleInfo(uid, servers)
return modules

@ -2,6 +2,8 @@ import torch
DUMMY = torch.empty(0) # dummy tensor that replaces empty prompt or adapter parameters
DUMMY_INT64 = torch.empty(0, dtype=torch.int64)
def is_dummy(tensor: torch.Tensor):
return tensor.numel() == 0

@ -0,0 +1,49 @@
from typing import Any, Dict, List, Tuple
import torch
from hivemind import nested_flatten, nested_pack
# TODO: Move functions to hivemind
def _mark_masked_tensor(index: int) -> bytes:
return b"__T" + str(index).encode()
def _is_masked_tensor(item: Any) -> bool:
return isinstance(item, bytes) and item.startswith(b"__T")
def _get_tensor_index(item: bytes) -> int:
return int(item[3:])
def pack_args_kwargs(*args, **kwargs) -> Tuple[List[torch.Tensor], Any]:
"""
Check the function's arguments and pack all tensors into different flattened lists.
:returns: a flattened list of tensors and args and kwargs, where tensors were masked
"""
masked_flat_values, flat_tensors, tensor_to_index = [], [], {}
for value in nested_flatten((args, kwargs)):
if isinstance(value, torch.Tensor):
tensor_index = tensor_to_index.setdefault(value, len(flat_tensors))
if tensor_index == len(flat_tensors):
flat_tensors.append(value)
masked_flat_values.append(_mark_masked_tensor(tensor_index))
else:
masked_flat_values.append(value)
return flat_tensors, nested_pack(masked_flat_values, (args, kwargs))
def unpack_args_kwargs(flat_tensors: List[torch.Tensor], args_structure: Any):
"""
Restore arguments after `pack_args_kwargs` function.
:returns: list of args and dict of kwargs
"""
return nested_pack(
(
value if not _is_masked_tensor(value) else flat_tensors[_get_tensor_index(value)]
for value in nested_flatten(args_structure)
),
args_structure,
)

@ -3,10 +3,13 @@ import sys
import pytest
import torch
from hivemind import nested_compare, nested_flatten
from petals import AutoDistributedConfig
from petals.server.throughput import measure_compute_rps
from petals.utils.convert_block import QuantType
from petals.utils.misc import DUMMY, is_dummy
from petals.utils.packaging import pack_args_kwargs, unpack_args_kwargs
from test_utils import MODEL_NAME
@ -44,3 +47,29 @@ def test_compute_throughput(inference: bool, n_tokens: int, tensor_parallel: boo
inference=inference,
)
assert isinstance(compute_rps, float) and compute_rps > 0
@pytest.mark.forked
def test_pack_inputs():
x = torch.ones(3)
y = torch.arange(5)
z = DUMMY
args = (x, z, None, (y, y), z)
kwargs = dict(foo=torch.zeros(1, 1), bar={"l": "i", "g": "h", "t": ("y", "e", "a", "r", torch.rand(1), x, y)})
flat_tensors, args_structure = pack_args_kwargs(*args, **kwargs)
assert len(flat_tensors) == 5
assert all(isinstance(t, torch.Tensor) for t in flat_tensors)
restored_args, restored_kwargs = unpack_args_kwargs(flat_tensors, args_structure)
assert len(restored_args) == len(args)
assert torch.all(restored_args[0] == x).item() and restored_args[2] is None
assert nested_compare((args, kwargs), (restored_args, restored_kwargs))
for original, restored in zip(nested_flatten((args, kwargs)), nested_flatten((restored_args, restored_kwargs))):
if isinstance(original, torch.Tensor):
assert torch.all(original == restored)
else:
assert original == restored

@ -4,6 +4,7 @@ import pytest
import torch
from petals import AutoDistributedConfig, RemoteSequential
from petals.server.block_functions import MAX_SHORT_INFERENCE_TOKENS
from petals.server.from_pretrained import load_pretrained_block
from test_utils import *
@ -13,26 +14,30 @@ def test_remote_block_exact_match(atol_forward=1e-4, atol_inference=1e-3):
config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
remote_sequential = RemoteSequential(config)
for block_index in random.sample(range(config.num_hidden_layers), 3):
remote_block = remote_sequential[block_index]
block_index = random.randint(0, config.num_hidden_layers - 1)
remote_block = remote_sequential[block_index]
inputs = torch.randn(1, 8, config.hidden_size)
outputs_forward = remote_block(inputs)
inputs = torch.randn(1, MAX_SHORT_INFERENCE_TOKENS + 8, config.hidden_size)
outputs_forward = remote_block(inputs)
outputs_inference = []
with torch.inference_mode():
with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
for i in range(inputs.shape[1]):
outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
outputs_inference = []
with torch.inference_mode():
with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
# Test long inference (unmerged inference pools)
outputs_inference.append(sess.step(inputs[:, : MAX_SHORT_INFERENCE_TOKENS + 1, :]))
# test that max length is respected
with pytest.raises(ValueError, match=r"Maximum length exceeded") as exc_info:
sess.step(inputs[:, -1:, :])
assert "Maximum length exceeded" in repr(exc_info.value)
outputs_inference = torch.cat(outputs_inference, dim=1)
# Test short inference (merged inference pools)
for i in range(MAX_SHORT_INFERENCE_TOKENS + 1, inputs.shape[1]):
outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)
(outputs_local,) = ref_block(inputs)
# test that max length is respected
with pytest.raises(ValueError, match=r"Maximum length exceeded") as exc_info:
sess.step(inputs[:, -1:, :])
assert "Maximum length exceeded" in repr(exc_info.value)
outputs_inference = torch.cat(outputs_inference, dim=1)
assert torch.allclose(outputs_local, outputs_forward, rtol=0, atol=atol_forward)
assert torch.allclose(outputs_local, outputs_inference, rtol=0, atol=atol_inference)
ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)
(outputs_local,) = ref_block(inputs)
assert torch.allclose(outputs_local, outputs_forward, rtol=0, atol=atol_forward)
assert torch.allclose(outputs_local, outputs_inference, rtol=0, atol=atol_inference)

@ -40,10 +40,10 @@ def test_remote_sequential():
assert hidden.shape == test_inputs.shape
assert hidden.requires_grad
second_half_outputs = second_half(hidden)
assert torch.allclose(second_half_outputs, full_outputs, atol=3e-4)
assert torch.allclose(second_half_outputs, full_outputs, atol=1e-3)
(second_half_outputs * grad_proj).sum().backward()
assert torch.allclose(test_inputs.grad, full_grad, atol=1e-2)
assert torch.allclose(test_inputs.grad, full_grad, atol=3e-2)
# test RemoteSequential with lossy compression
block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.num_hidden_layers)]

Loading…
Cancel
Save