Remove smaller limit for legacy bfloat16 serialization

pull/505/head
Alexander Borzunov 8 months ago committed by GitHub
parent 1ebd88ae7b
commit b945e388e5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -104,8 +104,7 @@ async def run_remote_forward(
# call RPC on remote server
size = sum(t.element_size() * t.nelement() for t in inputs)
forward_fn = _forward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _forward_unary
# Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space
forward_fn = _forward_stream if size > MAX_UNARY_PAYLOAD_SIZE else _forward_unary
deserialized_outputs = await forward_fn(uid, serialized_tensors, stub, config, metadata=metadata, **kwargs)
return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"])

Loading…
Cancel
Save