pull/467/merge
justheuristic 6 months ago committed by GitHub
commit 875dc81c4c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -24,6 +24,9 @@ if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
assert (
version.parse("4.32.0") <= version.parse(transformers.__version__) < version.parse("4.35.0")
), "Please install a proper transformers version: pip install transformers>=4.32.0,<4.35.0"
assert version.parse("1.1.10") <= version.parse(
hivemind.__version__
), "Please install a proper hivemind version: pip install hivemind>=1.1.10"
def _override_bfloat16_mode_default():

@ -4,18 +4,16 @@ import asyncio
import itertools
import time
import uuid
from typing import AsyncIterator, List, Optional, Tuple
from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Tuple
import torch
from hivemind import MSGPackSerializer, anext, deserialize_torch_tensor, get_logger, serialize_torch_tensor
from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.p2p import P2P
from hivemind.proto import runtime_pb2
from hivemind.utils.tensor_descr import BatchTensorDescriptor
from hivemind.utils import MSGPackSerializer, anext, get_logger, nested_flatten
from petals.client.config import ClientConfig
from petals.client.routing import RemoteSequenceManager, maybe_log_traceback
from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo
from petals.server.handler import TransformerConnectionHandler
from petals.utils.misc import DUMMY, DUMMY_INT64, is_dummy
from petals.utils.packaging import pack_args_kwargs
@ -32,23 +30,21 @@ class _ServerInferenceSession:
def __init__(
self,
config: ClientConfig,
sequence_manager: RemoteSequenceManager,
span: RemoteSpanInfo,
uid: ModuleUID,
rpc_info: RPCInfo,
span_uids: Sequence[ModuleUID],
inputs_queue: asyncio.Queue,
outputs_aiter: AsyncIterator,
*,
*block_kwargs,
max_length: int,
**metadata,
):
self.config = config
self.span, self.uid, self.rpc_info = span, uid, rpc_info
self.num_blocks = uid.count(CHAIN_DELIMITER) + 1
self.sequence_manager = sequence_manager
self.span, self.span_uids = span, span_uids
self.num_blocks = len(span_uids)
self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
self.session_id = str(uuid.uuid4())
self.session_metadata = dict(max_length=max_length, **metadata)
self.max_length = max_length
self.stepped = False
self.closed = False
@ -56,24 +52,26 @@ class _ServerInferenceSession:
self.history = None # Used in case of server failures to regenerate attention caches on new servers
self.next_session = None
self.block_kwargs = block_kwargs
assert len(self.block_kwargs) in (0, self.num_blocks)
@classmethod
async def create(
cls,
config: ClientConfig,
p2p: P2P,
sequence_manager: RemoteSequenceManager,
span: RemoteSpanInfo,
uid: ModuleUID,
rpc_info: RPCInfo,
**metadata,
span_uids: Sequence[ModuleUID],
*block_kwargs: Dict[str, Any],
**kwargs,
) -> _ServerInferenceSession:
"""Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
stub = TransformerConnectionHandler.get_stub(p2p, span.peer_id)
stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id)
inputs_queue = asyncio.Queue()
outputs_stream = await asyncio.wait_for(
stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)),
config.connect_timeout,
sequence_manager.config.connect_timeout,
)
return cls(config, span, uid, rpc_info, inputs_queue, outputs_stream, **metadata)
return cls(sequence_manager, span, span_uids, inputs_queue, outputs_stream, *block_kwargs, **kwargs)
@staticmethod
async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[float] = None) -> AsyncIterator:
@ -84,11 +82,16 @@ class _ServerInferenceSession:
break # this message means "done sending"
def step(
self, inputs: torch.Tensor, prompts: torch.Tensor, hypo_ids: torch.LongTensor, *, step_id: str
self,
inputs: torch.Tensor,
prompts: Optional[torch.Tensor] = None,
*,
hypo_ids: Optional[torch.Tensor] = None,
step_id: str,
) -> torch.Tensor:
"""
Inference step: send a chunk of input tensors and receive a chunk of outputs
:prompts: optional DEEP prompts, added to a prefix of each layer's outputs,
:param prompts: optional DEEP prompts, added to a prefix of each layer's outputs,
if specified, deep prompts should have shape [num_layers, batch_size, prefix_len, hid_size]
"""
if self.closed:
@ -106,41 +109,70 @@ class _ServerInferenceSession:
if not self.stepped:
inputs = self.history # Pass full inputs including prefix
block_kwargs = self.block_kwargs
else:
inputs = inputs[:, -n_input_tokens:] # No need to pass prefix further
block_kwargs = []
# serialize inputs and put them into the queue
input_tensors, args_structure = pack_args_kwargs(inputs, prompts, hypo_ids)
if prompts is None or is_dummy(prompts):
prompts = DUMMY
else:
assert prompts.ndim == 4, "deep prompts should have shape [num_blocks, batch_size, prefix_len, hid_size]"
assert prompts.shape[0] == self.num_blocks
assert prompts.shape[1] in (inputs.shape[0], 1)
assert prompts.shape[2] <= inputs.shape[1]
assert prompts.shape[3] == inputs.shape[2]
request_metadata = dict(session_id=self.session_id, step_id=step_id)
if not self.stepped:
request_metadata.update(self.session_metadata)
elif self.config.use_server_to_server:
if hypo_ids is None or is_dummy(hypo_ids):
hypo_ids = DUMMY_INT64
else:
assert len(hypo_ids) == len(inputs)
assert hypo_ids.dtype == torch.int64
metadata = dict(session_id=self.session_id, step_id=step_id, max_length=self.max_length)
metadata.update(
self.sequence_manager.get_request_metadata(
self.span.peer_id,
"rpc_inference",
self.span_uids,
inputs,
prompts,
*block_kwargs,
max_length=self.max_length,
session_id=self.session_id,
step_id=step_id,
)
)
if self.stepped and self.sequence_manager.config.use_server_to_server:
next_servers = self._collect_next_servers()
if next_servers:
request_metadata["next_servers"] = next_servers
metadata["next_servers"] = next_servers
request_metadata["args_structure"] = args_structure
codecs = self.sequence_manager.get_compression_codecs(
self.span.peer_id, "rpc_inference", self.span_uids, inputs, prompts, *block_kwargs
)
# TODO: make possible to use different compression method for different tensors
server_side_inference_schema, kwargs_schema = self.rpc_info["inference_schema"]
compression = server_side_inference_schema[0].compression
inference_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in input_tensors)
# serialize inputs and put them into the queue
input_tensors, args_structure = pack_args_kwargs(inputs, prompts, hypo_ids, *block_kwargs)
args_structure = metadata.setdefault("args_structure", args_structure)
# TODO: create more explicit way to check servers schema and client's structure
assert len(input_tensors) >= len(
server_side_inference_schema
), "Hidden_state, prompts and hypo_ids tensors are necessary for an inference step"
if codecs is None:
codecs = [runtime_pb2.CompressionType.NONE] * len(input_tensors)
else:
codecs = list(nested_flatten(codecs))
assert len(codecs) == len(
input_tensors
), f"got {len(input_tensors)} tensors but {len(codecs)} compression codecs"
outputs_serialized = RemoteExpertWorker.run_coroutine(
self._step(
runtime_pb2.ExpertRequest(
uid=self.uid,
uid=CHAIN_DELIMITER.join(self.span_uids),
tensors=[
serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
for tensor, proto in zip(input_tensors, inference_schema)
serialize_torch_tensor(tensor, compression)
for tensor, compression in zip(input_tensors, codecs)
],
metadata=MSGPackSerializer.dumps(request_metadata),
metadata=MSGPackSerializer.dumps(metadata),
)
)
)
@ -167,7 +199,7 @@ class _ServerInferenceSession:
"""Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
await self._inputs_queue.put(inputs_serialized)
self.stepped = True
return await asyncio.wait_for(anext(self._outputs_stream), self.config.request_timeout)
return await asyncio.wait_for(anext(self._outputs_stream), self.sequence_manager.config.request_timeout)
def close(self):
"""Finish a given inference session, close the underlying connection"""
@ -204,7 +236,7 @@ class InferenceSession:
An interface to a multi-step *inference* session for a sequence of remote transformer blocks
"""
def __init__(self, sequence_manager: RemoteSequenceManager, max_length: int):
def __init__(self, sequence_manager: RemoteSequenceManager, max_length: int, *block_kwargs: Dict[str, Any]):
self._sequence_manager = sequence_manager
self._closed = False
self._server_sessions = []
@ -212,6 +244,12 @@ class InferenceSession:
self._max_length = max_length
self.output_ids = None
num_blocks = len(self._sequence_manager)
if len(block_kwargs) == 1:
block_kwargs = block_kwargs * num_blocks
assert len(block_kwargs) in (0, num_blocks), f"expected {num_blocks} block_kwargs, got {len(block_kwargs)}"
self.block_kwargs = block_kwargs
@property
def num_blocks(self) -> int:
return len(self._sequence_manager)
@ -224,17 +262,13 @@ class InferenceSession:
server_sessions = []
try:
for span in chosen_spans:
span_uids = CHAIN_DELIMITER.join(self._sequence_manager.block_uids[span.start : span.end])
metadata = self._sequence_manager.get_request_metadata("rpc_inference", span_uids, peer_id=span.peer_id)
session = RemoteExpertWorker.run_coroutine(
_ServerInferenceSession.create(
self._sequence_manager.config,
self._sequence_manager.state.p2p,
self._sequence_manager,
span,
span_uids,
rpc_info=self._sequence_manager.rpc_info,
self._sequence_manager.block_uids[span.start : span.end],
*self.block_kwargs[span.start : span.end],
max_length=self._max_length,
**metadata,
)
)
server_sessions.append(session)
@ -256,8 +290,12 @@ class InferenceSession:
return self
def step(
self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, hypo_ids: Optional[torch.Tensor] = None
self,
inputs: torch.Tensor,
prompts: Optional[torch.Tensor] = None,
hypo_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert not self._closed
if torch.is_grad_enabled():
logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
@ -302,7 +340,10 @@ class InferenceSession:
server_session = self._server_sessions[server_idx]
inputs = server_session.step(
inputs, prompts[server_session.span.start : server_session.span.end], hypo_ids, step_id=step_id
inputs,
prompts[server_session.span.start : server_session.span.end],
hypo_ids=hypo_ids,
step_id=step_id,
)
server_idx += 1
@ -328,7 +369,7 @@ class InferenceSession:
outputs = outputs.to(device=inputs_device, dtype=inputs_dtype)
return outputs
def _update_sequence(self, server_idx: int, block_idx: int, attempt_no: int) -> int:
def _update_sequence(self, server_idx: int, block_idx: int, attempt_no: int):
# If there is a failed server session, this code closes it
self._exit_server_sessions(self._server_sessions[server_idx : server_idx + 1])

@ -2,20 +2,22 @@
Utility functions that call RPC forward or backward on a single remote server
"""
import asyncio
from typing import Iterable, List, Optional, Sequence, Tuple
from typing import Iterable, List, Sequence, Tuple
import torch
from hivemind import nested_compare, nested_flatten, nested_pack, serialize_torch_tensor
from hivemind import PeerID, nested_flatten, serialize_torch_tensor
from hivemind.compression.serialization import deserialize_tensor_stream, deserialize_torch_tensor
from hivemind.p2p import StubBase
from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, MAX_UNARY_PAYLOAD_SIZE
from hivemind.proto import runtime_pb2
from hivemind.utils.asyncio import aiter_with_timeout, iter_as_aiter
from hivemind.utils.serializer import MSGPackSerializer
from hivemind.utils.streaming import split_for_streaming
from hivemind.utils.tensor_descr import BatchTensorDescriptor
from petals.client.config import ClientConfig
from petals.data_structures import ModuleUID, RPCInfo
from petals.client.routing import RemoteSequenceManager
from petals.data_structures import CHAIN_DELIMITER, ModuleUID
from petals.server.handler import TransformerConnectionHandler
from petals.utils.packaging import pack_args_kwargs
async def _forward_unary(
@ -65,85 +67,93 @@ async def _backward_stream(
async def run_remote_forward(
uid: ModuleUID,
stub: StubBase,
rpc_info: RPCInfo,
*inputs: torch.Tensor,
config: ClientConfig,
metadata: Optional[bytes] = None,
**kwargs,
sequence_manager: RemoteSequenceManager,
peer_id: PeerID,
span_uids: Sequence[ModuleUID],
*args: torch.Tensor,
**kwargs: torch.Tensor,
) -> Tuple[torch.Tensor, ...]:
"""
Serializes input tensors and calls "rpc_forward" on a remote server.
Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198
but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
"""
# Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
# detach to avoid pickling the computation graph
assert len(kwargs) == len(rpc_info["keyword_names"]), f"Keyword args should be {rpc_info['keyword_names']}"
kwargs = {key: kwargs[key] for key in rpc_info["keyword_names"]}
# Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
forward_inputs = tuple(nested_flatten((inputs, kwargs)))
args_schema, kwargs_schema = rpc_info["forward_schema"]
compression = args_schema[0].compression
forward_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in forward_inputs)
inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs)
# TODO: create more explicit way to check servers schema and client's structure
assert len(inputs) >= len(args_schema) + 1, "Inputs and prompt tensors are necessary for a forward step"
merged_uid = CHAIN_DELIMITER.join(span_uids)
stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, peer_id)
metadata = sequence_manager.get_request_metadata(peer_id, "rpc_forward", span_uids, *args, **kwargs)
codecs = sequence_manager.get_compression_codecs(peer_id, "rpc_forward", span_uids, *args, **kwargs)
flat_tensors, args_structure = pack_args_kwargs(*args, **kwargs)
flat_tensors = tuple(tensor.cpu().detach().requires_grad_(tensor.requires_grad) for tensor in flat_tensors)
args_structure = metadata.setdefault("args_structure", args_structure)
if codecs is None:
codecs = [runtime_pb2.CompressionType.NONE] * len(flat_tensors)
else:
codecs = list(nested_flatten(codecs))
assert len(codecs) == len(flat_tensors), f"got {len(flat_tensors)} tensors but {len(codecs)} compression codecs"
# Asynchronous serialization
loop = asyncio.get_running_loop()
serialized_tensors = await asyncio.gather(
*(
loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
for tensor, proto in zip(inputs, forward_schema)
loop.run_in_executor(None, serialize_torch_tensor, tensor, compression)
for tensor, compression in zip(flat_tensors, codecs)
)
)
# call RPC on remote server
size = sum(t.element_size() * t.nelement() for t in inputs)
size = sum(t.element_size() * t.nelement() for t in flat_tensors)
forward_fn = _forward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _forward_unary
# Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space
deserialized_outputs = await forward_fn(uid, serialized_tensors, stub, config, metadata=metadata, **kwargs)
return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"])
# Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space - TODO remove in the next PR
output_tensors = await forward_fn(
merged_uid, serialized_tensors, stub, sequence_manager.config, metadata=MSGPackSerializer.dumps(metadata)
)
# backward compatibility: ensure requires_grad; remove after https://github.com/learning-at-home/hivemind/pull/591
requires_grad = any(tensor.requires_grad for tensor in flat_tensors)
output_tensors = [tensor.requires_grad_(requires_grad) for tensor in output_tensors]
return output_tensors
async def run_remote_backward(
uid: ModuleUID,
stub: StubBase,
rpc_info: RPCInfo,
*inputs_and_grad_outputs: torch.Tensor,
config: ClientConfig,
metadata: Optional[bytes] = None,
**kwargs,
sequence_manager: RemoteSequenceManager,
peer_id: PeerID,
span_uids: Sequence[ModuleUID],
grad_outputs: Sequence[torch.Tensor],
*args: torch.Tensor,
**kwargs: torch.Tensor,
) -> Sequence[torch.Tensor]:
"""
Serializes grad outputs and calls "rpc_backward" on a remote server.
Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L221
but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
"""
args_schema, kwargs_schema = rpc_info["forward_schema"]
outputs_schema = rpc_info["outputs_schema"]
compression = args_schema[0].compression
backward_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in inputs_and_grad_outputs)
# TODO: create more explicit way to check servers schema and client's structure
assert (
len(inputs_and_grad_outputs) >= len(args_schema) + len(outputs_schema) + 1
), "Inputs, grad_outputs and prompt tensors are necessary for a backward step"
merged_uid = CHAIN_DELIMITER.join(span_uids)
stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, peer_id)
metadata = sequence_manager.get_request_metadata(peer_id, "rpc_backward", span_uids, grad_outputs, *args, **kwargs)
codecs = sequence_manager.get_compression_codecs(peer_id, "rpc_backward", span_uids, grad_outputs, *args, **kwargs)
flat_tensors, args_structure = pack_args_kwargs(grad_outputs, *args, **kwargs)
flat_tensors = tuple(tensor.cpu().detach().requires_grad_(tensor.requires_grad) for tensor in flat_tensors)
args_structure = metadata.setdefault("args_structure", args_structure)
if codecs is None:
codecs = [runtime_pb2.CompressionType.NONE] * len(flat_tensors)
else:
codecs = list(nested_flatten(codecs))
assert len(codecs) == len(flat_tensors), f"got {len(flat_tensors)} tensors but {len(codecs)} compression codecs"
# Asynchronous serialization
loop = asyncio.get_running_loop()
serialized_tensors = await asyncio.gather(
*(
loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
loop.run_in_executor(None, serialize_torch_tensor, tensor, compression)
for tensor, compression in zip(flat_tensors, codecs)
)
)
for tensor, serialized in zip(flat_tensors, serialized_tensors):
serialized.requires_grad = tensor.requires_grad # see https://github.com/learning-at-home/hivemind/pull/591
size = sum(t.element_size() * t.nelement() for t in inputs_and_grad_outputs)
size = sum(t.element_size() * t.nelement() for t in flat_tensors)
backward_fn = _backward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _backward_unary
# Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space
deserialized_grad_inputs = await backward_fn(uid, serialized_tensors, stub, config, metadata=metadata, **kwargs)
return deserialized_grad_inputs
return await backward_fn(
merged_uid, serialized_tensors, stub, sequence_manager.config, metadata=MSGPackSerializer.dumps(metadata)
)

