Implement exponential backoff for forward & backward (#85)

add-sst2-example
Alexander Borzunov 2 years ago committed by GitHub
parent ee4e69c254
commit 57e8d2e721
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -4,5 +4,5 @@ accelerate==0.10.0
huggingface-hub==0.7.0
transformers==4.21.3
protobuf>=3.12.2,<4.0.0
hivemind==1.1.2
git+https://github.com/learning-at-home/hivemind@94c985d2dc7a79a091e46c755e9f2f4469b164c7
humanfriendly

@ -2,6 +2,7 @@
A PyTorch autograd function that runs forward/backward on a sequence of remote servers in a fault-tolerant manner
"""
import asyncio
import itertools
import logging
from typing import List, Optional, Sequence, Tuple
@ -23,6 +24,7 @@ async def sequential_forward(
sequence_manager: RemoteSequenceManager,
start_index: int = 0,
end_index: Optional[int] = None,
min_backoff: float = 1.0,
) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]:
"""
Constructs a routing path from <start_index> to <end_index>.
@ -44,7 +46,7 @@ async def sequential_forward(
outputs = inputs
while len(sequences) > 0:
while True:
for attempt_no in itertools.count():
span = sequences.pop(0)
span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
try:
@ -64,6 +66,8 @@ async def sequential_forward(
break
except Exception as e:
logging.warning(f"Caught {e} when running forward for chain {span.start}-{span.end}", exc_info=True)
await asyncio.sleep(min_backoff * 2**attempt_no)
backup_sequences = sequence_manager.make_sequence(span.start)
assert backup_sequences[0].start == span.start
sequences = backup_sequences
@ -77,6 +81,7 @@ async def sequential_backward(
prompts: torch.Tensor,
forward_sequences: List[RemoteSpanInfo],
sequence_manager: RemoteSequenceManager,
min_backoff: float = 1.0,
) -> Sequence[torch.Tensor]:
"""
Performs chained backward for each forward subsequence.
@ -86,7 +91,7 @@ async def sequential_backward(
grad_prompts_reversed = []
while len(forward_sequences) > 0 and len(intermediate_inputs) > 0:
while True:
for attempt_no in itertools.count():
inputs = intermediate_inputs.pop(-1)
span = forward_sequences.pop(-1)
span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
@ -100,6 +105,8 @@ async def sequential_backward(
break
except Exception as e:
logging.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True)
await asyncio.sleep(min_backoff * 2**attempt_no)
_, backup_intermediate_inputs, backup_forward_sequences = await sequential_forward(
inputs, prompts, sequence_manager, start_index=span.start, end_index=span.end
)

Loading…
Cancel
Save