diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index 5e14d8a..d14c4a2 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -75,7 +75,7 @@ class _ServerInferenceSession: inputs_queue = asyncio.Queue() outputs_stream = await asyncio.wait_for( stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)), - config.request_timeout, + config.connect_timeout, ) return cls(config, span, uid, rpc_info, inputs_queue, outputs_stream, **metadata) diff --git a/src/petals/client/remote_forward_backward.py b/src/petals/client/remote_forward_backward.py index df97db1..a116822 100644 --- a/src/petals/client/remote_forward_backward.py +++ b/src/petals/client/remote_forward_backward.py @@ -13,52 +13,53 @@ from hivemind.proto import runtime_pb2 from hivemind.utils.asyncio import aiter_with_timeout, iter_as_aiter from hivemind.utils.streaming import split_for_streaming +from petals.client.routing.sequence_manager import SequenceManagerConfig from petals.data_structures import ModuleUID, RPCInfo async def _forward_unary( - uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs + uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: SequenceManagerConfig, **kwargs ) -> List[torch.Tensor]: outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward( runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs), - timeout=timeout, + timeout=config.request_timeout, ) return [deserialize_torch_tensor(t) for t in outputs.tensors] async def _backward_unary( - uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs + uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: SequenceManagerConfig, **kwargs ) -> List[torch.Tensor]: grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward( runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs), - timeout=timeout, + timeout=config.request_timeout, ) return [deserialize_torch_tensor(t) for t in grad_inputs.tensors] async def _forward_stream( - uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs + uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: SequenceManagerConfig, **kwargs ) -> List[torch.Tensor]: parts = ( runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs) for tensor in serialized_tensors for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE) ) - outputs = await asyncio.wait_for(stub.rpc_forward_stream(iter_as_aiter(parts)), timeout) - outputs = aiter_with_timeout(outputs, timeout) + outputs = await asyncio.wait_for(stub.rpc_forward_stream(iter_as_aiter(parts)), config.connect_timeout) + outputs = aiter_with_timeout(outputs, config.request_timeout) return await deserialize_tensor_stream(msg.tensors async for msg in outputs) async def _backward_stream( - uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs + uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: SequenceManagerConfig, **kwargs ) -> List[torch.Tensor]: parts = ( runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs) for tensor in serialized_tensors for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE) ) - grad_inputs = await asyncio.wait_for(stub.rpc_backward_stream(iter_as_aiter(parts)), timeout) - grad_inputs = aiter_with_timeout(grad_inputs, timeout) + grad_inputs = await asyncio.wait_for(stub.rpc_backward_stream(iter_as_aiter(parts)), config.connect_timeout) + grad_inputs = aiter_with_timeout(grad_inputs, config.request_timeout) return await deserialize_tensor_stream(msg.tensors async for msg in grad_inputs) @@ -67,7 +68,7 @@ async def run_remote_forward( stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, - timeout: float, + config: SequenceManagerConfig, metadata: Optional[bytes] = None, **kwargs, ) -> Tuple[torch.Tensor, ...]: @@ -110,7 +111,7 @@ async def run_remote_forward( size = sum(t.element_size() * t.nelement() for t in inputs) forward_fn = _forward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _forward_unary # Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space - deserialized_outputs = await forward_fn(uid, serialized_tensors, stub, timeout, metadata=metadata, **kwargs) + deserialized_outputs = await forward_fn(uid, serialized_tensors, stub, config, metadata=metadata, **kwargs) return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"]) @@ -121,7 +122,7 @@ async def run_remote_backward( inputs: torch.Tensor, grad_outputs: List[torch.Tensor], *extra_tensors: torch.Tensor, - timeout: float, + config: SequenceManagerConfig, metadata: Optional[bytes] = None, **kwargs, ) -> Sequence[torch.Tensor]: @@ -153,5 +154,5 @@ async def run_remote_backward( size = sum(t.element_size() * t.nelement() for t in inputs_and_grad_outputs) backward_fn = _backward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _backward_unary # Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space - deserialized_grad_inputs = await backward_fn(uid, serialized_tensors, stub, timeout, metadata=metadata, **kwargs) + deserialized_grad_inputs = await backward_fn(uid, serialized_tensors, stub, config, metadata=metadata, **kwargs) return deserialized_grad_inputs diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py index c980412..a7a0f1d 100644 --- a/src/petals/client/routing/sequence_manager.py +++ b/src/petals/client/routing/sequence_manager.py @@ -40,6 +40,7 @@ class SequenceManagerConfig: allowed_servers: Optional[Collection[Union[PeerID, str]]] = None # if defined, send requests only to these servers use_server_to_server: bool = True # Use direct server-to-server communication + connect_timeout: float = 5 # timeout for opening a connection request_timeout: float = 3 * 60 # timeout for forward/backward/inference requests update_period: float = 60 # refresh DHT information once in this many seconds diff --git a/src/petals/client/sequential_autograd.py b/src/petals/client/sequential_autograd.py index 425fdb7..ebc56b4 100644 --- a/src/petals/client/sequential_autograd.py +++ b/src/petals/client/sequential_autograd.py @@ -76,7 +76,7 @@ async def sequential_forward( stub, sequence_manager.rpc_info, *inputs_and_prompts, - timeout=sequence_manager.config.request_timeout, + config=sequence_manager.config, metadata=MSGPackSerializer.dumps(metadata), ) @@ -161,7 +161,7 @@ async def sequential_backward( inputs, grad_outputs, prompts[span.start : span.end], - timeout=sequence_manager.config.request_timeout, + config=sequence_manager.config, metadata=MSGPackSerializer.dumps(metadata), ) grad_outputs = [grad_outputs]