diff --git a/src/client/remote_forward_backward.py b/src/client/remote_forward_backward.py index b8713ff..37862a4 100644 --- a/src/client/remote_forward_backward.py +++ b/src/client/remote_forward_backward.py @@ -10,14 +10,60 @@ from hivemind.compression.serialization import deserialize_tensor_stream, deseri from hivemind.p2p import StubBase from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, MAX_UNARY_PAYLOAD_SIZE from hivemind.proto import runtime_pb2 -from hivemind.utils.asyncio import amap_in_executor, iter_as_aiter +from hivemind.utils.asyncio import aiter_with_timeout, iter_as_aiter from hivemind.utils.streaming import split_for_streaming from src.data_structures import ModuleUID, RPCInfo +async def _forward_unary( + uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs +) -> List[torch.Tensor]: + outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward( + runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs), + timeout=timeout, + ) + return [deserialize_torch_tensor(t) for t in outputs.tensors] + + +async def _backward_unary( + uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs +) -> List[torch.Tensor]: + grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward( + runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs), + timeout=timeout, + ) + return [deserialize_torch_tensor(t) for t in grad_inputs.tensors] + + +async def _forward_stream( + uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs +) -> List[torch.Tensor]: + parts = ( + runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs) + for tensor in serialized_tensors + for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE) + ) + outputs = await asyncio.wait_for(stub.rpc_forward_stream(iter_as_aiter(parts)), timeout) + outputs = aiter_with_timeout(outputs, timeout) + return await deserialize_tensor_stream(msg.tensors async for msg in outputs) + + +async def _backward_stream( + uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs +) -> List[torch.Tensor]: + parts = ( + runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs) + for tensor in serialized_tensors + for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE) + ) + grad_inputs = await asyncio.wait_for(stub.rpc_backward_stream(iter_as_aiter(parts)), timeout) + grad_inputs = aiter_with_timeout(grad_inputs, timeout) + return await deserialize_tensor_stream(msg.tensors async for msg in grad_inputs) + + async def run_remote_forward( - uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, metadata: bytes = b"", **kwargs + uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, timeout: float, **kwargs ) -> Tuple[torch.Tensor, ...]: """ Serializes input tensors and calls "rpc_forward" on a remote server. @@ -57,53 +103,13 @@ async def run_remote_forward( # call RPC on remote server size = sum(t.element_size() * t.nelement() for t in inputs) if size > MAX_UNARY_PAYLOAD_SIZE: - deserialized_outputs = await _forward_stream(uid, serialized_tensors, stub, **kwargs) + deserialized_outputs = await _forward_stream(uid, serialized_tensors, stub, timeout, **kwargs) else: - deserialized_outputs = await _forward_unary(uid, serialized_tensors, stub, **kwargs) + deserialized_outputs = await _forward_unary(uid, serialized_tensors, stub, timeout, **kwargs) return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"]) -async def _forward_stream( - uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs -) -> List[torch.Tensor]: - split = (p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE)) - - outputs = await stub.rpc_forward_stream( - amap_in_executor( - lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor], **kwargs), - iter_as_aiter(split), - ), - ) - - tensors_stream = amap_in_executor(lambda msg: msg.tensors, outputs) - return await deserialize_tensor_stream(tensors_stream) - - -async def _forward_unary( - uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs -) -> List[torch.Tensor]: - outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward( - runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs) - ) - return [deserialize_torch_tensor(t) for t in outputs.tensors] - - -async def _backward_stream( - uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs -) -> List[torch.Tensor]: - split = (part for tensor in serialized_tensors for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)) - - grad_inputs = await stub.rpc_backward_stream( - amap_in_executor( - lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor], **kwargs), - iter_as_aiter(split), - ), - ) - tensors_stream = amap_in_executor(lambda msg: msg.tensors, grad_inputs) - return await deserialize_tensor_stream(tensors_stream) - - async def run_remote_backward( uid: ModuleUID, stub: StubBase, @@ -111,6 +117,7 @@ async def run_remote_backward( inputs: torch.Tensor, grad_outputs: List[torch.Tensor], *extra_tensors: torch.Tensor, + timeout: float, **kwargs, ) -> Sequence[torch.Tensor]: """ @@ -140,17 +147,8 @@ async def run_remote_backward( size = sum(t.element_size() * t.nelement() for t in inputs_and_grad_outputs) if size > MAX_UNARY_PAYLOAD_SIZE: - deserialized_grad_inputs = await _backward_stream(uid, serialized_tensors, stub, **kwargs) + deserialized_grad_inputs = await _backward_stream(uid, serialized_tensors, stub, timeout, **kwargs) else: - deserialized_grad_inputs = await _backward_unary(uid, serialized_tensors, stub, **kwargs) + deserialized_grad_inputs = await _backward_unary(uid, serialized_tensors, stub, timeout, **kwargs) return deserialized_grad_inputs - - -async def _backward_unary( - uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs -) -> List[torch.Tensor]: - grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward( - runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs) - ) - return [deserialize_torch_tensor(t) for t in grad_inputs.tensors] diff --git a/src/client/sequence_manager.py b/src/client/sequence_manager.py index 0c15163..5cd704f 100644 --- a/src/client/sequence_manager.py +++ b/src/client/sequence_manager.py @@ -24,7 +24,15 @@ class RemoteSequenceManager: In future, this class is intended to maintain latency statistics, ban non-responsive peers, etc. """ - def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID], p2p: P2P, max_retries: int = 3): + def __init__( + self, + dht: DHT, + block_uids: Sequence[ModuleUID], + p2p: P2P, + max_retries: int = 3, + timeout: float = 5, + min_backoff: float = 1, + ): assert len(block_uids) > 0, "Sequences must contain at least one block" self.dht, self.p2p = dht, p2p self.block_uids: List[ModuleUID] = list(block_uids) @@ -33,6 +41,7 @@ class RemoteSequenceManager: self.spans_containing_block: Tuple[List[RemoteSpanInfo], ...] = tuple([] for _ in range(len(self.block_uids))) self.last_update_time: DHTExpiration = -float("inf") self.max_retries = max_retries + self.timeout, self.min_backoff = timeout, min_backoff self._rpc_info = None self.lock_changes = threading.Lock() self.update_() diff --git a/src/client/sequential_autograd.py b/src/client/sequential_autograd.py index 408e622..bc70882 100644 --- a/src/client/sequential_autograd.py +++ b/src/client/sequential_autograd.py @@ -24,7 +24,6 @@ async def sequential_forward( sequence_manager: RemoteSequenceManager, start_index: int = 0, end_index: Optional[int] = None, - min_backoff: float = 1.0, ) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]: """ Constructs a routing path from to . @@ -53,7 +52,9 @@ async def sequential_forward( stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id) inputs_and_prompts = [inputs, prompts[span.start : span.end]] - (outputs,) = await run_remote_forward(span_uids, stub, sequence_manager.rpc_info, *inputs_and_prompts) + (outputs,) = await run_remote_forward( + span_uids, stub, sequence_manager.rpc_info, *inputs_and_prompts, timeout=sequence_manager.timeout + ) assert isinstance(outputs, torch.Tensor) assert outputs.shape == inputs.shape, f"Expected output {inputs.shape}, got {outputs.shape}" @@ -66,7 +67,7 @@ async def sequential_forward( break except Exception as e: logging.warning(f"Caught {e} when running forward for chain {span.start}-{span.end}", exc_info=True) - await asyncio.sleep(min_backoff * 2**attempt_no) + await asyncio.sleep(sequence_manager.min_backoff * 2**attempt_no) backup_sequences = sequence_manager.make_sequence(span.start) assert backup_sequences[0].start == span.start @@ -81,7 +82,6 @@ async def sequential_backward( prompts: torch.Tensor, forward_sequences: List[RemoteSpanInfo], sequence_manager: RemoteSequenceManager, - min_backoff: float = 1.0, ) -> Sequence[torch.Tensor]: """ Performs chained backward for each forward subsequence. @@ -98,14 +98,20 @@ async def sequential_backward( try: stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id) grad_outputs, *span_grad_prompts = await run_remote_backward( - span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs, prompts[span.start : span.end] + span_uids, + stub, + sequence_manager.rpc_info, + inputs, + grad_outputs, + prompts[span.start : span.end], + timeout=sequence_manager.timeout, ) grad_outputs = [grad_outputs] grad_prompts_reversed.extend(span_grad_prompts) break except Exception as e: logging.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True) - await asyncio.sleep(min_backoff * 2**attempt_no) + await asyncio.sleep(sequence_manager.min_backoff * 2**attempt_no) _, backup_intermediate_inputs, backup_forward_sequences = await sequential_forward( inputs, prompts, sequence_manager, start_index=span.start, end_index=span.end