@ -49,13 +49,13 @@ class RemoteSequential(nn.Module):
self._active_session = ContextVar("active_session", default=None)
def forward(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
def forward(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, *args, **kwargs) -> torch.Tensor:
assert inputs.ndim == 3, "inputs must be a tensor of shape [batch_size, seq_length, hidden_size]"
if self.active_session is None:
assert all(v is None for v in kwargs.values()), f"Extra kwargs are not supported in forward: {kwargs}"
return _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager)
return _RemoteSequentialAutogradFunction.apply(self.sequence_manager, inputs, prompts, *args)
else:
return self.active_session.step(inputs, prompts, **kwargs)
return self.active_session.step(inputs, prompts, *args, **kwargs)
@property
def active_session(self) -> Optional[InferenceSession]:

@ -471,21 +471,33 @@ class RemoteSequenceManager:
return min(self.config.min_backoff * 2 ** (attempt_no - 1), self.config.max_backoff)
def get_request_metadata(
self, protocol: str, args_structure: Any = None, *args, **kwargs
self, peer_id: PeerID, protocol: str, uids: Sequence[str], *args, **kwargs
) -> Optional[Dict[str, Any]]:
"""
:param peer_id: remote server's PeerID
:param protocol: one of "rpc_forward", "rpc_backward" or "rpc_inference"
:param args_structure: the structure of flattened tensors from pack_args_kwargs in petals.utils.packaging
:param args: request-specific inputs, typically block uids and input tensors
:param kwargs: additional request context, such as remote peer ID
:returns: msgpack-serialized metadata dict that will be passed alongside a given request
:param args: request-specific input tensors
:param kwargs: additional request keyword arguments
:returns: metadata dict that will be passed alongside a given request
"""
return dict(
points=self.policy.get_points(protocol, *args, **kwargs),
active_adapter=self.config.active_adapter,
args_structure=args_structure,
)
def get_compression_codecs(
self, peer_id: PeerID, protocol: str, uids: Sequence[str], *args, **kwargs
) -> Optional[Sequence[runtime_pb2.CompressionType.ValueType]]:
"""
return a sequence of compression codecs for client-side compression (applied to tensors sent to remote server)
:param peer_id: remote server's PeerID
:param protocol: one of "rpc_forward", "rpc_backward" or "rpc_inference"
:param args: request-specific input tensors
:param kwargs: additional request keyword arguments
:returns: compressions for each input tensor; contains as many elements as there are tensors in (args, kwargs)
"""
return None
def shutdown(self):
self._thread.shutdown()

