Refactor _{forward,backward}_stream()

forward-backward-timeouts
Aleksandr Borzunov 2 years ago
parent 4518d65fdd
commit cd829fde92

@ -10,12 +10,54 @@ from hivemind.compression.serialization import deserialize_tensor_stream, deseri
from hivemind.p2p import StubBase
from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, MAX_UNARY_PAYLOAD_SIZE
from hivemind.proto import runtime_pb2
from hivemind.utils.asyncio import amap_in_executor, iter_as_aiter
from hivemind.utils.asyncio import iter_as_aiter
from hivemind.utils.streaming import split_for_streaming
from src.data_structures import ModuleUID, RPCInfo
async def _forward_unary(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
) -> List[torch.Tensor]:
outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs)
)
return [deserialize_torch_tensor(t) for t in outputs.tensors]
async def _backward_unary(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
) -> List[torch.Tensor]:
grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs)
)
return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
async def _forward_stream(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
) -> List[torch.Tensor]:
parts = (
runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)
for tensor in serialized_tensors
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
)
outputs = await stub.rpc_forward_stream(iter_as_aiter(parts))
return await deserialize_tensor_stream(msg.tensors async for msg in outputs)
async def _backward_stream(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
) -> List[torch.Tensor]:
parts = (
runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)
for tensor in serialized_tensors
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
)
grad_inputs = await stub.rpc_backward_stream(iter_as_aiter(parts))
return await deserialize_tensor_stream(msg.tensors async for msg in grad_inputs)
async def run_remote_forward(
uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, metadata: bytes = b"", **kwargs
) -> Tuple[torch.Tensor, ...]:
@ -64,46 +106,6 @@ async def run_remote_forward(
return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"])
async def _forward_stream(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
) -> List[torch.Tensor]:
split = (p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE))
outputs = await stub.rpc_forward_stream(
amap_in_executor(
lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor], **kwargs),
iter_as_aiter(split),
),
)
tensors_stream = amap_in_executor(lambda msg: msg.tensors, outputs)
return await deserialize_tensor_stream(tensors_stream)
async def _forward_unary(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
) -> List[torch.Tensor]:
outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs)
)
return [deserialize_torch_tensor(t) for t in outputs.tensors]
async def _backward_stream(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
) -> List[torch.Tensor]:
split = (part for tensor in serialized_tensors for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE))
grad_inputs = await stub.rpc_backward_stream(
amap_in_executor(
lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor], **kwargs),
iter_as_aiter(split),
),
)
tensors_stream = amap_in_executor(lambda msg: msg.tensors, grad_inputs)
return await deserialize_tensor_stream(tensors_stream)
async def run_remote_backward(
uid: ModuleUID,
stub: StubBase,
@ -145,12 +147,3 @@ async def run_remote_backward(
deserialized_grad_inputs = await _backward_unary(uid, serialized_tensors, stub, **kwargs)
return deserialized_grad_inputs
async def _backward_unary(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
) -> List[torch.Tensor]:
grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs)
)
return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]

Loading…
Cancel
Save