diff --git a/src/client/remote_forward_backward.py b/src/client/remote_forward_backward.py index 34fa3b5..b8713ff 100644 --- a/src/client/remote_forward_backward.py +++ b/src/client/remote_forward_backward.py @@ -111,7 +111,7 @@ async def run_remote_backward( inputs: torch.Tensor, grad_outputs: List[torch.Tensor], *extra_tensors: torch.Tensor, - metadata: bytes = b"", + **kwargs, ) -> Sequence[torch.Tensor]: """ Serializes grad outputs and calls "rpc_backward" on a remote server.