add todo & black

efficient-forward-backward
dbaranchuk 2 years ago
parent a2d020afcd
commit d94008311d

@ -22,7 +22,7 @@ async def run_expert_forward(
"""
Serializes input tensors and calls "expert_forward".
Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198
but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
"""
# Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
@ -39,6 +39,7 @@ async def run_expert_forward(
forward_inputs = nested_flatten(forward_inputs)
inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs)
# TODO: figure out whether we should use run_in_executor here
serialized_tensors = (
serialize_torch_tensor(tensor, proto.compression)
for tensor, proto in zip(inputs, nested_flatten(rpc_info["forward_schema"]))
@ -59,7 +60,7 @@ async def run_expert_backward(
"""
Serializes grad outputs and calls "expert_backward".
Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L221
but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
"""
grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
@ -78,7 +79,7 @@ async def sequential_forward(
inputs: torch.Tensor, sequence_manager: RemoteSequenceManager, start_index: int = 0, end_index: Optional[int] = None
) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]:
"""
Constructs a routing path from <start_index> to <end_index>.
Constructs a routing path from <start_index> to <end_index>.
Performs chained forward for each subsequence of blocks on the path.
If some subsequence fails, reconstructs the remaining path and tries to finish the forward.
"""
@ -141,7 +142,9 @@ async def sequential_backward(
span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
grad_outputs = await run_expert_backward(span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs)
grad_outputs = await run_expert_backward(
span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs
)
break
except Exception as e:
logging.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True)
@ -159,12 +162,12 @@ async def sequential_backward(
async def _gather_forward(input_batches, sequence_manager):
""" Wrapper for asyncio.gather to perform parallel sequential forwards """
"""Wrapper for asyncio.gather to perform parallel sequential forwards"""
return await asyncio.gather(*[sequential_forward(input_batch, sequence_manager) for input_batch in input_batches])
async def _gather_backward(grad_output_batches, intermediate_input_batches, forward_sequences, sequence_manager):
""" Wrapper for asyncio.gather to perform parallel sequential backwards """
"""Wrapper for asyncio.gather to perform parallel sequential backwards"""
return await asyncio.gather(
*[
sequential_backward((grad_output,), input_batch, spans, sequence_manager)

Loading…
Cancel
Save