black-isort

distributed-deep-ptune
justheuristic 2 years ago
parent 0791f854f8
commit cdde27af83

@ -168,11 +168,11 @@ async def sequential_backward(
while True:
inputs = intermediate_inputs.pop(-1)
span = forward_sequences.pop(-1)
span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start: span.end])
span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
try:
stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
grad_outputs, *span_grad_prompts = await run_expert_backward(
span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs, prompts[span.start: span.end]
span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs, prompts[span.start : span.end]
)
grad_outputs = [grad_outputs]
grad_prompts_reversed.extend(span_grad_prompts)
@ -180,7 +180,7 @@ async def sequential_backward(
except Exception as e:
logging.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True)
_, backup_intermediate_inputs, backup_forward_sequences = await sequential_forward(
inputs, prompts[span.start: span.end], sequence_manager, start_index=span.start, end_index=span.end
inputs, prompts[span.start : span.end], sequence_manager, start_index=span.start, end_index=span.end
)
assert len(intermediate_inputs) == len(forward_sequences)
assert backup_forward_sequences[0].start == span.start

Loading…
Cancel
Save