@ -4,19 +4,16 @@ A PyTorch autograd function that runs forward/backward on a sequence of remote s
import asyncio
import itertools
from collections import deque
from typing import List, Optional, Sequence, Tuple
from typing import Any, Dict, List, Optional, Sequence, Tuple
import torch
from hivemind import MSGPackSerializer
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.utils.logging import get_logger
from petals.client.remote_forward_backward import run_remote_backward, run_remote_forward
from petals.client.routing import RemoteSequenceManager, maybe_log_traceback
from petals.data_structures import CHAIN_DELIMITER, RemoteSpanInfo
from petals.server.handler import TransformerConnectionHandler
from petals.data_structures import RemoteSpanInfo
from petals.utils.misc import DUMMY, is_dummy
from petals.utils.packaging import pack_args_kwargs
logger = get_logger(__name__)
@ -24,19 +21,26 @@ MAX_TOKENS_IN_BATCH = 1024
async def sequential_forward(
sequence_manager: RemoteSequenceManager,
inputs: torch.Tensor,
prompts: torch.Tensor,
sequence_manager: RemoteSequenceManager,
start_index: int = 0,
end_index: Optional[int] = None,
*block_kwargs: Dict[str, Any],
) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]:
"""
Constructs a routing path from <start_index> to <end_index>.
Performs chained forward for each subsequence of blocks on the path.
If some subsequence fails, reconstructs the remaining path and tries to finish the forward.
"""
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
:param inputs: initial hidden states of shape [batch_size, sequence length, hidden_size]
:param prompts: optional DEEP prompts, added to a prefix of each layer's outputs,
if specified, deep prompts should have shape [num_layers, batch_size, prefix_len, hid_size]
:param sequence_manager: a running SequenceManager used to select remote servers and handle failures
:param start_index: run remote blocks starting from this index
:param end_index: run remote blocks up to (but not including) this index
:param block_kwargs: optional per-block keyword arguments. Must be a sequence with one dictionary for each block
"""
inputs_device = inputs.device
inputs_dtype = inputs.dtype
@ -45,6 +49,12 @@ async def sequential_forward(
end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)
if len(block_kwargs) == 1:
block_kwargs = block_kwargs * (end_index - start_index)
assert (
not block_kwargs or len(block_kwargs) == end_index - start_index
), f"got {end_index - start_index} blocks but {len(block_kwargs)} sets of kwargs"
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
assert is_dummy(prompts) or len(prompts) == len(
sequence_manager.block_uids
) # should be n_layers - 1 but add extra prompts for convenience
@ -67,20 +77,13 @@ async def sequential_forward(
span = sequences.popleft()
stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id)
flat_tensors, args_structure = pack_args_kwargs(inputs, prompts[span.start : span.end])
span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
metadata = sequence_manager.get_request_metadata(
"rpc_forward", args_structure, span_uids, *flat_tensors
)
(outputs,) = await run_remote_forward(
span_uids,
stub,
sequence_manager.rpc_info,
*flat_tensors,
config=sequence_manager.config,
metadata=MSGPackSerializer.dumps(metadata),
sequence_manager,
span.peer_id,
sequence_manager.block_uids[span.start : span.end],
inputs,
prompts[span.start : span.end],
*block_kwargs[span.start : span.end],
)
assert isinstance(outputs, torch.Tensor)
@ -111,11 +114,12 @@ async def sequential_forward(
async def sequential_backward(
sequence_manager: RemoteSequenceManager,
forward_sequences: List[RemoteSpanInfo],
grad_outputs: Sequence[torch.Tensor],
intermediate_inputs: List[torch.Tensor],
prompts: torch.Tensor,
forward_sequences: List[RemoteSpanInfo],
sequence_manager: RemoteSequenceManager,
*block_kwargs: Dict[str, Any],
) -> Tuple[Sequence[torch.Tensor], torch.Tensor]:
"""
Performs chained backward for each forward subsequence.
@ -141,7 +145,7 @@ async def sequential_backward(
try:
if attempt_no >= 1:
_, backup_inputs, backup_sequences = await sequential_forward(
inputs, prompts, sequence_manager, start_index=span.start, end_index=span.end
sequence_manager, inputs, prompts, start_index=span.start, end_index=span.end
)
assert len(backup_inputs) == len(backup_sequences)
assert backup_sequences[0].start == span.start
@ -152,23 +156,14 @@ async def sequential_backward(
inputs = intermediate_inputs.pop()
span = forward_sequences.pop()
grad_outputs_cpu = [grad.cpu() for grad in grad_outputs]
flat_tensors, args_structure = pack_args_kwargs(
inputs, *grad_outputs_cpu, prompts[span.start : span.end]
)
span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id)
metadata = sequence_manager.get_request_metadata(
"rpc_backward", args_structure, span_uids, *flat_tensors, peer_id=span.peer_id
)
grad_outputs, *span_grad_prompts = await run_remote_backward(
span_uids,
stub,
sequence_manager.rpc_info,
*flat_tensors,
config=sequence_manager.config,
metadata=MSGPackSerializer.dumps(metadata),
sequence_manager,
span.peer_id,
sequence_manager.block_uids[span.start : span.end],
grad_outputs,
inputs,
prompts[span.start : span.end],
*block_kwargs[span.start : span.end],
)
grad_outputs = [grad_outputs]
grad_prompts_reversed.extend(span_grad_prompts)
@ -200,7 +195,7 @@ async def _gather_forward(input_batches, prompt_batches, sequence_manager):
"""Wrapper for asyncio.gather to perform parallel sequential forwards"""
return await asyncio.gather(
*[
sequential_forward(input_batch, prompt_batch, sequence_manager)
sequential_forward(sequence_manager, input_batch, prompt_batch)
for input_batch, prompt_batch in zip(input_batches, prompt_batches)
]
)
@ -212,7 +207,7 @@ async def _gather_backward(
"""Wrapper for asyncio.gather to perform parallel sequential backwards"""
return await asyncio.gather(
*[
sequential_backward((grad_output,), input_batch, prompt_batch, spans, sequence_manager)
sequential_backward(sequence_manager, spans, (grad_output,), input_batch, prompt_batch)
for grad_output, input_batch, prompt_batch, spans in zip(
grad_output_batches, intermediate_input_batches, prompt_batches, forward_sequences
)
@ -227,15 +222,17 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
"""
@staticmethod
def forward(ctx, inputs: torch.Tensor, prompts: torch.Tensor, sequence_manager: RemoteSequenceManager):
def forward(ctx, sequence_manager: RemoteSequenceManager, inputs: torch.Tensor, prompts: torch.Tensor):
# TODO add kwargs here; figure out a way to split kwargs across servers
batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1)
input_batches: Sequence[torch.Tensor] = inputs.detach().split(batch_size)
input_batches = tuple(batch.requires_grad_(inputs.requires_grad) for batch in input_batches)
if prompts is None or is_dummy(prompts):
prompt_batches = [DUMMY] * len(input_batches)
else:
prompt_batches: Sequence[torch.Tensor] = prompts.detach().split(batch_size, dim=1)
prompt_batches = tuple(batch.requires_grad_(prompts.requires_grad) for batch in prompt_batches)
sequence_manager.rpc_info # lazy init
outputs = RemoteExpertWorker.run_coroutine(_gather_forward(input_batches, prompt_batches, sequence_manager))
assert len(outputs) == len(input_batches)
@ -274,4 +271,5 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
grad_inputs = torch.cat(grad_input_batches, dim=0)
dummy_grad_prompts = [grad_prompt is None for grad_prompt in grad_prompt_batches]
grad_prompts = torch.cat(grad_prompt_batches, dim=1) if not any(dummy_grad_prompts) else None
return (grad_inputs, grad_prompts, None)
# TODO return grads w.r.t. kwargs here
return (None, grad_inputs, grad_prompts)

@ -5,7 +5,7 @@ from itertools import chain
from typing import Any, Dict, Optional, Sequence, Tuple, Union
import torch
from hivemind import BatchTensorDescriptor, TensorDescriptor
from hivemind import BatchTensorDescriptor, TensorDescriptor, nested_flatten, nested_map
from hivemind.moe.expert_uid import ExpertUID
from hivemind.moe.server.module_backend import ModuleBackend
from hivemind.utils import get_logger
@ -96,22 +96,29 @@ class TransformerBackend(ModuleBackend):
cache_tensors.extend((keys, values))
return cache_tensors
def forward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]:
*inputs, active_adapter = inputs
with self._peft_module.using_adapter(active_adapter):
return super().forward(*inputs)
def backward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]:
*inputs, active_adapter = inputs
with self._peft_module.using_adapter(active_adapter):
return super().backward(*inputs)
def forward(self, active_adapter: Optional[str], *args: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, ...]:
with self._peft_module.using_adapter(active_adapter), torch.no_grad():
return self.module(*args, **kwargs)
def backward(
self, active_adapter: Optional[str], grad_outputs: torch.Tensor, *args, **kwargs
) -> Tuple[Union[torch.Tensor, Any], ...]:
with self._peft_module.using_adapter(active_adapter), torch.enable_grad():
(outputs,) = self.module(*args, **kwargs)
assert isinstance(outputs, torch.Tensor) and outputs.shape == grad_outputs.shape
torch.autograd.backward((outputs,), grad_tensors=(grad_outputs,), create_graph=False, retain_graph=False)
return nested_map(self._get_grad_if_required, (*args, kwargs))
@staticmethod
def _get_grad_if_required(input: Any) -> Optional[torch.Tensor]:
"""Get grad w.r.t. input if input is a tensor that requires grad; otherwise return None"""
if isinstance(input, torch.Tensor) and input.requires_grad:
return input.grad if input.grad is not None else torch.zeros_like(input)
return None
@torch.inference_mode()
def inference_step(
self,
hidden_states: torch.Tensor,
hypo_ids: torch.LongTensor,
inference_info: InferenceMetadata,
self, hidden_states: torch.Tensor, hypo_ids: torch.LongTensor, inference_info: InferenceMetadata, **kwargs
) -> Tuple[torch.Tensor, ...]:
assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
seq_len = hidden_states.shape[1]
@ -129,8 +136,9 @@ class TransformerBackend(ModuleBackend):
layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length)
for offset in range(0, seq_len, max_chunk_length):
hidden_states_chunk = hidden_states[:, offset : offset + max_chunk_length, :]
kwargs_chunk = self._select_kwargs_chunk(kwargs, seq_len, offset, max_chunk_length)
output_hidden_states_chunk, new_kvs = self.module.forward(
hidden_states_chunk, layer_past=layer_past, use_cache=True
hidden_states_chunk, layer_past=layer_past, use_cache=True, **kwargs_chunk
)
if seq_len > max_chunk_length:
output_hidden_states[:, offset : offset + max_chunk_length] = output_hidden_states_chunk
@ -178,6 +186,17 @@ class TransformerBackend(ModuleBackend):
new_value = new_value.view(*cache_value.shape[:2], new_length, head_dim)
cache_value[:, :, prefix_length:new_length, :] = new_value[:, :, prefix_length:new_length, :]
@staticmethod
def _select_kwargs_chunk(kwargs: Dict[str, Any], seq_len: int, offset: int, max_chunk_length: int):
if offset == 0 and max_chunk_length >= seq_len:
return kwargs
kwargs_chunk = {}
for key, value in kwargs.items():
if isinstance(value, torch.Tensor) and value.ndim >= 2 and value.shape[-2] == seq_len:
value = value[:, offset : offset + max_chunk_length]
kwargs_chunk[key] = value
return kwargs_chunk
def get_pools(self) -> Sequence[PrioritizedTaskPool]:
return self.forward_pool, self.backward_pool, self.inference_pool
@ -200,8 +219,9 @@ def merge_inference_pools_inplace(backends: Dict[ExpertUID, TransformerBackend])
"""Replace each backend's rpc_inference pools with a combined pool runs multiple blocks in one call"""
assert len(backends) != 0 and all(isinstance(b, TransformerBackend) for b in backends.values())
first_pool = next(iter(backends.values())).inference_pool
merged_inference_func = _MergedInferenceStep(backends)
merged_pool = PrioritizedTaskPool(
_MergedInferenceStep(backends),
merged_inference_func,
max_batch_size=first_pool.max_batch_size,
device=first_pool.device,
name=f"merged_inference",
@ -222,12 +242,15 @@ class _MergedInferenceStep:
hypo_ids: torch.LongTensor,
inference_infos: Sequence[InferenceMetadata],
*optional_prompts: Optional[torch.Tensor],
block_kwargs: Sequence[Dict[str, torch.Tensor]],
) -> Tuple[torch.Tensor, ...]:
assert len(inference_infos) == len(
optional_prompts
), f"found {len(inference_infos)} blocks but {len(optional_prompts)} prompts"
for inference_info, optional_prompt in zip(inference_infos, optional_prompts):
assert (
len(inference_infos) == len(optional_prompts) == len(block_kwargs)
), f"mismatch: got {len(inference_infos)} infos, {len(optional_prompts)} prompts, {len(block_kwargs)} kwargs"
for inference_info, optional_prompt, kwargs in zip(inference_infos, optional_prompts, block_kwargs):
if optional_prompt is not None:
hidden_states[:, : optional_prompt.shape[1]] += optional_prompt
(hidden_states,) = self.backends[inference_info.uid].inference_step(hidden_states, hypo_ids, inference_info)
(hidden_states,) = self.backends[inference_info.uid].inference_step(
hidden_states, hypo_ids, inference_info, **kwargs
)
return (hidden_states,)

@ -18,7 +18,7 @@ from petals.server.task_pool import PrioritizedTaskPool
from petals.server.task_prioritizer import TaskPrioritizerBase
from petals.utils.convert_block import QuantType
from petals.utils.misc import DUMMY, is_dummy
from petals.utils.packaging import unpack_args_kwargs
from petals.utils.packaging import pack_args_kwargs, unpack_args_kwargs
# We prioritize short inference requests and make them use a *merged* inference pool,
# so they are processed without interruptions and extra overheads
@ -31,36 +31,40 @@ logger = get_logger(__name__)
async def run_rpc_forward(
*flat_tensors: torch.Tensor,
args_structure: Any,
requested_backends: Sequence[TransformerBackend],
active_adapter: str = "",
prioritizer: TaskPrioritizerBase,
points: int = 0,
args_structure: Any = None,
) -> torch.Tensor:
"""
Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
:param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors
:note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy)
:param args_structure: a schema that defines which of flat_tensors corresponds to which arg / kwarg
:note: see pack_args_kwargs function for the definition of args_structure
:param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
:param active_adapter: the name of LoRA adapter to use; defaults to no adapter
:param prioritizer: assigns priorities to each sub-request based on the number of points
:param points: client-specified number of points, used to assign priorities
:param args_structure:
:returns: hidden states after the last layer [batch_size, seq_length, hid_size]
"""
if args_structure is not None:
# TODO: kwargs currently is unused, it can be used later for peft-like adaptation
flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure)
hidden_states, prompts, *_ = flat_tensors
requires_grad = any(tensor.requires_grad for tensor in flat_tensors)
flat_tensors = tuple(tensor.detach() for tensor in flat_tensors)
(hidden_states, prompts), block_kwargs = _check_inputs(requested_backends, flat_tensors, args_structure)
dtype = requested_backends[0].dtype
# check parse input tensors and cast dtypes
hidden_states = hidden_states.to(dtype)
assert hidden_states.ndim == 3
num_tokens = hidden_states.shape[0] * hidden_states.shape[1]
if prompts is None or is_dummy(prompts):
prompts = [DUMMY] * len(requested_backends)
else:
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
# Run a chain of requested backends
for backend, prompt in zip(requested_backends, prompts):
for backend, prompt, kwargs in zip(requested_backends, prompts, block_kwargs):
if not is_dummy(prompt):
hidden_states[:, : prompt.shape[1]] += prompt
@ -69,16 +73,18 @@ async def run_rpc_forward(
hidden_states, points=points / len(requested_backends), backend=backend, type="forward"
)
(hidden_states,) = await backend.forward_pool.submit_task(
hidden_states,
active_adapter,
hidden_states,
**kwargs,
priority=priority,
size=num_tokens,
)
assert isinstance(hidden_states, torch.Tensor)
assert (
hidden_states.ndim == 3
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
return hidden_states
return hidden_states.requires_grad_(requires_grad)
async def run_rpc_backward(
@ -87,58 +93,70 @@ async def run_rpc_backward(
active_adapter: str = "",
prioritizer: TaskPrioritizerBase,
points: int = 0,
args_structure: Any = None,
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
if args_structure is not None:
# TODO: kwargs currently is unused, it can be used later for peft-like adaptation
flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure)
inputs, grad_outputs, prompts, *_ = flat_tensors
args_structure: Any,
) -> Tuple[Sequence[torch.Tensor], Any]:
"""A custom backward pass used by the server to service rpc_backward and rpc_backward_stream requests"""
assert any(x.requires_grad for x in flat_tensors), "cannot backward: none of the input tensors requires_grad"
((grad_outputs,), hidden_states, prompts), block_kwargs = _check_inputs(
requested_backends, flat_tensors, args_structure
)
input_requires_grad, prompts_requires_grad = hidden_states.requires_grad, prompts.requires_grad
# Cast inputs & grad outputs to backend dtype
inputs = inputs.to(requested_backends[0].dtype)
grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
num_tokens = hidden_states.shape[0] * hidden_states.shape[1]
hidden_states = hidden_states.detach().to(requested_backends[0].dtype)
grad_outputs = grad_outputs.detach().to(requested_backends[-1].dtype)
if prompts is None or is_dummy(prompts):
prompts = [DUMMY] * len(requested_backends)
else:
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
prompts = [p.squeeze(0).detach() for p in prompts.detach().to(requested_backends[0].dtype).split(1, dim=0)]
# Run a forward chain to collect intermediate inputs
# Note that we do not forward for the last module since we do not need its output
inter_inputs = []
for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
for backend, prompt, kwargs in zip(requested_backends[:-1], prompts[:-1], block_kwargs):
assert hidden_states.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
if not is_dummy(prompt):
inputs[:, : prompt.shape[1]] += prompt
inter_inputs.append(inputs)
hidden_states[:, : prompt.shape[1]] += prompt
inter_inputs.append(hidden_states)
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
priority = prioritizer.prioritize(
inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
hidden_states, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
)
(inputs,) = await backend.forward_pool.submit_task(inputs, active_adapter, priority=priority)
assert isinstance(inputs, torch.Tensor)
(hidden_states,) = await backend.forward_pool.submit_task(
active_adapter, hidden_states, **kwargs, priority=priority, size=num_tokens
)
assert isinstance(hidden_states, torch.Tensor), "intermediate hidden states is not a tensor"
if not is_dummy(prompts[-1]):
inputs[:, : prompts[-1].shape[1]] += prompts[-1]
inter_inputs.append(inputs)
hidden_states[:, : prompts[-1].shape[1]] += prompts[-1]
inter_inputs.append(hidden_states)
assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
grad_prompts_reversed = []
grad_block_kwargs_reversed = []
# Run a chain of requested backends
for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))):
for hidden_states, prompt, backend, kwargs in reversed(
list(zip(inter_inputs, prompts, requested_backends, block_kwargs))
):
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
hidden_states = hidden_states.detach().requires_grad_(True)
priority = prioritizer.prioritize(
inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
hidden_states, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
)
(grad_outputs, grad_kwargs) = await backend.backward_pool.submit_task(
active_adapter, grad_outputs, hidden_states, **kwargs, priority=priority, size=num_tokens
)
(grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, active_adapter, priority=priority)
assert isinstance(grad_outputs, torch.Tensor)
if not is_dummy(prompt):
if not is_dummy(prompt) and prompts_requires_grad:
grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
grad_block_kwargs_reversed.append(grad_kwargs)
grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] # TODO un-duct-tape
grad_args = [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts]
return pack_args_kwargs((grad_args, list(reversed(grad_block_kwargs_reversed))))
async def iterate_rpc_inference(
@ -161,12 +179,11 @@ async def iterate_rpc_inference(
async for request, step_metadata in input_iterator:
flat_tensors = tuple(deserialize_torch_tensor(tensor) for tensor in request.tensors)
if args_structure is not None:
# TODO: kwargs currently is unused, it can be used later for peft-like adaptation
flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure)
hidden_states, prompts, hypo_ids, *_ = flat_tensors
(hidden_states, prompts, hypo_ids), block_kwargs = _check_inputs(
requested_backends, flat_tensors, args_structure
)
batch_size, length_increment, _ = hidden_states.shape
num_tokens = batch_size * length_increment
# Cast inputs to backend dtype
hidden_states = hidden_states.to(requested_backends[0].dtype)
@ -209,13 +226,27 @@ async def iterate_rpc_inference(
for uid, handles in zip(requested_uids, cache_handles)
)
(hidden_states,) = await requested_backends[0].inference_pool.submit_task(
hidden_states, hypo_ids, inference_infos, *prompts, priority=priority
hidden_states,
hypo_ids,
inference_infos,
*prompts,
block_kwargs=block_kwargs,
priority=priority,
size=num_tokens,
)
else:
for backend, uid, handles, prompt in zip(requested_backends, requested_uids, cache_handles, prompts):
for backend, uid, handles, prompt, kwargs in zip(
requested_backends, requested_uids, cache_handles, prompts, block_kwargs
):
inference_infos = (InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter),)
(hidden_states,) = await backend.inference_pool.submit_task(
hidden_states, hypo_ids, inference_infos, prompt, priority=priority
hidden_states,
hypo_ids,
inference_infos,
prompt,
block_kwargs=(kwargs,),
priority=priority,
size=num_tokens,
)
# serialize and send last layer outputs
@ -228,3 +259,29 @@ async def iterate_rpc_inference(
# prepare for next step
prefix_length += length_increment
def _check_inputs(
requested_backends: Sequence[TransformerBackend], flat_tensors: Sequence[torch.Tensor], args_structure: Any
):
if len(flat_tensors) == 3: # backward compatibility for rpc_backward, remove after 2.3
if flat_tensors[0].requires_grad and not flat_tensors[1].requires_grad:
hidden_states, grad_outputs, prompts = flat_tensors
flat_tensors = grad_outputs, hidden_states, prompts
if args_structure is not None:
args, *block_kwargs = unpack_args_kwargs(flat_tensors, args_structure)
else:
args, *block_kwargs = flat_tensors, {} # backward compatibility for grad structure, remove at 2.2
if len(block_kwargs) not in (1, len(requested_backends)):
raise RuntimeError(
f"Server expected either one dict of keyword arguments or {len(requested_backends)} dicts "
f"(one for each block). Found {len(block_kwargs)} instead."
)
if len(block_kwargs) == 1:
block_kwargs = block_kwargs * len(requested_backends)
assert len(block_kwargs) == len(requested_backends)
for i, kwargs in enumerate(block_kwargs):
if not isinstance(kwargs, dict):
raise RuntimeError(f"Expected kwargs for block {i} to be a dictionary, got {type(kwargs)}")
return args, block_kwargs

@ -361,18 +361,19 @@ class TransformerConnectionHandler(ConnectionHandler):
active_adapter = self._get_active_adapter(metadata)
points = metadata.get("points", 0)
args_structure = metadata.get("args_structure")
assert isinstance(
points, (float, int)
), f"rpc_forward should have number of points as number or None, got {points}"
hidden_states = await run_rpc_forward(
*flat_inputs,
args_structure=args_structure,
requested_backends=requested_backends,
prioritizer=self._prioritizer,
active_adapter=active_adapter,
prioritizer=self._prioritizer,
points=points,
args_structure=args_structure,
)
return runtime_pb2.ExpertResponse(
tensors=self._serialize_outputs(hidden_states, requested_backends, metadata)
)
@ -396,11 +397,11 @@ class TransformerConnectionHandler(ConnectionHandler):
hidden_states = await run_rpc_forward(
*flat_inputs,
args_structure=args_structure,
requested_backends=requested_backends,
prioritizer=self._prioritizer,
active_adapter=active_adapter,
prioritizer=self._prioritizer,
points=points,
args_structure=args_structure,
)
# Split the serialized_output for streaming and respond to client
@ -447,16 +448,18 @@ class TransformerConnectionHandler(ConnectionHandler):
points, (float, int)
), f"rpc_backward should have number of points as number or None, got {points}"
grads = await run_rpc_backward(
flat_grads, grads_structure = await run_rpc_backward(
*flat_tensors,
requested_backends=requested_backends,
prioritizer=self._prioritizer,
active_adapter=active_adapter,
prioritizer=self._prioritizer,
points=points,
args_structure=args_structure,
)
return runtime_pb2.ExpertResponse(tensors=self._serialize_grads(grads, requested_backends, metadata))
serialized_flat_grads = self._serialize_grads(flat_grads, flat_tensors, metadata)
serialized_output_metadata = MSGPackSerializer.dumps(dict(structure=grads_structure))
return runtime_pb2.ExpertResponse(tensors=serialized_flat_grads, metadata=serialized_output_metadata)
async def rpc_backward_stream(
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
@ -474,18 +477,20 @@ class TransformerConnectionHandler(ConnectionHandler):
points, (float, int)
), f"rpc_backward_stream should have number of points as number or None, got {points}"
grads = await run_rpc_backward(
flat_grads, grad_structure = await run_rpc_backward(
*flat_tensors,
requested_backends=requested_backends,
prioritizer=self._prioritizer,
active_adapter=active_adapter,
prioritizer=self._prioritizer,
points=points,
args_structure=args_structure,
)
# Split the serialized_grad_inputs for streaming and respond
for tensor in self._serialize_grads(grads, requested_backends, metadata):
serialized_output_metadata = MSGPackSerializer.dumps(dict(structure=grad_structure))
for tensor in self._serialize_grads(flat_grads, requested_backends, metadata):
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE):
yield runtime_pb2.ExpertResponse(tensors=[part])
yield runtime_pb2.ExpertResponse(tensors=[part], metadata=serialized_output_metadata)
serialized_output_metadata = None # attach metadata to the first response only
def _get_active_adapter(self, metadata: dict) -> str:
active_adapter = metadata.get("active_adapter", "")
@ -495,28 +500,31 @@ class TransformerConnectionHandler(ConnectionHandler):
def _serialize_grads(
self,
grads: Sequence[torch.Tensor],
requested_backends: Sequence[TransformerBackend],
metadata: Dict[str, Any],
flat_grads: Sequence[torch.Tensor],
flat_inputs: Sequence[runtime_pb2.Tensor],
input_metadata: Dict[str, Any],
) -> Sequence[runtime_pb2.Tensor]:
"""Serialize backward gradients w.r.t. inputs using either default schema or custom user-specified schema"""
inputs_with_grad = tuple(input for input in flat_inputs if input.requires_grad)
assert len(flat_grads) == len(inputs_with_grad), (
f"user provides {len(inputs_with_grad)} inputs with grad, "
f"but backward produced {len(flat_grads)} gradients"
)
# Modify grad_inputs_schema to support grad_prompts
assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize
flat_grads_schema = tuple(
nested_flatten((requested_backends[0].args_schema * len(grads), requested_backends[0].kwargs_schema))
) # TODO generalize
if metadata.get("output_compression") is not None:
assert isinstance(metadata["output_compression"], (list, tuple)), "output_compression must be a tuple/list"
output_compression = tuple(metadata["output_compression"])
if input_metadata.get("output_compression") is not None:
output_compression = input_metadata["output_compression"]
assert isinstance(output_compression, (list, tuple)), "output_compression must be a tuple/list"
assert all(isinstance(c, int) for c in output_compression), "output_compression must contain integers"
assert len(output_compression) == len(grads), f"output_compression should have {len(grads)} elements"
assert len(output_compression) == len(flat_grads), (
f"output_compression should have {len(flat_grads)} "
f"elements, one for every tensor thar requires grad"
)
else:
output_compression = tuple(tensor.compression for tensor in flat_grads_schema)
output_compression = tuple(runtime_pb2.NONE for _ in flat_grads)
output_compression = tuple(output_compression)
return [
serialize_torch_tensor(result.to(proto.dtype), compression, allow_inplace=True)
for result, proto, compression in zip(grads, flat_grads_schema, output_compression)
serialize_torch_tensor(result.to(input.dtype), compression, allow_inplace=True)
for result, input, compression in zip(flat_grads, inputs_with_grad, output_compression)
]
def _check_uids(self, uids: str) -> Tuple[ModuleUID, ...]:

@ -8,7 +8,7 @@ import random
import sys
import threading
import time
from typing import Dict, List, Optional, Sequence, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import hivemind
import psutil
@ -17,6 +17,7 @@ import torch.mps
from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
from hivemind.moe.server.layers import add_custom_models_from_file
from hivemind.moe.server.runtime import Runtime
from hivemind.moe.server.task_pool import TaskPoolBase
from hivemind.proto.runtime_pb2 import CompressionType
from hivemind.utils.logging import get_logger
from transformers import PretrainedConfig
@ -773,3 +774,15 @@ class RuntimeWithDeduplicatedPools(Runtime):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.pools = tuple(set(self.pools))
def process_batch(
self, pool: TaskPoolBase, batch_index: int, args: Sequence[Any], kwargs: Dict[str, Any]
) -> Tuple[Any, int]:
"""process one batch of tasks from a given pool, return a batch of results and total batch size"""
outputs = pool.process_func(*args, **kwargs)
batch_size = 1
for arg in args:
if isinstance(arg, torch.Tensor) and arg.ndim > 2:
batch_size = arg.shape[0] * arg.shape[1]
break
return outputs, batch_size

@ -4,13 +4,17 @@ import threading
import time
from concurrent.futures._base import PENDING
from dataclasses import dataclass, field
from functools import partial
from queue import PriorityQueue
from typing import Any, List, Optional, Sequence, Tuple, Union
import torch
from hivemind import get_logger
from hivemind import get_logger, nested_map
from hivemind.moe.server.task_pool import TaskPoolBase
from hivemind.utils.mpfuture import ALL_STATES, MPFuture
from petals.utils.packaging import pack_args_kwargs, unpack_args_kwargs
logger = get_logger(__name__)
@ -18,8 +22,10 @@ logger = get_logger(__name__)
class Task:
priority: float
time_submitted: float
size: int
future: MPFuture = field(compare=False)
args: Sequence[torch.Tensor] = field(compare=False)
flat_tensors: Sequence[torch.Tensor] = field(compare=False)
structure: Any
@property
def uid(self) -> int:
@ -92,15 +98,14 @@ class PrioritizedTaskPool(threading.Thread):
def shutdown(self):
self.submitted_tasks.put(None) # Shuts down self.run()
def submit_task(self, *args: Any, priority: float = 0.0) -> MPFuture:
def submit_task(self, *args: Any, priority: float = 0.0, size: int = 1, **kwargs: Any) -> MPFuture:
"""Add task to this pool's queue, return Future for its output"""
future = MPFuture()
# Remove shmem from MPFuture. This disables the .cancel() feature but
# saves the server from "could not unlink the shared memory file" crashes during rebalancing
future._shared_state_code = torch.tensor([ALL_STATES.index(PENDING)], dtype=torch.uint8)
task = Task(priority, time.monotonic(), future, args)
if self.get_task_size(task) > self.max_batch_size:
task = Task(priority, time.monotonic(), size, future, *pack_args_kwargs(*args, **kwargs))
if task.size > self.max_batch_size:
exc = ValueError(f"Task size greater than max_batch_size ({self.max_batch_size}), it can't be processed")
task.future.set_exception(exc)
else:
@ -110,33 +115,27 @@ class PrioritizedTaskPool(threading.Thread):
self.priority = (task.priority, task.time_submitted)
return task.future
def get_task_size(self, task: Task) -> int:
"""compute task processing complexity; defaults to the total number of tokens"""
if task.args and task.args[0].ndim >= 2:
return task.args[0].shape[0] * task.args[0].shape[1]
return 1
def load_batch_to_runtime(
self, timeout: Optional[float] = None, device: Optional[torch.device] = None
) -> Tuple[Any, List[torch.Tensor]]:
) -> Tuple[int, Any]:
"""receive next batch of arrays"""
device = device if device is not None else self.device
task = self._ordered_tasks.get(block=True, timeout=timeout)
batch_inputs = [_move_to_device_if_tensor(arg, device, share_memory=False) for arg in task.args]
device_flat_tensors = [_move_to_device_if_tensor(arg, device, share_memory=False) for arg in task.flat_tensors]
self._dispatched_tasks[task.uid] = task
self.batch_receiver.recv() # reduce the number of active batches
if not self._ordered_tasks.empty():
first_remaining_task: Task = self._ordered_tasks.queue[0]
self.priority = (first_remaining_task.priority, first_remaining_task.time_submitted)
return task.uid, batch_inputs
return task.uid, unpack_args_kwargs(device_flat_tensors, task.structure)
def send_outputs_from_runtime(self, uid: int, batch_outputs: List[torch.Tensor]):
"""send results for a processed batch, previously loaded through load_batch_to_runtime"""
batch_outputs = [_move_to_device_if_tensor(output, device="cpu", share_memory=True) for output in batch_outputs]
batch_outputs = nested_map(partial(_move_to_device_if_tensor, device="cpu", share_memory=True), batch_outputs)
task = self._dispatched_tasks.pop(uid, None)
if task is None:
logger.error(
f"Internal error: task task with index {uid} is missing from the dictionary; " f"Could not set result"
f"Internal error: task task with index {uid} is missing from the dictionary; Could not set result"
)
else:
task.future.set_result(batch_outputs)

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Tuple
from typing import Any, Sequence, Tuple
import torch
from hivemind import nested_flatten, nested_pack
@ -18,7 +18,7 @@ def _get_tensor_index(item: bytes) -> int:
return int(item[3:])
def pack_args_kwargs(*args, **kwargs) -> Tuple[List[torch.Tensor], Any]:
def pack_args_kwargs(*args, **kwargs) -> Tuple[Sequence[torch.Tensor], Any]:
"""
Check the function's arguments and pack all tensors into different flattened lists.
:returns: a flattened list of tensors and args and kwargs, where tensors were masked
@ -35,7 +35,7 @@ def pack_args_kwargs(*args, **kwargs) -> Tuple[List[torch.Tensor], Any]:
return flat_tensors, nested_pack(masked_flat_values, (args, kwargs))
def unpack_args_kwargs(flat_tensors: List[torch.Tensor], args_structure: Any):
def unpack_args_kwargs(flat_tensors: Sequence[torch.Tensor], args_structure: Any):
"""
Restore arguments after `pack_args_kwargs` function.
:returns: list of args and dict of kwargs

