From 11d6ba683cc3339c4e6d54cebefa7c96b06960d0 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Sun, 27 Nov 2022 04:11:54 +0400 Subject: [PATCH] Make inference, forward, and backward fully fault-tolerant (#91) --- requirements.txt | 2 +- src/client/__init__.py | 2 +- src/client/inference_session.py | 212 ++++++++++++++++++++++-------- src/client/remote_sequential.py | 6 +- src/client/sequence_manager.py | 5 + src/client/sequential_autograd.py | 76 +++++++---- tests/test_block_exact_match.py | 2 +- 7 files changed, 221 insertions(+), 84 deletions(-) diff --git a/requirements.txt b/requirements.txt index 7ecd566..1afedb4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,5 +4,5 @@ accelerate==0.10.0 huggingface-hub==0.7.0 transformers==4.21.3 protobuf>=3.12.2,<4.0.0 -git+https://github.com/learning-at-home/hivemind@94c985d2dc7a79a091e46c755e9f2f4469b164c7 +git+https://github.com/learning-at-home/hivemind@8f258b4b3688f671208bf323359cb967b25d640a humanfriendly diff --git a/src/client/__init__.py b/src/client/__init__.py index e9217b1..3cda475 100644 --- a/src/client/__init__.py +++ b/src/client/__init__.py @@ -1,4 +1,4 @@ -from src.client.inference_session import RemoteSequentialInferenceSession, RemoteTransformerBlockInferenceSession +from src.client.inference_session import InferenceSession from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel from src.client.remote_sequential import RemoteSequential, RemoteTransformerBlock from src.client.sequence_manager import RemoteSequenceManager diff --git a/src/client/inference_session.py b/src/client/inference_session.py index 812e953..da45fb7 100644 --- a/src/client/inference_session.py +++ b/src/client/inference_session.py @@ -1,7 +1,8 @@ from __future__ import annotations import asyncio -import contextlib +import itertools +import time from typing import AsyncIterator, List, Optional import torch @@ -13,26 +14,25 @@ from hivemind import ( get_logger, nested_flatten, serialize_torch_tensor, - use_hivemind_log_handler, ) from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker from hivemind.p2p import StubBase from hivemind.proto import runtime_pb2 +from hivemind.utils.asyncio import aiter_with_timeout from src.client.sequence_manager import RemoteSequenceManager from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo from src.server.handler import TransformerConnectionHandler from src.utils.misc import DUMMY, is_dummy -use_hivemind_log_handler("in_root_logger") logger = get_logger(__file__) -class RemoteTransformerBlockInferenceSession: +class _ServerInferenceSession: """ - An interface to a single multi-step *inference* session for a specific remote module on a specific server + An interface to a single multi-step *inference* session for a a set of blocks on a specific server. - :note: this inference session is *not* fault-tolerant out of the box + :note: This class is *not* fault-tolerant out of the box. """ def __init__( @@ -42,32 +42,35 @@ class RemoteTransformerBlockInferenceSession: inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator, *, + timeout: float, max_length: int, points: int = 0, ): self.uid, self.rpc_info = uid, rpc_info self.num_blocks = uid.count(CHAIN_DELIMITER) + 1 - # warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread; - # using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter + self.timeout = timeout self._serialized_metadata = MSGPackSerializer.dumps(dict(max_length=max_length, points=points)) self.stepped = False self.closed = False @classmethod - async def _create( - cls, stub: StubBase, uid: ModuleUID, rpc_info: RPCInfo, timeout: Optional[float] = None, **metadata - ) -> RemoteTransformerBlockInferenceSession: + async def create( + cls, stub: StubBase, uid: ModuleUID, rpc_info: RPCInfo, timeout: float, **metadata + ) -> _ServerInferenceSession: """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker""" inputs_queue = asyncio.Queue() - outputs_stream = await stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue, timeout), timeout=timeout) - return cls(uid, rpc_info, inputs_queue, outputs_stream, **metadata) + outputs_stream = await asyncio.wait_for( + stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)), + timeout, + ) + return cls(uid, rpc_info, inputs_queue, outputs_stream, timeout=timeout, **metadata) @staticmethod - async def _read_inputs_from_queue(queue: asyncio.Queue, timeout: Optional[float]) -> AsyncIterator: + async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[float] = None) -> AsyncIterator: while True: - next_input_message = await asyncio.wait_for(queue.get(), timeout) + next_input_message = await asyncio.wait_for(queue.get(), input_timeout) yield next_input_message if not next_input_message.uid and not next_input_message.tensors: break # this message means "done sending" @@ -77,7 +80,7 @@ class RemoteTransformerBlockInferenceSession: new_hidden_states: torch.Tensor, prompts: Optional[torch.Tensor] = None, hypo_ids: Optional[torch.Tensor] = None, - ): + ) -> torch.Tensor: """ Inference step: send a chunk of input tesors and receive a chunk of outputs :prompts: optional DEEP prompts, added to a prefix of each layer's outputs, @@ -122,7 +125,7 @@ class RemoteTransformerBlockInferenceSession: """Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker""" await self._inputs_queue.put(inputs_serialized) self.stepped = True - return await anext(self._outputs_stream) + return await asyncio.wait_for(anext(self._outputs_stream), self.timeout) def close(self): """Finish a given inference session, close the underlying connection""" @@ -154,60 +157,163 @@ class RemoteTransformerBlockInferenceSession: self.close() -class RemoteSequentialInferenceSession: +class InferenceSession: """ An interface to a multi-step *inference* session for a sequence of remote transformer blocks """ - def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, timeout: Optional[float] = None, **metadata): - self.sequence_manager = sequence_manager - self.p2p = p2p - self.closed = False - self.chosen_spans: List[RemoteSpanInfo] = [] - self.stack = contextlib.ExitStack() - self.inference_sessions: List[RemoteTransformerBlockInferenceSession] = [] - self.metadata = metadata - self.timeout = timeout + def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, max_length: int, **metadata): + self._sequence_manager = sequence_manager + self._p2p = p2p + self._closed = False + self._chosen_spans = [] + self._server_sessions = [] + self._server_inputs = [] # Used in case of server failures to regenerate attention caches on new servers + self._position = 0 + self._max_length = max_length + self._metadata = metadata - def __enter__(self): - assert not self.closed and not self.chosen_spans - self.stack.__enter__() - # TODO(yozh) replace this code with a fault-tolerant chain that can be reconstructed if some peers fail - self.chosen_spans.extend(self.sequence_manager.make_sequence()) - - for chosen_span in self.chosen_spans: - stub = TransformerConnectionHandler.get_stub(self.p2p, chosen_span.peer_id) - span_uids: str = CHAIN_DELIMITER.join(self.sequence_manager.block_uids[chosen_span.start : chosen_span.end]) - inference_session = RemoteExpertWorker.run_coroutine( - RemoteTransformerBlockInferenceSession._create( - stub, span_uids, rpc_info=self.sequence_manager.rpc_info, timeout=self.timeout, **self.metadata + def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[_ServerInferenceSession]: + server_sessions = [] + try: + for span in chosen_spans: + stub = TransformerConnectionHandler.get_stub(self._p2p, span.peer_id) + span_uids = CHAIN_DELIMITER.join(self._sequence_manager.block_uids[span.start : span.end]) + session = RemoteExpertWorker.run_coroutine( + _ServerInferenceSession.create( + stub, + span_uids, + rpc_info=self._sequence_manager.rpc_info, + timeout=self._sequence_manager.timeout, + max_length=self._max_length, + **self._metadata, + ) ) - ) - self.inference_sessions.append(inference_session) - self.stack.enter_context(inference_session) + server_sessions.append(session) + session.__enter__() + return server_sessions + except: + self._exit_server_sessions(server_sessions) + raise + + def _exit_server_sessions(self, server_sessions: List[_ServerInferenceSession]) -> None: + for session in reversed(server_sessions): + try: + session.__exit__(None, None, None) + except Exception: + logger.debug("Caught exception while closing connection to server:", exc_info=True) + def __enter__(self) -> "InferenceSession": + assert not self._closed and not self._chosen_spans return self - def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs): - assert not self.closed + def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: + assert not self._closed if torch.is_grad_enabled(): logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.") + + n_blocks = len(self._sequence_manager) if prompts is None or is_dummy(prompts): prompts = DUMMY else: - assert prompts.ndim == 4 and prompts.shape[0] == len(self.sequence_manager) - for session in self.inference_sessions: - outputs = session.step(inputs, prompts[self.chosen_spans[0].start : self.chosen_spans[0].end], **kwargs) - assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}" - inputs = outputs + assert prompts.ndim == 4 and prompts.shape[0] == n_blocks + + n_input_tokens = inputs.shape[1] + if self._position + n_input_tokens > self._max_length: + raise ValueError( + f"Maximum length exceeded: prefix {self._position} + current {n_input_tokens} exceeds pre-allocated maximum {self._max_length}" + ) + + server_idx = 0 + block_idx = 0 + recovery_until = -1 # Recovery mode is disabled until a failure happens + while block_idx < n_blocks: + for attempt_no in itertools.count(): + logger.debug(f"Inference: block {block_idx}, attempt {attempt_no}") + try: + if attempt_no >= 1: + self._sequence_manager.update_() + if not self._chosen_spans or not self._server_sessions or attempt_no >= 1: + # If there is a failed server session, this code closes it + self._exit_server_sessions(self._server_sessions[server_idx : server_idx + 1]) + + n_prev_spans = len(self._chosen_spans) + update_end = self._chosen_spans[server_idx].end if server_idx < n_prev_spans else n_blocks + if attempt_no >= 1 and update_end > recovery_until: + logger.info( + f"Due to a server failure, remote attention caches " + f"from block {block_idx} to {update_end} will be regenerated" + ) + recovery_until = max(recovery_until, update_end) + + updated_spans = self._sequence_manager.make_sequence(block_idx, update_end) + # make_sequence() could return a longer sequence + updated_spans[-1].end = min(updated_spans[-1].end, update_end) + updated_sessions = self._enter_server_sessions(updated_spans) + logger.debug( + f"Found path from block {block_idx} to {update_end} via {len(updated_spans)} servers" + ) + + # If there is a failed span, this code replaces it, otherwise it just adds new ones + self._chosen_spans[server_idx : server_idx + 1] = updated_spans + self._server_sessions[server_idx : server_idx + 1] = updated_sessions + recovery_inputs = self._server_inputs[server_idx] if server_idx < n_prev_spans else None + self._server_inputs[server_idx : server_idx + 1] = [recovery_inputs] + [None] * ( + len(updated_spans) - 1 + ) + assert len(self._chosen_spans) == len(self._server_sessions) == len(self._server_inputs), ( + f"Broken state: {len(self._chosen_spans)} spans, {len(self._server_sessions)} sessions, " + f"{len(self._server_inputs)} inputs" + ) + + session = self._server_sessions[server_idx] + span = self._chosen_spans[server_idx] + + if self._server_inputs[server_idx] is None: + self._server_inputs[server_idx] = inputs + elif self._server_inputs[server_idx].shape[1] == self._position: + self._server_inputs[server_idx] = torch.cat( + [self._server_inputs[server_idx], inputs[:, -n_input_tokens:]], dim=1 + ) + assert self._server_inputs[server_idx].shape[1] == self._position + n_input_tokens, ( + f"Broken input cache: server_idx={server_idx} shape={self._server_inputs[server_idx].shape} " + f"position={self._position} n_input_tokens={n_input_tokens}" + ) + + if not session.stepped: + inputs = self._server_inputs[server_idx] # Pass full inputs including prefix + else: + inputs = inputs[:, -n_input_tokens:] # No need to pass prefix further + + outputs = session.step(inputs, prompts[span.start : span.end], **kwargs) + assert ( + inputs.shape == outputs.shape + ), f"Shape mismatch: inputs.shape={inputs.shape}, outputs.shape={outputs.shape})" + + inputs = outputs + server_idx += 1 + block_idx = span.end + break + except Exception as e: + delay = self._sequence_manager.get_retry_delay(attempt_no) + logger.warning( + f"Caught exception when running inference from block {block_idx} " + f"(retry in {delay:.0f} sec): {repr(e)}" + ) + logger.debug("See detailed traceback below:", exc_info=True) + time.sleep(delay) + + self._position += n_input_tokens return inputs def close(self, *exc_details): """Finish a given inference session, close the underlying connection""" - if not self.closed: - self.stack.__exit__(*exc_details or (None, None, None)) - self.inference_sessions.clear() - self.closed = True + if not self._closed: + self._server_inputs.clear() + self._exit_server_sessions(self._server_sessions) + self._server_sessions.clear() + self._chosen_spans.clear() + self._closed = True def __exit__(self, *exc_details): self.close(*exc_details) diff --git a/src/client/remote_sequential.py b/src/client/remote_sequential.py index d9e63b2..fb62249 100644 --- a/src/client/remote_sequential.py +++ b/src/client/remote_sequential.py @@ -8,7 +8,7 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker from torch import nn import src -from src.client.inference_session import RemoteSequentialInferenceSession +from src.client.inference_session import InferenceSession from src.client.sequence_manager import RemoteSequenceManager from src.client.sequential_autograd import _RemoteSequentialAutogradFunction from src.data_structures import UID_DELIMITER @@ -80,9 +80,9 @@ class RemoteSequential(nn.Module): def __len__(self): return len(self.sequence_manager) - def inference_session(self, **kwargs) -> RemoteSequentialInferenceSession: + def inference_session(self, **kwargs) -> InferenceSession: self.sequence_manager.update_() - return RemoteSequentialInferenceSession(self.sequence_manager, self.p2p, **kwargs) + return InferenceSession(self.sequence_manager, self.p2p, **kwargs) def extra_repr(self) -> str: return f"modules={self.sequence_manager.block_uids[0]}..{self.sequence_manager.block_uids[-1]}" diff --git a/src/client/sequence_manager.py b/src/client/sequence_manager.py index 5cd704f..de66d84 100644 --- a/src/client/sequence_manager.py +++ b/src/client/sequence_manager.py @@ -160,3 +160,8 @@ class RemoteSequenceManager: else: logger.warning(f"Tried to call rpc_info, but caught {repr(e)}", exc_info=True) return self._rpc_info + + def get_retry_delay(self, attempt_no: int) -> float: + if attempt_no == 0: + return 0 + return self.min_backoff * 2 ** (attempt_no - 1) diff --git a/src/client/sequential_autograd.py b/src/client/sequential_autograd.py index bc70882..364a6b5 100644 --- a/src/client/sequential_autograd.py +++ b/src/client/sequential_autograd.py @@ -3,11 +3,12 @@ A PyTorch autograd function that runs forward/backward on a sequence of remote s """ import asyncio import itertools -import logging +from collections import deque from typing import List, Optional, Sequence, Tuple import torch from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker +from hivemind.utils.logging import get_logger from src.client.remote_forward_backward import run_remote_backward, run_remote_forward from src.client.sequence_manager import RemoteSequenceManager @@ -15,6 +16,8 @@ from src.data_structures import CHAIN_DELIMITER, RemoteSpanInfo from src.server.handler import TransformerConnectionHandler from src.utils.misc import DUMMY, is_dummy +logger = get_logger(__file__) + MAX_TOKENS_IN_BATCH = 1024 @@ -39,19 +42,30 @@ async def sequential_forward( sequence_manager.block_uids ) # should be n_layers - 1 but add extra prompts for convenience - sequences = sequence_manager.make_sequence(start_index, end_index) + sequences = deque() intermediate_inputs = [] done_sequences = [] outputs = inputs - while len(sequences) > 0: + block_idx = start_index + while block_idx < end_index: for attempt_no in itertools.count(): - span = sequences.pop(0) - span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end]) + logger.debug(f"Forward: block {block_idx}, attempt {attempt_no}") try: + if attempt_no >= 1: + sequence_manager.update_() + if not sequences or attempt_no >= 1: + sequences = deque(sequence_manager.make_sequence(block_idx, end_index)) + # make_sequence() could return a longer sequence + sequences[-1].end = min(sequences[-1].end, end_index) + logger.debug(f"Found path from block {block_idx} to {end_index} via {len(sequences)} servers") + + span = sequences.popleft() + stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id) inputs_and_prompts = [inputs, prompts[span.start : span.end]] + span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end]) (outputs,) = await run_remote_forward( span_uids, stub, sequence_manager.rpc_info, *inputs_and_prompts, timeout=sequence_manager.timeout ) @@ -64,14 +78,16 @@ async def sequential_forward( done_sequences.append(span) inputs = outputs + block_idx = span.end break except Exception as e: - logging.warning(f"Caught {e} when running forward for chain {span.start}-{span.end}", exc_info=True) - await asyncio.sleep(sequence_manager.min_backoff * 2**attempt_no) - - backup_sequences = sequence_manager.make_sequence(span.start) - assert backup_sequences[0].start == span.start - sequences = backup_sequences + delay = sequence_manager.get_retry_delay(attempt_no) + logger.warning( + f"Caught exception when running forward from block {block_idx} " + f"(retry in {delay:.0f} sec): {repr(e)}" + ) + logger.debug("See detailed traceback below:", exc_info=True) + await asyncio.sleep(delay) return outputs, intermediate_inputs, done_sequences @@ -91,11 +107,26 @@ async def sequential_backward( grad_prompts_reversed = [] while len(forward_sequences) > 0 and len(intermediate_inputs) > 0: + inputs = intermediate_inputs.pop() + span = forward_sequences.pop() for attempt_no in itertools.count(): - inputs = intermediate_inputs.pop(-1) - span = forward_sequences.pop(-1) - span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end]) + logger.debug(f"Backward: block {span.end - 1}, attempt {attempt_no}") try: + if attempt_no >= 1: + sequence_manager.update_() + _, backup_inputs, backup_sequences = await sequential_forward( + inputs, prompts, sequence_manager, start_index=span.start, end_index=span.end + ) + assert len(backup_inputs) == len(backup_sequences) + assert backup_sequences[0].start == span.start + assert backup_sequences[-1].end == span.end + + intermediate_inputs.extend(backup_inputs) + forward_sequences.extend(backup_sequences) + inputs = intermediate_inputs.pop() + span = forward_sequences.pop() + + span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end]) stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id) grad_outputs, *span_grad_prompts = await run_remote_backward( span_uids, @@ -110,18 +141,13 @@ async def sequential_backward( grad_prompts_reversed.extend(span_grad_prompts) break except Exception as e: - logging.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True) - await asyncio.sleep(sequence_manager.min_backoff * 2**attempt_no) - - _, backup_intermediate_inputs, backup_forward_sequences = await sequential_forward( - inputs, prompts, sequence_manager, start_index=span.start, end_index=span.end + delay = sequence_manager.get_retry_delay(attempt_no) + logger.warning( + f"Caught exception when running backward between blocks {span.start}-{span.end} " + f"(retry in {delay:.0f} sec): {repr(e)}" ) - assert len(intermediate_inputs) == len(forward_sequences) - assert backup_forward_sequences[0].start == span.start - assert backup_forward_sequences[-1].end == span.end - - forward_sequences.extend(backup_forward_sequences) - intermediate_inputs.extend(backup_intermediate_inputs) + logger.debug("See detailed traceback below:", exc_info=True) + await asyncio.sleep(delay) # For now, we do not support mixed dummy and grad prompts # Concat in num_layer dimension diff --git a/tests/test_block_exact_match.py b/tests/test_block_exact_match.py index fad84ae..abe374f 100644 --- a/tests/test_block_exact_match.py +++ b/tests/test_block_exact_match.py @@ -33,7 +33,7 @@ def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3): outputs_inference.append(sess.step(inputs[:, i : i + 1, :])) # test that max length is respected - with pytest.raises(P2PHandlerError) as exc_info: + with pytest.raises(ValueError, match=r"Maximum length exceeded") as exc_info: sess.step(inputs[:, -1:, :]) assert "Maximum length exceeded" in repr(exc_info.value)