Add connect_timeout (#423)

pull/428/head
Alexander Borzunov 10 months ago committed by GitHub
parent cdc0f70653
commit 44fefa5e54
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -75,7 +75,7 @@ class _ServerInferenceSession:
inputs_queue = asyncio.Queue() inputs_queue = asyncio.Queue()
outputs_stream = await asyncio.wait_for( outputs_stream = await asyncio.wait_for(
stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)), stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)),
config.request_timeout, config.connect_timeout,
) )
return cls(config, span, uid, rpc_info, inputs_queue, outputs_stream, **metadata) return cls(config, span, uid, rpc_info, inputs_queue, outputs_stream, **metadata)

@ -13,52 +13,53 @@ from hivemind.proto import runtime_pb2
from hivemind.utils.asyncio import aiter_with_timeout, iter_as_aiter from hivemind.utils.asyncio import aiter_with_timeout, iter_as_aiter
from hivemind.utils.streaming import split_for_streaming from hivemind.utils.streaming import split_for_streaming
from petals.client.routing.sequence_manager import SequenceManagerConfig
from petals.data_structures import ModuleUID, RPCInfo from petals.data_structures import ModuleUID, RPCInfo
async def _forward_unary( async def _forward_unary(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: SequenceManagerConfig, **kwargs
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward( outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs), runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs),
timeout=timeout, timeout=config.request_timeout,
) )
return [deserialize_torch_tensor(t) for t in outputs.tensors] return [deserialize_torch_tensor(t) for t in outputs.tensors]
async def _backward_unary( async def _backward_unary(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: SequenceManagerConfig, **kwargs
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward( grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs), runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs),
timeout=timeout, timeout=config.request_timeout,
) )
return [deserialize_torch_tensor(t) for t in grad_inputs.tensors] return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
async def _forward_stream( async def _forward_stream(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: SequenceManagerConfig, **kwargs
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
parts = ( parts = (
runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs) runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)
for tensor in serialized_tensors for tensor in serialized_tensors
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE) for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
) )
outputs = await asyncio.wait_for(stub.rpc_forward_stream(iter_as_aiter(parts)), timeout) outputs = await asyncio.wait_for(stub.rpc_forward_stream(iter_as_aiter(parts)), config.connect_timeout)
outputs = aiter_with_timeout(outputs, timeout) outputs = aiter_with_timeout(outputs, config.request_timeout)
return await deserialize_tensor_stream(msg.tensors async for msg in outputs) return await deserialize_tensor_stream(msg.tensors async for msg in outputs)
async def _backward_stream( async def _backward_stream(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: SequenceManagerConfig, **kwargs
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
parts = ( parts = (
runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs) runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)
for tensor in serialized_tensors for tensor in serialized_tensors
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE) for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
) )
grad_inputs = await asyncio.wait_for(stub.rpc_backward_stream(iter_as_aiter(parts)), timeout) grad_inputs = await asyncio.wait_for(stub.rpc_backward_stream(iter_as_aiter(parts)), config.connect_timeout)
grad_inputs = aiter_with_timeout(grad_inputs, timeout) grad_inputs = aiter_with_timeout(grad_inputs, config.request_timeout)
return await deserialize_tensor_stream(msg.tensors async for msg in grad_inputs) return await deserialize_tensor_stream(msg.tensors async for msg in grad_inputs)
@ -67,7 +68,7 @@ async def run_remote_forward(
stub: StubBase, stub: StubBase,
rpc_info: RPCInfo, rpc_info: RPCInfo,
*inputs: torch.Tensor, *inputs: torch.Tensor,
timeout: float, config: SequenceManagerConfig,
metadata: Optional[bytes] = None, metadata: Optional[bytes] = None,
**kwargs, **kwargs,
) -> Tuple[torch.Tensor, ...]: ) -> Tuple[torch.Tensor, ...]:
@ -110,7 +111,7 @@ async def run_remote_forward(
size = sum(t.element_size() * t.nelement() for t in inputs) size = sum(t.element_size() * t.nelement() for t in inputs)
forward_fn = _forward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _forward_unary forward_fn = _forward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _forward_unary
# Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space # Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space
deserialized_outputs = await forward_fn(uid, serialized_tensors, stub, timeout, metadata=metadata, **kwargs) deserialized_outputs = await forward_fn(uid, serialized_tensors, stub, config, metadata=metadata, **kwargs)
return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"]) return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"])
@ -121,7 +122,7 @@ async def run_remote_backward(
inputs: torch.Tensor, inputs: torch.Tensor,
grad_outputs: List[torch.Tensor], grad_outputs: List[torch.Tensor],
*extra_tensors: torch.Tensor, *extra_tensors: torch.Tensor,
timeout: float, config: SequenceManagerConfig,
metadata: Optional[bytes] = None, metadata: Optional[bytes] = None,
**kwargs, **kwargs,
) -> Sequence[torch.Tensor]: ) -> Sequence[torch.Tensor]:
@ -153,5 +154,5 @@ async def run_remote_backward(
size = sum(t.element_size() * t.nelement() for t in inputs_and_grad_outputs) size = sum(t.element_size() * t.nelement() for t in inputs_and_grad_outputs)
backward_fn = _backward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _backward_unary backward_fn = _backward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _backward_unary
# Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space # Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space
deserialized_grad_inputs = await backward_fn(uid, serialized_tensors, stub, timeout, metadata=metadata, **kwargs) deserialized_grad_inputs = await backward_fn(uid, serialized_tensors, stub, config, metadata=metadata, **kwargs)
return deserialized_grad_inputs return deserialized_grad_inputs

@ -40,6 +40,7 @@ class SequenceManagerConfig:
allowed_servers: Optional[Collection[Union[PeerID, str]]] = None # if defined, send requests only to these servers allowed_servers: Optional[Collection[Union[PeerID, str]]] = None # if defined, send requests only to these servers
use_server_to_server: bool = True # Use direct server-to-server communication use_server_to_server: bool = True # Use direct server-to-server communication
connect_timeout: float = 5 # timeout for opening a connection
request_timeout: float = 3 * 60 # timeout for forward/backward/inference requests request_timeout: float = 3 * 60 # timeout for forward/backward/inference requests
update_period: float = 60 # refresh DHT information once in this many seconds update_period: float = 60 # refresh DHT information once in this many seconds

@ -76,7 +76,7 @@ async def sequential_forward(
stub, stub,
sequence_manager.rpc_info, sequence_manager.rpc_info,
*inputs_and_prompts, *inputs_and_prompts,
timeout=sequence_manager.config.request_timeout, config=sequence_manager.config,
metadata=MSGPackSerializer.dumps(metadata), metadata=MSGPackSerializer.dumps(metadata),
) )
@ -161,7 +161,7 @@ async def sequential_backward(
inputs, inputs,
grad_outputs, grad_outputs,
prompts[span.start : span.end], prompts[span.start : span.end],
timeout=sequence_manager.config.request_timeout, config=sequence_manager.config,
metadata=MSGPackSerializer.dumps(metadata), metadata=MSGPackSerializer.dumps(metadata),
) )
grad_outputs = [grad_outputs] grad_outputs = [grad_outputs]

Loading…
Cancel
Save