|
|
|
@ -10,14 +10,60 @@ from hivemind.compression.serialization import deserialize_tensor_stream, deseri
|
|
|
|
|
from hivemind.p2p import StubBase
|
|
|
|
|
from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, MAX_UNARY_PAYLOAD_SIZE
|
|
|
|
|
from hivemind.proto import runtime_pb2
|
|
|
|
|
from hivemind.utils.asyncio import amap_in_executor, iter_as_aiter
|
|
|
|
|
from hivemind.utils.asyncio import aiter_with_timeout, iter_as_aiter
|
|
|
|
|
from hivemind.utils.streaming import split_for_streaming
|
|
|
|
|
|
|
|
|
|
from src.data_structures import ModuleUID, RPCInfo
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _forward_unary(
|
|
|
|
|
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs
|
|
|
|
|
) -> List[torch.Tensor]:
|
|
|
|
|
outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
|
|
|
|
|
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs),
|
|
|
|
|
timeout=timeout,
|
|
|
|
|
)
|
|
|
|
|
return [deserialize_torch_tensor(t) for t in outputs.tensors]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _backward_unary(
|
|
|
|
|
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs
|
|
|
|
|
) -> List[torch.Tensor]:
|
|
|
|
|
grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
|
|
|
|
|
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs),
|
|
|
|
|
timeout=timeout,
|
|
|
|
|
)
|
|
|
|
|
return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _forward_stream(
|
|
|
|
|
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs
|
|
|
|
|
) -> List[torch.Tensor]:
|
|
|
|
|
parts = (
|
|
|
|
|
runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)
|
|
|
|
|
for tensor in serialized_tensors
|
|
|
|
|
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
|
|
|
|
|
)
|
|
|
|
|
outputs = await asyncio.wait_for(stub.rpc_forward_stream(iter_as_aiter(parts)), timeout)
|
|
|
|
|
outputs = aiter_with_timeout(outputs, timeout)
|
|
|
|
|
return await deserialize_tensor_stream(msg.tensors async for msg in outputs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _backward_stream(
|
|
|
|
|
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs
|
|
|
|
|
) -> List[torch.Tensor]:
|
|
|
|
|
parts = (
|
|
|
|
|
runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)
|
|
|
|
|
for tensor in serialized_tensors
|
|
|
|
|
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
|
|
|
|
|
)
|
|
|
|
|
grad_inputs = await asyncio.wait_for(stub.rpc_backward_stream(iter_as_aiter(parts)), timeout)
|
|
|
|
|
grad_inputs = aiter_with_timeout(grad_inputs, timeout)
|
|
|
|
|
return await deserialize_tensor_stream(msg.tensors async for msg in grad_inputs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def run_remote_forward(
|
|
|
|
|
uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, metadata: bytes = b"", **kwargs
|
|
|
|
|
uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, timeout: float, **kwargs
|
|
|
|
|
) -> Tuple[torch.Tensor, ...]:
|
|
|
|
|
"""
|
|
|
|
|
Serializes input tensors and calls "rpc_forward" on a remote server.
|
|
|
|
@ -57,53 +103,13 @@ async def run_remote_forward(
|
|
|
|
|
# call RPC on remote server
|
|
|
|
|
size = sum(t.element_size() * t.nelement() for t in inputs)
|
|
|
|
|
if size > MAX_UNARY_PAYLOAD_SIZE:
|
|
|
|
|
deserialized_outputs = await _forward_stream(uid, serialized_tensors, stub, **kwargs)
|
|
|
|
|
deserialized_outputs = await _forward_stream(uid, serialized_tensors, stub, timeout, **kwargs)
|
|
|
|
|
else:
|
|
|
|
|
deserialized_outputs = await _forward_unary(uid, serialized_tensors, stub, **kwargs)
|
|
|
|
|
deserialized_outputs = await _forward_unary(uid, serialized_tensors, stub, timeout, **kwargs)
|
|
|
|
|
|
|
|
|
|
return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _forward_stream(
|
|
|
|
|
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
|
|
|
|
|
) -> List[torch.Tensor]:
|
|
|
|
|
split = (p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE))
|
|
|
|
|
|
|
|
|
|
outputs = await stub.rpc_forward_stream(
|
|
|
|
|
amap_in_executor(
|
|
|
|
|
lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor], **kwargs),
|
|
|
|
|
iter_as_aiter(split),
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
tensors_stream = amap_in_executor(lambda msg: msg.tensors, outputs)
|
|
|
|
|
return await deserialize_tensor_stream(tensors_stream)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _forward_unary(
|
|
|
|
|
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
|
|
|
|
|
) -> List[torch.Tensor]:
|
|
|
|
|
outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
|
|
|
|
|
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs)
|
|
|
|
|
)
|
|
|
|
|
return [deserialize_torch_tensor(t) for t in outputs.tensors]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _backward_stream(
|
|
|
|
|
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
|
|
|
|
|
) -> List[torch.Tensor]:
|
|
|
|
|
split = (part for tensor in serialized_tensors for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE))
|
|
|
|
|
|
|
|
|
|
grad_inputs = await stub.rpc_backward_stream(
|
|
|
|
|
amap_in_executor(
|
|
|
|
|
lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor], **kwargs),
|
|
|
|
|
iter_as_aiter(split),
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
tensors_stream = amap_in_executor(lambda msg: msg.tensors, grad_inputs)
|
|
|
|
|
return await deserialize_tensor_stream(tensors_stream)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def run_remote_backward(
|
|
|
|
|
uid: ModuleUID,
|
|
|
|
|
stub: StubBase,
|
|
|
|
@ -111,6 +117,7 @@ async def run_remote_backward(
|
|
|
|
|
inputs: torch.Tensor,
|
|
|
|
|
grad_outputs: List[torch.Tensor],
|
|
|
|
|
*extra_tensors: torch.Tensor,
|
|
|
|
|
timeout: float,
|
|
|
|
|
**kwargs,
|
|
|
|
|
) -> Sequence[torch.Tensor]:
|
|
|
|
|
"""
|
|
|
|
@ -140,17 +147,8 @@ async def run_remote_backward(
|
|
|
|
|
|
|
|
|
|
size = sum(t.element_size() * t.nelement() for t in inputs_and_grad_outputs)
|
|
|
|
|
if size > MAX_UNARY_PAYLOAD_SIZE:
|
|
|
|
|
deserialized_grad_inputs = await _backward_stream(uid, serialized_tensors, stub, **kwargs)
|
|
|
|
|
deserialized_grad_inputs = await _backward_stream(uid, serialized_tensors, stub, timeout, **kwargs)
|
|
|
|
|
else:
|
|
|
|
|
deserialized_grad_inputs = await _backward_unary(uid, serialized_tensors, stub, **kwargs)
|
|
|
|
|
deserialized_grad_inputs = await _backward_unary(uid, serialized_tensors, stub, timeout, **kwargs)
|
|
|
|
|
|
|
|
|
|
return deserialized_grad_inputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _backward_unary(
|
|
|
|
|
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
|
|
|
|
|
) -> List[torch.Tensor]:
|
|
|
|
|
grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
|
|
|
|
|
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs)
|
|
|
|
|
)
|
|
|
|
|
return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
|
|
|
|
|