From b945e388e57c69a8555ad3833c39e172638e0c00 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Tue, 5 Sep 2023 16:23:26 +0400 Subject: [PATCH] Remove smaller limit for legacy bfloat16 serialization --- src/petals/client/remote_forward_backward.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/petals/client/remote_forward_backward.py b/src/petals/client/remote_forward_backward.py index 44abe26..ea2dd56 100644 --- a/src/petals/client/remote_forward_backward.py +++ b/src/petals/client/remote_forward_backward.py @@ -104,8 +104,7 @@ async def run_remote_forward( # call RPC on remote server size = sum(t.element_size() * t.nelement() for t in inputs) - forward_fn = _forward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _forward_unary - # Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space + forward_fn = _forward_stream if size > MAX_UNARY_PAYLOAD_SIZE else _forward_unary deserialized_outputs = await forward_fn(uid, serialized_tensors, stub, config, metadata=metadata, **kwargs) return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"])