|
|
|
@ -144,8 +144,8 @@ async def run_remote_backward(
|
|
|
|
|
for tensor, compression in zip(flat_tensors, codecs)
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
for tensor, serialized_tensor in zip(flat_tensors, serialized_tensors):
|
|
|
|
|
serialized_tensor.requires_grad = tensor.requires_grad
|
|
|
|
|
for tensor, serialized in zip(flat_tensors, serialized_tensors):
|
|
|
|
|
serialized.requires_grad = tensor.requires_grad # see https://github.com/learning-at-home/hivemind/pull/591
|
|
|
|
|
|
|
|
|
|
size = sum(t.element_size() * t.nelement() for t in flat_tensors)
|
|
|
|
|
backward_fn = _backward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _backward_unary
|
|
|
|
|