|
|
|
@ -22,7 +22,7 @@ async def run_expert_forward(
|
|
|
|
|
"""
|
|
|
|
|
Serializes input tensors and calls "expert_forward".
|
|
|
|
|
Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198
|
|
|
|
|
but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
|
|
|
|
|
but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
|
|
|
|
@ -39,6 +39,7 @@ async def run_expert_forward(
|
|
|
|
|
forward_inputs = nested_flatten(forward_inputs)
|
|
|
|
|
inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs)
|
|
|
|
|
|
|
|
|
|
# TODO: figure out whether we should use run_in_executor here
|
|
|
|
|
serialized_tensors = (
|
|
|
|
|
serialize_torch_tensor(tensor, proto.compression)
|
|
|
|
|
for tensor, proto in zip(inputs, nested_flatten(rpc_info["forward_schema"]))
|
|
|
|
@ -59,7 +60,7 @@ async def run_expert_backward(
|
|
|
|
|
"""
|
|
|
|
|
Serializes grad outputs and calls "expert_backward".
|
|
|
|
|
Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L221
|
|
|
|
|
but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
|
|
|
|
|
but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
|
|
|
|
@ -78,7 +79,7 @@ async def sequential_forward(
|
|
|
|
|
inputs: torch.Tensor, sequence_manager: RemoteSequenceManager, start_index: int = 0, end_index: Optional[int] = None
|
|
|
|
|
) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]:
|
|
|
|
|
"""
|
|
|
|
|
Constructs a routing path from <start_index> to <end_index>.
|
|
|
|
|
Constructs a routing path from <start_index> to <end_index>.
|
|
|
|
|
Performs chained forward for each subsequence of blocks on the path.
|
|
|
|
|
If some subsequence fails, reconstructs the remaining path and tries to finish the forward.
|
|
|
|
|
"""
|
|
|
|
@ -141,7 +142,9 @@ async def sequential_backward(
|
|
|
|
|
span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
|
|
|
|
|
stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
|
|
|
|
|
|
|
|
|
|
grad_outputs = await run_expert_backward(span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs)
|
|
|
|
|
grad_outputs = await run_expert_backward(
|
|
|
|
|
span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs
|
|
|
|
|
)
|
|
|
|
|
break
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logging.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True)
|
|
|
|
@ -159,12 +162,12 @@ async def sequential_backward(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _gather_forward(input_batches, sequence_manager):
|
|
|
|
|
""" Wrapper for asyncio.gather to perform parallel sequential forwards """
|
|
|
|
|
"""Wrapper for asyncio.gather to perform parallel sequential forwards"""
|
|
|
|
|
return await asyncio.gather(*[sequential_forward(input_batch, sequence_manager) for input_batch in input_batches])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _gather_backward(grad_output_batches, intermediate_input_batches, forward_sequences, sequence_manager):
|
|
|
|
|
""" Wrapper for asyncio.gather to perform parallel sequential backwards """
|
|
|
|
|
"""Wrapper for asyncio.gather to perform parallel sequential backwards"""
|
|
|
|
|
return await asyncio.gather(
|
|
|
|
|
*[
|
|
|
|
|
sequential_backward((grad_output,), input_batch, spans, sequence_manager)
|
|
|
|
|