@ -3,6 +3,7 @@ A PyTorch autograd function that runs forward/backward on a sequence of remote s
"""
import asyncio
import itertools
import logging
from collections import deque
from typing import List , Optional , Sequence , Tuple
@ -36,6 +37,11 @@ async def sequential_forward(
assert isinstance ( inputs , torch . Tensor ) and inputs . ndim == 3 , f " { type ( inputs ) } : { inputs . ndim } "
inputs_device = inputs . device
inputs_dtype = inputs . dtype
inputs = inputs . cpu ( )
prompts = prompts . cpu ( )
end_index = end_index if end_index is not None else len ( sequence_manager . block_uids )
assert start_index > = 0 and end_index < = len ( sequence_manager . block_uids )
assert is_dummy ( prompts ) or len ( prompts ) == len (
@ -86,9 +92,12 @@ async def sequential_forward(
f " Caught exception when running forward from block { block_idx } "
f " (retry in { delay : .0f } sec): { repr ( e ) } "
)
logger . debug ( " See detailed traceback below: " , exc_info = True )
traceback_level = logging . DEBUG if str ( e ) else logging . WARNING
logger . log ( traceback_level , " See detailed traceback below: " , exc_info = True )
await asyncio . sleep ( delay )
outputs = inputs . to ( device = inputs_device , dtype = inputs_dtype )
intermediate_inputs = [ tensor . to ( device = inputs_device , dtype = inputs_dtype ) for tensor in intermediate_inputs ]
return outputs , intermediate_inputs , done_sequences
@ -98,13 +107,22 @@ async def sequential_backward(
prompts : torch . Tensor ,
forward_sequences : List [ RemoteSpanInfo ] ,
sequence_manager : RemoteSequenceManager ,
) - > Sequence[ torch . Tensor ] :
) - > Tuple[ Sequence[ torch . Tensor ] , torch . Tensor ] :
"""
Performs chained backward for each forward subsequence .
If some subsequence fails , reconstructs the particular sub - path and recovers the backward .
"""
assert len ( intermediate_inputs ) == len ( forward_sequences )
grad_outputs_device = grad_outputs [ 0 ] . device if grad_outputs else None
grad_outputs_dtype = grad_outputs [ 0 ] . dtype if grad_outputs else None
prompts_device = prompts . device
prompts_dtype = prompts . dtype
grad_outputs = [ tensor . cpu ( ) for tensor in grad_outputs ]
intermediate_inputs = [ tensor . cpu ( ) for tensor in intermediate_inputs ]
prompts = prompts . cpu ( )
grad_prompts_reversed = [ ]
while len ( forward_sequences ) > 0 and len ( intermediate_inputs ) > 0 :
inputs = intermediate_inputs . pop ( )
@ -146,12 +164,18 @@ async def sequential_backward(
f " Caught exception when running backward between blocks { span . start } - { span . end } "
f " (retry in { delay : .0f } sec): { repr ( e ) } "
)
logger . debug ( " See detailed traceback below: " , exc_info = True )
traceback_level = logging . DEBUG if str ( e ) else logging . WARNING
logger . log ( traceback_level , " See detailed traceback below: " , exc_info = True )
await asyncio . sleep ( delay )
# For now, we do not support mixed dummy and grad prompts
# Concat in num_layer dimension
grad_prompts = torch . cat ( grad_prompts_reversed [ : : - 1 ] , dim = 0 ) if grad_prompts_reversed else None
if grad_outputs_dtype is not None :
grad_outputs = [ tensor . to ( device = grad_outputs_device , dtype = grad_outputs_dtype ) for tensor in grad_outputs ]
if grad_prompts is not None :
grad_prompts = grad_prompts . to ( device = prompts_device , dtype = prompts_dtype )
return grad_outputs , grad_prompts