|
|
|
@ -58,6 +58,7 @@ async def run_remote_forward(
|
|
|
|
|
size = sum(t.element_size() * t.nelement() for t in inputs)
|
|
|
|
|
if size > MAX_UNARY_PAYLOAD_SIZE:
|
|
|
|
|
deserialized_outputs = await _forward_stream(uid, serialized_tensors, stub, **kwargs)
|
|
|
|
|
raise 123
|
|
|
|
|
else:
|
|
|
|
|
deserialized_outputs = await _forward_unary(uid, serialized_tensors, stub, **kwargs)
|
|
|
|
|
|
|
|
|
@ -83,9 +84,11 @@ async def _forward_stream(
|
|
|
|
|
async def _forward_unary(
|
|
|
|
|
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
|
|
|
|
|
) -> List[torch.Tensor]:
|
|
|
|
|
print(end='client - forward - before\n', flush=True)
|
|
|
|
|
outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
|
|
|
|
|
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs)
|
|
|
|
|
)
|
|
|
|
|
print(end='client - forward - after\n', flush=True)
|
|
|
|
|
return [deserialize_torch_tensor(t) for t in outputs.tensors]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|