You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
261 lines
12 KiB
Python
261 lines
12 KiB
Python
2 years ago
|
"""
|
||
|
A PyTorch autograd function that runs forward/backward on a sequence of remote servers in a fault-tolerant manner
|
||
|
"""
|
||
2 years ago
|
import asyncio
|
||
2 years ago
|
import itertools
|
||
2 years ago
|
import logging
|
||
2 years ago
|
from collections import deque
|
||
2 years ago
|
from typing import List, Optional, Sequence, Tuple
|
||
|
|
||
|
import torch
|
||
|
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
||
2 years ago
|
from hivemind.utils.logging import get_logger
|
||
2 years ago
|
|
||
2 years ago
|
from petals.client.remote_forward_backward import run_remote_backward, run_remote_forward
|
||
|
from petals.client.sequence_manager import RemoteSequenceManager
|
||
|
from petals.data_structures import CHAIN_DELIMITER, RemoteSpanInfo
|
||
|
from petals.server.handler import TransformerConnectionHandler
|
||
|
from petals.utils.misc import DUMMY, is_dummy
|
||
2 years ago
|
|
||
2 years ago
|
logger = get_logger(__file__)
|
||
|
|
||
2 years ago
|
MAX_TOKENS_IN_BATCH = 1024
|
||
|
|
||
|
|
||
|
async def sequential_forward(
|
||
2 years ago
|
inputs: torch.Tensor,
|
||
|
prompts: torch.Tensor,
|
||
|
sequence_manager: RemoteSequenceManager,
|
||
|
start_index: int = 0,
|
||
|
end_index: Optional[int] = None,
|
||
2 years ago
|
) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]:
|
||
|
"""
|
||
|
Constructs a routing path from <start_index> to <end_index>.
|
||
|
Performs chained forward for each subsequence of blocks on the path.
|
||
|
If some subsequence fails, reconstructs the remaining path and tries to finish the forward.
|
||
|
"""
|
||
|
|
||
2 years ago
|
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
|
||
2 years ago
|
|
||
2 years ago
|
inputs_device = inputs.device
|
||
|
inputs_dtype = inputs.dtype
|
||
|
inputs = inputs.cpu()
|
||
|
prompts = prompts.cpu()
|
||
|
|
||
2 years ago
|
end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
|
||
|
assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)
|
||
2 years ago
|
assert is_dummy(prompts) or len(prompts) == len(
|
||
|
sequence_manager.block_uids
|
||
|
) # should be n_layers - 1 but add extra prompts for convenience
|
||
2 years ago
|
|
||
2 years ago
|
sequences = deque()
|
||
2 years ago
|
intermediate_inputs = []
|
||
|
done_sequences = []
|
||
2 years ago
|
outputs = inputs
|
||
2 years ago
|
|
||
2 years ago
|
block_idx = start_index
|
||
|
while block_idx < end_index:
|
||
2 years ago
|
for attempt_no in itertools.count():
|
||
2 years ago
|
logger.debug(f"Forward: block {block_idx}, attempt {attempt_no}")
|
||
2 years ago
|
try:
|
||
2 years ago
|
if attempt_no >= 1:
|
||
|
sequence_manager.update_()
|
||
|
if not sequences or attempt_no >= 1:
|
||
|
sequences = deque(sequence_manager.make_sequence(block_idx, end_index))
|
||
|
# make_sequence() could return a longer sequence
|
||
|
sequences[-1].end = min(sequences[-1].end, end_index)
|
||
|
logger.debug(f"Found path from block {block_idx} to {end_index} via {len(sequences)} servers")
|
||
|
|
||
|
span = sequences.popleft()
|
||
|
|
||
2 years ago
|
stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
|
||
2 years ago
|
inputs_and_prompts = [inputs, prompts[span.start : span.end]]
|
||
|
|
||
2 years ago
|
span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
|
||
2 years ago
|
(outputs,) = await run_remote_forward(
|
||
|
span_uids, stub, sequence_manager.rpc_info, *inputs_and_prompts, timeout=sequence_manager.timeout
|
||
|
)
|
||
2 years ago
|
|
||
|
assert isinstance(outputs, torch.Tensor)
|
||
|
assert outputs.shape == inputs.shape, f"Expected output {inputs.shape}, got {outputs.shape}"
|
||
|
|
||
|
# Save intermediate inputs and subsequences if the forward is already done for them
|
||
|
intermediate_inputs.append(inputs)
|
||
|
done_sequences.append(span)
|
||
|
|
||
|
inputs = outputs
|
||
2 years ago
|
block_idx = span.end
|
||
2 years ago
|
break
|
||
|
except Exception as e:
|
||
2 years ago
|
delay = sequence_manager.get_retry_delay(attempt_no)
|
||
|
logger.warning(
|
||
|
f"Caught exception when running forward from block {block_idx} "
|
||
|
f"(retry in {delay:.0f} sec): {repr(e)}"
|
||
|
)
|
||
2 years ago
|
traceback_level = logging.DEBUG if str(e) else logging.WARNING
|
||
|
logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
|
||
2 years ago
|
await asyncio.sleep(delay)
|
||
2 years ago
|
|
||
2 years ago
|
outputs = inputs.to(device=inputs_device, dtype=inputs_dtype)
|
||
|
intermediate_inputs = [tensor.to(device=inputs_device, dtype=inputs_dtype) for tensor in intermediate_inputs]
|
||
2 years ago
|
return outputs, intermediate_inputs, done_sequences
|
||
|
|
||
|
|
||
|
async def sequential_backward(
|
||
|
grad_outputs: Sequence[torch.Tensor],
|
||
2 years ago
|
intermediate_inputs: List[torch.Tensor],
|
||
|
prompts: torch.Tensor,
|
||
|
forward_sequences: List[RemoteSpanInfo],
|
||
2 years ago
|
sequence_manager: RemoteSequenceManager,
|
||
2 years ago
|
) -> Tuple[Sequence[torch.Tensor], torch.Tensor]:
|
||
2 years ago
|
"""
|
||
|
Performs chained backward for each forward subsequence.
|
||
|
If some subsequence fails, reconstructs the particular sub-path and recovers the backward.
|
||
|
"""
|
||
|
assert len(intermediate_inputs) == len(forward_sequences)
|
||
|
|
||
2 years ago
|
grad_outputs_device = grad_outputs[0].device if grad_outputs else None
|
||
|
grad_outputs_dtype = grad_outputs[0].dtype if grad_outputs else None
|
||
|
prompts_device = prompts.device
|
||
|
prompts_dtype = prompts.dtype
|
||
|
|
||
|
grad_outputs = [tensor.cpu() for tensor in grad_outputs]
|
||
|
intermediate_inputs = [tensor.cpu() for tensor in intermediate_inputs]
|
||
|
prompts = prompts.cpu()
|
||
|
|
||
2 years ago
|
grad_prompts_reversed = []
|
||
2 years ago
|
while len(forward_sequences) > 0 and len(intermediate_inputs) > 0:
|
||
2 years ago
|
inputs = intermediate_inputs.pop()
|
||
|
span = forward_sequences.pop()
|
||
2 years ago
|
for attempt_no in itertools.count():
|
||
2 years ago
|
logger.debug(f"Backward: block {span.end - 1}, attempt {attempt_no}")
|
||
2 years ago
|
try:
|
||
2 years ago
|
if attempt_no >= 1:
|
||
|
sequence_manager.update_()
|
||
|
_, backup_inputs, backup_sequences = await sequential_forward(
|
||
|
inputs, prompts, sequence_manager, start_index=span.start, end_index=span.end
|
||
|
)
|
||
|
assert len(backup_inputs) == len(backup_sequences)
|
||
|
assert backup_sequences[0].start == span.start
|
||
|
assert backup_sequences[-1].end == span.end
|
||
|
|
||
|
intermediate_inputs.extend(backup_inputs)
|
||
|
forward_sequences.extend(backup_sequences)
|
||
|
inputs = intermediate_inputs.pop()
|
||
|
span = forward_sequences.pop()
|
||
|
|
||
|
span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
|
||
2 years ago
|
stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
|
||
2 years ago
|
grad_outputs, *span_grad_prompts = await run_remote_backward(
|
||
2 years ago
|
span_uids,
|
||
|
stub,
|
||
|
sequence_manager.rpc_info,
|
||
|
inputs,
|
||
|
grad_outputs,
|
||
|
prompts[span.start : span.end],
|
||
|
timeout=sequence_manager.timeout,
|
||
2 years ago
|
)
|
||
2 years ago
|
grad_outputs = [grad_outputs]
|
||
|
grad_prompts_reversed.extend(span_grad_prompts)
|
||
2 years ago
|
break
|
||
|
except Exception as e:
|
||
2 years ago
|
delay = sequence_manager.get_retry_delay(attempt_no)
|
||
|
logger.warning(
|
||
|
f"Caught exception when running backward between blocks {span.start}-{span.end} "
|
||
|
f"(retry in {delay:.0f} sec): {repr(e)}"
|
||
2 years ago
|
)
|
||
2 years ago
|
traceback_level = logging.DEBUG if str(e) else logging.WARNING
|
||
|
logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
|
||
2 years ago
|
await asyncio.sleep(delay)
|
||
2 years ago
|
|
||
|
# For now, we do not support mixed dummy and grad prompts
|
||
|
# Concat in num_layer dimension
|
||
|
grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else None
|
||
2 years ago
|
|
||
|
if grad_outputs_dtype is not None:
|
||
|
grad_outputs = [tensor.to(device=grad_outputs_device, dtype=grad_outputs_dtype) for tensor in grad_outputs]
|
||
|
if grad_prompts is not None:
|
||
|
grad_prompts = grad_prompts.to(device=prompts_device, dtype=prompts_dtype)
|
||
2 years ago
|
return grad_outputs, grad_prompts
|
||
2 years ago
|
|
||
|
|
||
2 years ago
|
async def _gather_forward(input_batches, prompt_batches, sequence_manager):
|
||
2 years ago
|
"""Wrapper for asyncio.gather to perform parallel sequential forwards"""
|
||
2 years ago
|
return await asyncio.gather(
|
||
|
*[
|
||
|
sequential_forward(input_batch, prompt_batch, sequence_manager)
|
||
|
for input_batch, prompt_batch in zip(input_batches, prompt_batches)
|
||
|
]
|
||
|
)
|
||
2 years ago
|
|
||
|
|
||
2 years ago
|
async def _gather_backward(
|
||
|
grad_output_batches, intermediate_input_batches, prompt_batches, forward_sequences, sequence_manager
|
||
|
):
|
||
2 years ago
|
"""Wrapper for asyncio.gather to perform parallel sequential backwards"""
|
||
|
return await asyncio.gather(
|
||
|
*[
|
||
2 years ago
|
sequential_backward((grad_output,), input_batch, prompt_batch, spans, sequence_manager)
|
||
|
for grad_output, input_batch, prompt_batch, spans in zip(
|
||
|
grad_output_batches, intermediate_input_batches, prompt_batches, forward_sequences
|
||
2 years ago
|
)
|
||
|
]
|
||
|
)
|
||
|
|
||
|
|
||
|
class _RemoteSequentialAutogradFunction(torch.autograd.Function):
|
||
|
"""
|
||
|
PyTorch autograd function that provides forward and backward calls for the entire sequence of remote transformer blocks.
|
||
|
This function splits input data into batches with <MAX_TOKENS_IN_BATCH> and performs efficient parallel processing.
|
||
|
"""
|
||
|
|
||
|
@staticmethod
|
||
2 years ago
|
def forward(ctx, inputs: torch.Tensor, prompts: torch.Tensor, sequence_manager: RemoteSequenceManager):
|
||
2 years ago
|
batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1)
|
||
2 years ago
|
input_batches: Sequence[torch.Tensor] = inputs.detach().split(batch_size)
|
||
2 years ago
|
if is_dummy(prompts):
|
||
|
prompt_batches = [DUMMY] * len(input_batches)
|
||
|
else:
|
||
|
prompt_batches: Sequence[torch.Tensor] = prompts.detach().split(batch_size, dim=1)
|
||
2 years ago
|
|
||
|
sequence_manager.rpc_info # lazy init
|
||
2 years ago
|
outputs = RemoteExpertWorker.run_coroutine(_gather_forward(input_batches, prompt_batches, sequence_manager))
|
||
2 years ago
|
assert len(outputs) == len(input_batches)
|
||
|
|
||
|
output_batches = [output[0] for output in outputs]
|
||
|
intemediate_input_batches = [output[1] for output in outputs]
|
||
|
sequences_for_batches = [output[2] for output in outputs]
|
||
|
|
||
2 years ago
|
ctx.prompt_batches = prompt_batches
|
||
2 years ago
|
ctx.sequence_manager = sequence_manager
|
||
|
ctx.intemediate_input_batches = intemediate_input_batches
|
||
|
ctx.sequences_for_batches = sequences_for_batches
|
||
|
return torch.cat(output_batches, dim=0)
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx, grad_outputs: torch.Tensor):
|
||
|
intermediate_input_batches: List[Sequence[torch.Tensor]] = ctx.intemediate_input_batches
|
||
|
forward_sequences: List[Sequence[RemoteSpanInfo]] = ctx.sequences_for_batches
|
||
|
ctx.sequence_manager.rpc_info # lazy init
|
||
|
|
||
|
batch_size = max(MAX_TOKENS_IN_BATCH // grad_outputs.shape[1], 1)
|
||
|
grad_output_batches: Sequence[torch.Tensor] = grad_outputs.split(batch_size)
|
||
|
assert len(intermediate_input_batches) == len(grad_output_batches) == len(forward_sequences)
|
||
|
|
||
2 years ago
|
outputs = RemoteExpertWorker.run_coroutine(
|
||
|
_gather_backward(
|
||
|
grad_output_batches,
|
||
|
intermediate_input_batches,
|
||
|
ctx.prompt_batches,
|
||
|
forward_sequences,
|
||
|
ctx.sequence_manager,
|
||
|
)
|
||
2 years ago
|
)
|
||
2 years ago
|
grad_input_batches = [output[0][0] for output in outputs]
|
||
|
grad_prompt_batches = [output[1] for output in outputs]
|
||
|
|
||
|
grad_inputs = torch.cat(grad_input_batches, dim=0)
|
||
|
dummy_grad_prompts = [grad_prompt is None for grad_prompt in grad_prompt_batches]
|
||
|
grad_prompts = torch.cat(grad_prompt_batches, dim=1) if not any(dummy_grad_prompts) else None
|
||
|
return (grad_inputs, grad_prompts, None)
|