|
|
|
@ -2,6 +2,7 @@
|
|
|
|
|
A PyTorch autograd function that runs forward/backward on a sequence of remote servers in a fault-tolerant manner
|
|
|
|
|
"""
|
|
|
|
|
import asyncio
|
|
|
|
|
import itertools
|
|
|
|
|
import logging
|
|
|
|
|
from typing import List, Optional, Sequence, Tuple
|
|
|
|
|
|
|
|
|
@ -23,6 +24,7 @@ async def sequential_forward(
|
|
|
|
|
sequence_manager: RemoteSequenceManager,
|
|
|
|
|
start_index: int = 0,
|
|
|
|
|
end_index: Optional[int] = None,
|
|
|
|
|
min_backoff: float = 1.0,
|
|
|
|
|
) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]:
|
|
|
|
|
"""
|
|
|
|
|
Constructs a routing path from <start_index> to <end_index>.
|
|
|
|
@ -44,7 +46,7 @@ async def sequential_forward(
|
|
|
|
|
outputs = inputs
|
|
|
|
|
|
|
|
|
|
while len(sequences) > 0:
|
|
|
|
|
while True:
|
|
|
|
|
for attempt_no in itertools.count():
|
|
|
|
|
span = sequences.pop(0)
|
|
|
|
|
span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
|
|
|
|
|
try:
|
|
|
|
@ -64,6 +66,8 @@ async def sequential_forward(
|
|
|
|
|
break
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logging.warning(f"Caught {e} when running forward for chain {span.start}-{span.end}", exc_info=True)
|
|
|
|
|
await asyncio.sleep(min_backoff * 2**attempt_no)
|
|
|
|
|
|
|
|
|
|
backup_sequences = sequence_manager.make_sequence(span.start)
|
|
|
|
|
assert backup_sequences[0].start == span.start
|
|
|
|
|
sequences = backup_sequences
|
|
|
|
@ -77,6 +81,7 @@ async def sequential_backward(
|
|
|
|
|
prompts: torch.Tensor,
|
|
|
|
|
forward_sequences: List[RemoteSpanInfo],
|
|
|
|
|
sequence_manager: RemoteSequenceManager,
|
|
|
|
|
min_backoff: float = 1.0,
|
|
|
|
|
) -> Sequence[torch.Tensor]:
|
|
|
|
|
"""
|
|
|
|
|
Performs chained backward for each forward subsequence.
|
|
|
|
@ -86,7 +91,7 @@ async def sequential_backward(
|
|
|
|
|
|
|
|
|
|
grad_prompts_reversed = []
|
|
|
|
|
while len(forward_sequences) > 0 and len(intermediate_inputs) > 0:
|
|
|
|
|
while True:
|
|
|
|
|
for attempt_no in itertools.count():
|
|
|
|
|
inputs = intermediate_inputs.pop(-1)
|
|
|
|
|
span = forward_sequences.pop(-1)
|
|
|
|
|
span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
|
|
|
|
@ -100,6 +105,8 @@ async def sequential_backward(
|
|
|
|
|
break
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logging.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True)
|
|
|
|
|
await asyncio.sleep(min_backoff * 2**attempt_no)
|
|
|
|
|
|
|
|
|
|
_, backup_intermediate_inputs, backup_forward_sequences = await sequential_forward(
|
|
|
|
|
inputs, prompts, sequence_manager, start_index=span.start, end_index=span.end
|
|
|
|
|
)
|
|
|
|
|