|
|
|
@ -223,11 +223,6 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
|
|
|
|
|
grad_input_batches = RemoteExpertWorker.run_coroutine(
|
|
|
|
|
_gather_backward(grad_output_batches, intermediate_input_batches, forward_sequences, ctx.sequence_manager)
|
|
|
|
|
)
|
|
|
|
|
# grad_input_batches = [sequential_backward((grad_output,), input_batch, spans, ctx.sequence_manager)
|
|
|
|
|
# for grad_output, input_batch, spans in zip(
|
|
|
|
|
# grad_output_batches, intermediate_input_batches, forward_sequences
|
|
|
|
|
# )
|
|
|
|
|
# ]
|
|
|
|
|
grad_inputs = [grad_input_batch[0] for grad_input_batch in grad_input_batches]
|
|
|
|
|
grad_inputs = torch.cat(grad_inputs, dim=0)
|
|
|
|
|
return (grad_inputs, None)
|
|
|
|
|