Make inference, forward, and backward fully fault-tolerant (#91)

fix-joining-announce
Alexander Borzunov 2 years ago committed by GitHub
parent 695df826c2
commit 11d6ba683c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -4,5 +4,5 @@ accelerate==0.10.0
huggingface-hub==0.7.0
transformers==4.21.3
protobuf>=3.12.2,<4.0.0
git+https://github.com/learning-at-home/hivemind@94c985d2dc7a79a091e46c755e9f2f4469b164c7
git+https://github.com/learning-at-home/hivemind@8f258b4b3688f671208bf323359cb967b25d640a
humanfriendly

@ -1,4 +1,4 @@
from src.client.inference_session import RemoteSequentialInferenceSession, RemoteTransformerBlockInferenceSession
from src.client.inference_session import InferenceSession
from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
from src.client.remote_sequential import RemoteSequential, RemoteTransformerBlock
from src.client.sequence_manager import RemoteSequenceManager

@ -1,7 +1,8 @@
from __future__ import annotations
import asyncio
import contextlib
import itertools
import time
from typing import AsyncIterator, List, Optional
import torch
@ -13,26 +14,25 @@ from hivemind import (
get_logger,
nested_flatten,
serialize_torch_tensor,
use_hivemind_log_handler,
)
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.p2p import StubBase
from hivemind.proto import runtime_pb2
from hivemind.utils.asyncio import aiter_with_timeout
from src.client.sequence_manager import RemoteSequenceManager
from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
from src.server.handler import TransformerConnectionHandler
from src.utils.misc import DUMMY, is_dummy
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
class RemoteTransformerBlockInferenceSession:
class _ServerInferenceSession:
"""
An interface to a single multi-step *inference* session for a specific remote module on a specific server
An interface to a single multi-step *inference* session for a a set of blocks on a specific server.
:note: this inference session is *not* fault-tolerant out of the box
:note: This class is *not* fault-tolerant out of the box.
"""
def __init__(
@ -42,32 +42,35 @@ class RemoteTransformerBlockInferenceSession:
inputs_queue: asyncio.Queue,
outputs_aiter: AsyncIterator,
*,
timeout: float,
max_length: int,
points: int = 0,
):
self.uid, self.rpc_info = uid, rpc_info
self.num_blocks = uid.count(CHAIN_DELIMITER) + 1
# warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread;
# using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
self.timeout = timeout
self._serialized_metadata = MSGPackSerializer.dumps(dict(max_length=max_length, points=points))
self.stepped = False
self.closed = False
@classmethod
async def _create(
cls, stub: StubBase, uid: ModuleUID, rpc_info: RPCInfo, timeout: Optional[float] = None, **metadata
) -> RemoteTransformerBlockInferenceSession:
async def create(
cls, stub: StubBase, uid: ModuleUID, rpc_info: RPCInfo, timeout: float, **metadata
) -> _ServerInferenceSession:
"""Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
inputs_queue = asyncio.Queue()
outputs_stream = await stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue, timeout), timeout=timeout)
return cls(uid, rpc_info, inputs_queue, outputs_stream, **metadata)
outputs_stream = await asyncio.wait_for(
stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)),
timeout,
)
return cls(uid, rpc_info, inputs_queue, outputs_stream, timeout=timeout, **metadata)
@staticmethod
async def _read_inputs_from_queue(queue: asyncio.Queue, timeout: Optional[float]) -> AsyncIterator:
async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[float] = None) -> AsyncIterator:
while True:
next_input_message = await asyncio.wait_for(queue.get(), timeout)
next_input_message = await asyncio.wait_for(queue.get(), input_timeout)
yield next_input_message
if not next_input_message.uid and not next_input_message.tensors:
break # this message means "done sending"
@ -77,7 +80,7 @@ class RemoteTransformerBlockInferenceSession:
new_hidden_states: torch.Tensor,
prompts: Optional[torch.Tensor] = None,
hypo_ids: Optional[torch.Tensor] = None,
):
) -> torch.Tensor:
"""
Inference step: send a chunk of input tesors and receive a chunk of outputs
:prompts: optional DEEP prompts, added to a prefix of each layer's outputs,
@ -122,7 +125,7 @@ class RemoteTransformerBlockInferenceSession:
"""Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
await self._inputs_queue.put(inputs_serialized)
self.stepped = True
return await anext(self._outputs_stream)
return await asyncio.wait_for(anext(self._outputs_stream), self.timeout)
def close(self):
"""Finish a given inference session, close the underlying connection"""
@ -154,60 +157,163 @@ class RemoteTransformerBlockInferenceSession:
self.close()
class RemoteSequentialInferenceSession:
class InferenceSession:
"""
An interface to a multi-step *inference* session for a sequence of remote transformer blocks
"""
def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, timeout: Optional[float] = None, **metadata):
self.sequence_manager = sequence_manager
self.p2p = p2p
self.closed = False
self.chosen_spans: List[RemoteSpanInfo] = []
self.stack = contextlib.ExitStack()
self.inference_sessions: List[RemoteTransformerBlockInferenceSession] = []
self.metadata = metadata
self.timeout = timeout
def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, max_length: int, **metadata):
self._sequence_manager = sequence_manager
self._p2p = p2p
self._closed = False
self._chosen_spans = []
self._server_sessions = []
self._server_inputs = [] # Used in case of server failures to regenerate attention caches on new servers
self._position = 0
self._max_length = max_length
self._metadata = metadata
def __enter__(self):
assert not self.closed and not self.chosen_spans
self.stack.__enter__()
# TODO(yozh) replace this code with a fault-tolerant chain that can be reconstructed if some peers fail
self.chosen_spans.extend(self.sequence_manager.make_sequence())
for chosen_span in self.chosen_spans:
stub = TransformerConnectionHandler.get_stub(self.p2p, chosen_span.peer_id)
span_uids: str = CHAIN_DELIMITER.join(self.sequence_manager.block_uids[chosen_span.start : chosen_span.end])
inference_session = RemoteExpertWorker.run_coroutine(
RemoteTransformerBlockInferenceSession._create(
stub, span_uids, rpc_info=self.sequence_manager.rpc_info, timeout=self.timeout, **self.metadata
def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[_ServerInferenceSession]:
server_sessions = []
try:
for span in chosen_spans:
stub = TransformerConnectionHandler.get_stub(self._p2p, span.peer_id)
span_uids = CHAIN_DELIMITER.join(self._sequence_manager.block_uids[span.start : span.end])
session = RemoteExpertWorker.run_coroutine(
_ServerInferenceSession.create(
stub,
span_uids,
rpc_info=self._sequence_manager.rpc_info,
timeout=self._sequence_manager.timeout,
max_length=self._max_length,
**self._metadata,
)
)
)
self.inference_sessions.append(inference_session)
self.stack.enter_context(inference_session)
server_sessions.append(session)
session.__enter__()
return server_sessions
except:
self._exit_server_sessions(server_sessions)
raise
def _exit_server_sessions(self, server_sessions: List[_ServerInferenceSession]) -> None:
for session in reversed(server_sessions):
try:
session.__exit__(None, None, None)
except Exception:
logger.debug("Caught exception while closing connection to server:", exc_info=True)
def __enter__(self) -> "InferenceSession":
assert not self._closed and not self._chosen_spans
return self
def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs):
assert not self.closed
def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
assert not self._closed
if torch.is_grad_enabled():
logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
n_blocks = len(self._sequence_manager)
if prompts is None or is_dummy(prompts):
prompts = DUMMY
else:
assert prompts.ndim == 4 and prompts.shape[0] == len(self.sequence_manager)
for session in self.inference_sessions:
outputs = session.step(inputs, prompts[self.chosen_spans[0].start : self.chosen_spans[0].end], **kwargs)
assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
inputs = outputs
assert prompts.ndim == 4 and prompts.shape[0] == n_blocks
n_input_tokens = inputs.shape[1]
if self._position + n_input_tokens > self._max_length:
raise ValueError(
f"Maximum length exceeded: prefix {self._position} + current {n_input_tokens} exceeds pre-allocated maximum {self._max_length}"
)
server_idx = 0
block_idx = 0
recovery_until = -1 # Recovery mode is disabled until a failure happens
while block_idx < n_blocks:
for attempt_no in itertools.count():
logger.debug(f"Inference: block {block_idx}, attempt {attempt_no}")
try:
if attempt_no >= 1:
self._sequence_manager.update_()
if not self._chosen_spans or not self._server_sessions or attempt_no >= 1:
# If there is a failed server session, this code closes it
self._exit_server_sessions(self._server_sessions[server_idx : server_idx + 1])
n_prev_spans = len(self._chosen_spans)
update_end = self._chosen_spans[server_idx].end if server_idx < n_prev_spans else n_blocks
if attempt_no >= 1 and update_end > recovery_until:
logger.info(
f"Due to a server failure, remote attention caches "
f"from block {block_idx} to {update_end} will be regenerated"
)
recovery_until = max(recovery_until, update_end)
updated_spans = self._sequence_manager.make_sequence(block_idx, update_end)
# make_sequence() could return a longer sequence
updated_spans[-1].end = min(updated_spans[-1].end, update_end)
updated_sessions = self._enter_server_sessions(updated_spans)
logger.debug(
f"Found path from block {block_idx} to {update_end} via {len(updated_spans)} servers"
)
# If there is a failed span, this code replaces it, otherwise it just adds new ones
self._chosen_spans[server_idx : server_idx + 1] = updated_spans
self._server_sessions[server_idx : server_idx + 1] = updated_sessions
recovery_inputs = self._server_inputs[server_idx] if server_idx < n_prev_spans else None
self._server_inputs[server_idx : server_idx + 1] = [recovery_inputs] + [None] * (
len(updated_spans) - 1
)
assert len(self._chosen_spans) == len(self._server_sessions) == len(self._server_inputs), (
f"Broken state: {len(self._chosen_spans)} spans, {len(self._server_sessions)} sessions, "
f"{len(self._server_inputs)} inputs"
)
session = self._server_sessions[server_idx]
span = self._chosen_spans[server_idx]
if self._server_inputs[server_idx] is None:
self._server_inputs[server_idx] = inputs
elif self._server_inputs[server_idx].shape[1] == self._position:
self._server_inputs[server_idx] = torch.cat(
[self._server_inputs[server_idx], inputs[:, -n_input_tokens:]], dim=1
)
assert self._server_inputs[server_idx].shape[1] == self._position + n_input_tokens, (
f"Broken input cache: server_idx={server_idx} shape={self._server_inputs[server_idx].shape} "
f"position={self._position} n_input_tokens={n_input_tokens}"
)
if not session.stepped:
inputs = self._server_inputs[server_idx] # Pass full inputs including prefix
else:
inputs = inputs[:, -n_input_tokens:] # No need to pass prefix further
outputs = session.step(inputs, prompts[span.start : span.end], **kwargs)
assert (
inputs.shape == outputs.shape
), f"Shape mismatch: inputs.shape={inputs.shape}, outputs.shape={outputs.shape})"
inputs = outputs
server_idx += 1
block_idx = span.end
break
except Exception as e:
delay = self._sequence_manager.get_retry_delay(attempt_no)
logger.warning(
f"Caught exception when running inference from block {block_idx} "
f"(retry in {delay:.0f} sec): {repr(e)}"
)
logger.debug("See detailed traceback below:", exc_info=True)
time.sleep(delay)
self._position += n_input_tokens
return inputs
def close(self, *exc_details):
"""Finish a given inference session, close the underlying connection"""
if not self.closed:
self.stack.__exit__(*exc_details or (None, None, None))
self.inference_sessions.clear()
self.closed = True
if not self._closed:
self._server_inputs.clear()
self._exit_server_sessions(self._server_sessions)
self._server_sessions.clear()
self._chosen_spans.clear()
self._closed = True
def __exit__(self, *exc_details):
self.close(*exc_details)

@ -8,7 +8,7 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from torch import nn
import src
from src.client.inference_session import RemoteSequentialInferenceSession
from src.client.inference_session import InferenceSession
from src.client.sequence_manager import RemoteSequenceManager
from src.client.sequential_autograd import _RemoteSequentialAutogradFunction
from src.data_structures import UID_DELIMITER
@ -80,9 +80,9 @@ class RemoteSequential(nn.Module):
def __len__(self):
return len(self.sequence_manager)
def inference_session(self, **kwargs) -> RemoteSequentialInferenceSession:
def inference_session(self, **kwargs) -> InferenceSession:
self.sequence_manager.update_()
return RemoteSequentialInferenceSession(self.sequence_manager, self.p2p, **kwargs)
return InferenceSession(self.sequence_manager, self.p2p, **kwargs)
def extra_repr(self) -> str:
return f"modules={self.sequence_manager.block_uids[0]}..{self.sequence_manager.block_uids[-1]}"

@ -160,3 +160,8 @@ class RemoteSequenceManager:
else:
logger.warning(f"Tried to call rpc_info, but caught {repr(e)}", exc_info=True)
return self._rpc_info
def get_retry_delay(self, attempt_no: int) -> float:
if attempt_no == 0:
return 0
return self.min_backoff * 2 ** (attempt_no - 1)

@ -3,11 +3,12 @@ A PyTorch autograd function that runs forward/backward on a sequence of remote s
"""
import asyncio
import itertools
import logging
from collections import deque
from typing import List, Optional, Sequence, Tuple
import torch
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.utils.logging import get_logger
from src.client.remote_forward_backward import run_remote_backward, run_remote_forward
from src.client.sequence_manager import RemoteSequenceManager
@ -15,6 +16,8 @@ from src.data_structures import CHAIN_DELIMITER, RemoteSpanInfo
from src.server.handler import TransformerConnectionHandler
from src.utils.misc import DUMMY, is_dummy
logger = get_logger(__file__)
MAX_TOKENS_IN_BATCH = 1024
@ -39,19 +42,30 @@ async def sequential_forward(
sequence_manager.block_uids
) # should be n_layers - 1 but add extra prompts for convenience
sequences = sequence_manager.make_sequence(start_index, end_index)
sequences = deque()
intermediate_inputs = []
done_sequences = []
outputs = inputs
while len(sequences) > 0:
block_idx = start_index
while block_idx < end_index:
for attempt_no in itertools.count():
span = sequences.pop(0)
span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
logger.debug(f"Forward: block {block_idx}, attempt {attempt_no}")
try:
if attempt_no >= 1:
sequence_manager.update_()
if not sequences or attempt_no >= 1:
sequences = deque(sequence_manager.make_sequence(block_idx, end_index))
# make_sequence() could return a longer sequence
sequences[-1].end = min(sequences[-1].end, end_index)
logger.debug(f"Found path from block {block_idx} to {end_index} via {len(sequences)} servers")
span = sequences.popleft()
stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
inputs_and_prompts = [inputs, prompts[span.start : span.end]]
span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
(outputs,) = await run_remote_forward(
span_uids, stub, sequence_manager.rpc_info, *inputs_and_prompts, timeout=sequence_manager.timeout
)
@ -64,14 +78,16 @@ async def sequential_forward(
done_sequences.append(span)
inputs = outputs
block_idx = span.end
break
except Exception as e:
logging.warning(f"Caught {e} when running forward for chain {span.start}-{span.end}", exc_info=True)
await asyncio.sleep(sequence_manager.min_backoff * 2**attempt_no)
backup_sequences = sequence_manager.make_sequence(span.start)
assert backup_sequences[0].start == span.start
sequences = backup_sequences
delay = sequence_manager.get_retry_delay(attempt_no)
logger.warning(
f"Caught exception when running forward from block {block_idx} "
f"(retry in {delay:.0f} sec): {repr(e)}"
)
logger.debug("See detailed traceback below:", exc_info=True)
await asyncio.sleep(delay)
return outputs, intermediate_inputs, done_sequences
@ -91,11 +107,26 @@ async def sequential_backward(
grad_prompts_reversed = []
while len(forward_sequences) > 0 and len(intermediate_inputs) > 0:
inputs = intermediate_inputs.pop()
span = forward_sequences.pop()
for attempt_no in itertools.count():
inputs = intermediate_inputs.pop(-1)
span = forward_sequences.pop(-1)
span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
logger.debug(f"Backward: block {span.end - 1}, attempt {attempt_no}")
try:
if attempt_no >= 1:
sequence_manager.update_()
_, backup_inputs, backup_sequences = await sequential_forward(
inputs, prompts, sequence_manager, start_index=span.start, end_index=span.end
)
assert len(backup_inputs) == len(backup_sequences)
assert backup_sequences[0].start == span.start
assert backup_sequences[-1].end == span.end
intermediate_inputs.extend(backup_inputs)
forward_sequences.extend(backup_sequences)
inputs = intermediate_inputs.pop()
span = forward_sequences.pop()
span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
grad_outputs, *span_grad_prompts = await run_remote_backward(
span_uids,
@ -110,18 +141,13 @@ async def sequential_backward(
grad_prompts_reversed.extend(span_grad_prompts)
break
except Exception as e:
logging.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True)
await asyncio.sleep(sequence_manager.min_backoff * 2**attempt_no)
_, backup_intermediate_inputs, backup_forward_sequences = await sequential_forward(
inputs, prompts, sequence_manager, start_index=span.start, end_index=span.end
delay = sequence_manager.get_retry_delay(attempt_no)
logger.warning(
f"Caught exception when running backward between blocks {span.start}-{span.end} "
f"(retry in {delay:.0f} sec): {repr(e)}"
)
assert len(intermediate_inputs) == len(forward_sequences)
assert backup_forward_sequences[0].start == span.start
assert backup_forward_sequences[-1].end == span.end
forward_sequences.extend(backup_forward_sequences)
intermediate_inputs.extend(backup_intermediate_inputs)
logger.debug("See detailed traceback below:", exc_info=True)
await asyncio.sleep(delay)
# For now, we do not support mixed dummy and grad prompts
# Concat in num_layer dimension

@ -33,7 +33,7 @@ def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3):
outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
# test that max length is respected
with pytest.raises(P2PHandlerError) as exc_info:
with pytest.raises(ValueError, match=r"Maximum length exceeded") as exc_info:
sess.step(inputs[:, -1:, :])
assert "Maximum length exceeded" in repr(exc_info.value)

Loading…
Cancel
Save