diff --git a/requirements.txt b/requirements.txt index 68d63e1..7ecd566 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,5 +4,5 @@ accelerate==0.10.0 huggingface-hub==0.7.0 transformers==4.21.3 protobuf>=3.12.2,<4.0.0 -hivemind==1.1.2 +git+https://github.com/learning-at-home/hivemind@94c985d2dc7a79a091e46c755e9f2f4469b164c7 humanfriendly diff --git a/src/client/sequential_autograd.py b/src/client/sequential_autograd.py index 71dc77a..408e622 100644 --- a/src/client/sequential_autograd.py +++ b/src/client/sequential_autograd.py @@ -2,6 +2,7 @@ A PyTorch autograd function that runs forward/backward on a sequence of remote servers in a fault-tolerant manner """ import asyncio +import itertools import logging from typing import List, Optional, Sequence, Tuple @@ -23,6 +24,7 @@ async def sequential_forward( sequence_manager: RemoteSequenceManager, start_index: int = 0, end_index: Optional[int] = None, + min_backoff: float = 1.0, ) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]: """ Constructs a routing path from to . @@ -44,7 +46,7 @@ async def sequential_forward( outputs = inputs while len(sequences) > 0: - while True: + for attempt_no in itertools.count(): span = sequences.pop(0) span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end]) try: @@ -64,6 +66,8 @@ async def sequential_forward( break except Exception as e: logging.warning(f"Caught {e} when running forward for chain {span.start}-{span.end}", exc_info=True) + await asyncio.sleep(min_backoff * 2**attempt_no) + backup_sequences = sequence_manager.make_sequence(span.start) assert backup_sequences[0].start == span.start sequences = backup_sequences @@ -77,6 +81,7 @@ async def sequential_backward( prompts: torch.Tensor, forward_sequences: List[RemoteSpanInfo], sequence_manager: RemoteSequenceManager, + min_backoff: float = 1.0, ) -> Sequence[torch.Tensor]: """ Performs chained backward for each forward subsequence. @@ -86,7 +91,7 @@ async def sequential_backward( grad_prompts_reversed = [] while len(forward_sequences) > 0 and len(intermediate_inputs) > 0: - while True: + for attempt_no in itertools.count(): inputs = intermediate_inputs.pop(-1) span = forward_sequences.pop(-1) span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end]) @@ -100,6 +105,8 @@ async def sequential_backward( break except Exception as e: logging.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True) + await asyncio.sleep(min_backoff * 2**attempt_no) + _, backup_intermediate_inputs, backup_forward_sequences = await sequential_forward( inputs, prompts, sequence_manager, start_index=span.start, end_index=span.end )