@ -4,8 +4,8 @@ import time
import pytest
import torch
from hivemind.moe.server.runtime import Runtime
from petals.server.server import RuntimeWithDeduplicatedPools
from petals.server.task_pool import PrioritizedTaskPool
@ -57,7 +57,9 @@ def test_priority_pools():
proc = mp.context.ForkProcess(target=_submit_tasks, args=(runtime_ready, pools, results_valid))
proc.start()
runtime = Runtime({str(i): DummyBackend([pool]) for i, pool in enumerate(pools)}, prefetch_batches=0)
runtime = RuntimeWithDeduplicatedPools(
{str(i): DummyBackend([pool]) for i, pool in enumerate(pools)}, prefetch_batches=0
)
runtime.ready = runtime_ready
runtime.start()

@ -73,8 +73,8 @@ class DummyCustomSequenceManager(RemoteSequenceManager):
rpc_info["forward_schema"] = (compressed_input_schema,), dict() # (args, kwargs)
return rpc_info
def get_request_metadata(self, protocol: str, *args, **kwargs):
metadata = super().get_request_metadata(protocol, *args, **kwargs)
def get_request_metadata(self, peer_id, protocol, block_uids, *args, **kwargs):
metadata = super().get_request_metadata(peer_id, protocol, block_uids, *args, **kwargs)
if protocol == "rpc_forward":
metadata["output_compression"] = (runtime_pb2.CompressionType.FLOAT16,)
elif protocol == "rpc_backward":

Loading…
Cancel
Save