From f0c73831812c3c5a93264e176ac8f90fd88a3c46 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Tue, 19 Jul 2022 04:28:04 +0300 Subject: [PATCH] Implement RemoteSequential slicing and extra repr, add tests (#30) - finish renaming RemoteSequenceInfo -> RemoteSequenceManager (why: if it was an *Info, user would expect it to be similar - to a dataclass; whereas in actuality, the class is doing heavy network interactions on its own) - implement RemoteSequenceManager.make_sequence (from https://pastebin.com/uXgy2U8B ) - make RemoteSequentialInferenceSession use RemoteSequenceManager.make_sequence - make tests pass again - make it possible to create inference session without RemoteTransformerBlock - make a standalone test for RemoteSequential - rollback convert-model Co-authored-by: Tim Dettmers --- .github/workflows/run-tests.yaml | 11 +- src/client/__init__.py | 3 +- src/client/inference_session.py | 173 +++++++++++++++++++++++++++++++ src/client/remote_block.py | 104 +------------------ src/client/remote_sequential.py | 94 +++-------------- src/client/sequence_manager.py | 81 +++++++++++---- src/data_structures.py | 5 +- src/server/handler.py | 1 - tests/conftest.py | 51 +++++++++ tests/test_block_exact_match.py | 52 ++++------ tests/test_chained_calls.py | 27 ++--- tests/test_full_model.py | 57 +++++----- tests/test_remote_sequential.py | 43 ++++++++ tests/test_utils.py | 13 +++ 14 files changed, 427 insertions(+), 288 deletions(-) create mode 100644 src/client/inference_session.py create mode 100644 tests/conftest.py create mode 100644 tests/test_remote_sequential.py create mode 100644 tests/test_utils.py diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index 0478e08..71f6c9c 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -66,6 +66,8 @@ jobs: run: | export HF_TAG=$(python -c "import os; print(os.environ.get('GITHUB_BASE_REF') or os.environ.get('GITHUB_REF_NAME'))") export MODEL_NAME=bloom-testing/test-bloomd-350m-$HF_TAG + export REF_NAME=bigscience/bloom-350m + python -m cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:12 \ --torch_dtype float32 --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --throughput 1 & SERVER1_PID=$! @@ -79,14 +81,7 @@ jobs: sleep 30 # wait for server to download layers - # test individual blocks - export PYTHONPATH=. - BLOCK_UID=$MODEL_NAME.0 REF_NAME=$MODEL_NAME REF_INDEX=0 pytest tests/test_block_exact_match.py - BLOCK_UID=$MODEL_NAME.19 REF_NAME=$MODEL_NAME REF_INDEX=19 pytest tests/test_block_exact_match.py - - REF_NAME=$MODEL_NAME pytest tests/test_chained_calls.py - - REF_NAME=bigscience/bloom-350m pytest tests/test_full_model.py + PYTHONPATH=. pytest tests kill -s SIGINT $SERVER1_PID $SERVER2_PID echo "Done!" diff --git a/src/client/__init__.py b/src/client/__init__.py index 8ca8c8e..0335921 100644 --- a/src/client/__init__.py +++ b/src/client/__init__.py @@ -1,4 +1,5 @@ -from src.client.remote_block import RemoteTransformerBlock, RemoteTransformerBlockInferenceSession +from src.client.inference_session import RemoteSequentialInferenceSession, RemoteTransformerBlockInferenceSession +from src.client.remote_block import RemoteTransformerBlock from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel from src.client.remote_sequential import RemoteSequential from src.client.sequence_manager import RemoteSequenceManager diff --git a/src/client/inference_session.py b/src/client/inference_session.py new file mode 100644 index 0000000..824a583 --- /dev/null +++ b/src/client/inference_session.py @@ -0,0 +1,173 @@ +from __future__ import annotations + +import asyncio +import contextlib +from typing import AsyncIterator, List, Optional + +import torch +from hivemind import ( + P2P, + anext, + deserialize_torch_tensor, + get_logger, + nested_flatten, + serialize_torch_tensor, + use_hivemind_log_handler, +) +from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker +from hivemind.p2p import StubBase +from hivemind.proto import runtime_pb2 + +from src.client.sequence_manager import RemoteSequenceManager +from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo +from src.server.handler import TransformerConnectionHandler + +use_hivemind_log_handler("in_root_logger") +logger = get_logger(__file__) + + +class RemoteTransformerBlockInferenceSession: + """ + An interface to a single multi-step *inference* session for a specific remote module on a specific server + + :note: this inference session is *not* fault-tolerant out of the box + """ + + def __init__(self, uid: ModuleUID, rpc_info: RPCInfo, inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator): + self.uid, self.rpc_info = uid, rpc_info + # warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread; + # using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep + self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue + self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter + self.stepped = False + self.closed = False + + @classmethod + async def _create( + cls, stub: StubBase, uid: ModuleUID, rpc_info: RPCInfo, timeout: Optional[float] = None + ) -> RemoteTransformerBlockInferenceSession: + """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker""" + inputs_queue = asyncio.Queue() + outputs_stream = await stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue, timeout), timeout=timeout) + return cls(uid, rpc_info, inputs_queue, outputs_stream) + + @staticmethod + async def _read_inputs_from_queue(queue: asyncio.Queue, timeout: Optional[float]) -> AsyncIterator: + while True: + next_input_message = await asyncio.wait_for(queue.get(), timeout) + yield next_input_message + if not next_input_message.uid and not next_input_message.tensors: + break # this message means "done sending" + + def step(self, new_hidden_states: torch.Tensor): + """Inference step: send a chunk of input tensors and receive a chunk of outputs""" + if self.closed: + raise Exception("Session is closed, cannot perform step") + # serialize inputs and put them into the queue + inputs = (new_hidden_states,) + outputs_serialized = RemoteExpertWorker.run_coroutine( + self._step( + runtime_pb2.ExpertRequest( + uid=self.uid, + tensors=[ + serialize_torch_tensor(tensor, proto.compression) + for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["forward_schema"])) + ], + ) + ) + ) + outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors)) + assert outputs[0].shape == inputs[0].shape, f"expected outputs[0] to be hidden states but got {outputs[0]}" + return outputs[0] + + async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse: + """Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker""" + await self._inputs_queue.put(inputs_serialized) + self.stepped = True + return await anext(self._outputs_stream) + + def close(self): + """Finish a given inference session, close the underlying connection""" + if self._outputs_stream is None: + return # already closed + RemoteExpertWorker.run_coroutine(self._aclose_stream()) + self._outputs_stream = self._inputs_queue = None + self.closed = True + + async def _aclose_stream(self): + """Close the inference session. This code is meant to be run inside RemoteExpertWorker""" + if self._outputs_stream is None: + return # already closed + if self.stepped: + await self._inputs_queue.put(runtime_pb2.ExpertRequest()) # empty request will trigger end of session + try: + await anext(self._outputs_stream) + except StopAsyncIteration: + pass + + def __del__(self): + self.close() + + def __enter__(self): + assert not self.closed + return self + + def __exit__(self, *exc_details): + self.close() + + +class RemoteSequentialInferenceSession: + """ + An interface to a multi-step *inference* session for a sequence of remote transformer blocks + """ + + def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, timeout: Optional[float] = None): + self.sequence_manager = sequence_manager + self.p2p = p2p + self.closed = False + self.chosen_spans: List[RemoteSpanInfo] = [] + self.stack = contextlib.ExitStack() + self.inference_sessions: List[RemoteTransformerBlockInferenceSession] = [] + self.timeout = timeout + + def __enter__(self): + assert not self.closed and not self.chosen_spans + self.stack.__enter__() + # TODO(yozh) replace this code with a fault-tolerant chain that can be reconstructed if some peers fail + self.chosen_spans.extend(self.sequence_manager.make_sequence()) + + for chosen_span in self.chosen_spans: + stub = TransformerConnectionHandler.get_stub(self.p2p, chosen_span.peer_id) + span_uids: str = CHAIN_DELIMITER.join(self.sequence_manager.block_uids[chosen_span.start : chosen_span.end]) + inference_session = RemoteExpertWorker.run_coroutine( + RemoteTransformerBlockInferenceSession._create( + stub, span_uids, rpc_info=self.sequence_manager.rpc_info, timeout=self.timeout + ) + ) + self.inference_sessions.append(inference_session) + self.stack.enter_context(inference_session) + + return self + + def step(self, inputs: torch.Tensor): + assert not self.closed + if torch.is_grad_enabled(): + logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.") + for session in self.inference_sessions: + outputs = session.step(inputs) + assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}" + inputs = outputs + return inputs + + def close(self, *exc_details): + """Finish a given inference session, close the underlying connection""" + if not self.closed: + self.stack.__exit__(*exc_details or (None, None, None)) + self.inference_sessions.clear() + self.closed = True + + def __exit__(self, *exc_details): + self.close(*exc_details) + + def __del__(self): + self.close() diff --git a/src/client/remote_block.py b/src/client/remote_block.py index 40143a1..68cd004 100644 --- a/src/client/remote_block.py +++ b/src/client/remote_block.py @@ -1,20 +1,16 @@ # Note: this code is being actively modified by justheuristic. If you want to change anything about it, please warn me. from __future__ import annotations -import asyncio import random -from typing import Any, AsyncIterator, Dict, Optional import torch -from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor from hivemind.moe.client.expert import RemoteExpert, RemoteExpertWorker from hivemind.moe.expert_uid import ExpertInfo from hivemind.p2p import P2P, StubBase -from hivemind.proto import runtime_pb2 -from hivemind.utils import anext, get_logger, nested_flatten, use_hivemind_log_handler +from hivemind.utils import get_logger, use_hivemind_log_handler +from src.client.inference_session import RemoteTransformerBlockInferenceSession from src.data_structures import RemoteModuleInfo -from src.dht_utils import ModuleUID from src.server.handler import TransformerConnectionHandler use_hivemind_log_handler("in_root_logger") @@ -39,100 +35,10 @@ class RemoteTransformerBlock(RemoteExpert): def inference_session(self) -> RemoteTransformerBlockInferenceSession: """Initialize a new inference session with the specified remote server""" - _ = self.info # create _info manually since the built-in property will not work inside RemoteExpertWorker - return RemoteExpertWorker.run_coroutine(RemoteTransformerBlockInferenceSession._create(self)) + return RemoteExpertWorker.run_coroutine( + RemoteTransformerBlockInferenceSession._create(self.stub, self.uid, self.info) + ) def begin_inference_session(self): logger.warning("beging_inference_session was renamed to just inference_session") return self.inference_session() - - -class RemoteTransformerBlockInferenceSession: - """An interface to a single multi-step *inference* session for a specific remote module with a specific server""" - - def __init__(self, uid: ModuleUID, info: Dict[str, Any], inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator): - self.uid, self.info = uid, info - # warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread; - # using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep - self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue - self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter - self.stepped = False - self.closed = False - - @classmethod - async def _create( - cls, - remote_module: RemoteTransformerBlock, - timeout: Optional[float] = None, - ) -> RemoteTransformerBlockInferenceSession: - """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker""" - inputs_queue = asyncio.Queue() - outputs_stream = await remote_module.stub.rpc_inference( - cls._read_inputs_from_queue(inputs_queue, timeout), - timeout=timeout, - ) - return cls(remote_module.uid, remote_module.info, inputs_queue, outputs_stream) - - @staticmethod - async def _read_inputs_from_queue(queue: asyncio.Queue, timeout: Optional[float]) -> AsyncIterator: - while True: - next_input_message = await asyncio.wait_for(queue.get(), timeout) - yield next_input_message - if not next_input_message.uid and not next_input_message.tensors: - break # this message means "done sending" - - def step(self, new_hidden_states: torch.Tensor): - """Inference step: send a chunk of input tensors and receive a chunk of outputs""" - if self.closed: - raise Exception("Session is closed, cannot perform step") - # serialize inputs and put them into the queue - inputs = (new_hidden_states,) - outputs_serialized = RemoteExpertWorker.run_coroutine( - self._step( - runtime_pb2.ExpertRequest( - uid=self.uid, - tensors=[ - serialize_torch_tensor(tensor, proto.compression) - for tensor, proto in zip(inputs, nested_flatten(self.info["forward_schema"])) - ], - ) - ) - ) - outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors)) - assert outputs[0].shape == inputs[0].shape, f"expected outputs[0] to be hidden states but got {outputs[0]}" - return outputs[0] - - async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse: - """Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker""" - await self._inputs_queue.put(inputs_serialized) - self.stepped = True - return await anext(self._outputs_stream) - - def close(self): - """Finish a given inference session, close the underlying connection""" - if self._outputs_stream is None: - return # already closed - RemoteExpertWorker.run_coroutine(self._aclose_stream()) - self._outputs_stream = self._inputs_queue = None - self.closed = True - - async def _aclose_stream(self): - """Close the inference session. This code is meant to be run inside RemoteExpertWorker""" - if self._outputs_stream is None: - return # already closed - if self.stepped: - await self._inputs_queue.put(runtime_pb2.ExpertRequest()) # empty request will trigger end of session - try: - await anext(self._outputs_stream) - except StopAsyncIteration: - pass - - def __del__(self): - self.close() - - def __enter__(self): - assert not self.closed - return self - - def __exit__(self, *exc_details): - self.close() diff --git a/src/client/remote_sequential.py b/src/client/remote_sequential.py index 7026e7c..4fb5728 100644 --- a/src/client/remote_sequential.py +++ b/src/client/remote_sequential.py @@ -1,17 +1,15 @@ from __future__ import annotations -import contextlib import logging -import random from typing import Optional, Union import torch from hivemind import DHT, P2P, get_logger, use_hivemind_log_handler from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker -from hivemind.moe.expert_uid import ExpertInfo from torch import nn import src +from src.client.inference_session import RemoteSequentialInferenceSession from src.client.remote_block import RemoteTransformerBlock from src.client.sequence_manager import RemoteSequenceManager from src.data_structures import UID_DELIMITER @@ -30,49 +28,41 @@ class RemoteSequential(nn.Module): self, config: src.DistributedBloomConfig, dht: DHT, - prefix: str, - max_retries: int = 3, + dht_prefix: Optional[str] = None, p2p: Optional[P2P] = None, sequence_manager: Optional[RemoteSequenceManager] = None, ): logger.warning(f"{self.__class__.__name__} is in active development; expect adventures") - if prefix.endswith(UID_DELIMITER): - logger.warning( - f"dht_prefix {prefix} already ends with '{UID_DELIMITER}'." - f"This will cause {self.__class__.__name__} to look for modules under " - f"{prefix}{UID_DELIMITER}*. Please make sure this is what you intended." - ) - super().__init__() self.config = config self.dht = dht - self.prefix = prefix - self.max_retries = max_retries + self.dht_prefix = dht_prefix or config.dht_prefix self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p()) if p2p is None else p2p - block_uids = [f"{prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)] + num_blocks = self.config.n_layer if sequence_manager is None else len(sequence_manager) + block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(num_blocks)] if sequence_manager is None: logger.debug(f"Creating new sequence manager for block uids: {block_uids}") - self.sequence_manager = RemoteSequenceManager(dht, block_uids) + self.sequence_manager = RemoteSequenceManager(dht, block_uids, self.p2p) self.is_subsequence = False else: + logger.debug(f"Reusing sequence manager with {len(sequence_manager)} modules") + self.sequence_manager = sequence_manager assert isinstance(sequence_manager.block_uids, list) - logger.debug(f"Reusing sequence manager with {len(self.sequence_manager)}") self.is_subsequence = self.sequence_manager.block_uids == block_uids def forward(self, inputs: torch.Tensor): assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3 and inputs.shape[-1] == self.config.n_embed - for block_index in range(self.config.n_layer): - for retry_index in range(self.max_retries): + for block in iter(self): + for retry_index in range(self.sequence_manager.max_retries): try: - block = self[block_index] (outputs,) = block(inputs) assert isinstance(outputs, torch.Tensor) assert outputs.shape == inputs.shape, f"Expected {block} output {inputs.shape}, got {outputs.shape}" inputs = outputs break except Exception as e: - if retry_index == self.max_retries - 1: + if retry_index == self.sequence_manager.max_retries - 1: raise e else: logging.debug(f"Caught {e} when running forward for block {block_index}", exc_info=True) @@ -81,21 +71,20 @@ class RemoteSequential(nn.Module): def __getitem__(self, ix: Union[int, slice]) -> Union[RemoteTransformerBlock, RemoteSequential]: assert isinstance(ix, (int, slice)) if isinstance(ix, int): - assert 0 <= ix < self.config.n_layer + assert 0 <= ix < len(self) (module,) = _create_remote_modules_from_infos([self.sequence_manager.block_infos[ix]], self.p2p) return module else: return RemoteSequential( self.config, self.dht, - prefix=self.prefix, - max_retries=self.max_retries, + dht_prefix=self.dht_prefix, p2p=self.p2p, sequence_manager=self.sequence_manager[ix], ) def __iter__(self): - for block_index in range(self.config.n_layer): + for block_index in range(len(self)): yield self[block_index] def __len__(self): @@ -105,56 +94,5 @@ class RemoteSequential(nn.Module): self.sequence_manager.update_() return RemoteSequentialInferenceSession(self.sequence_manager, self.p2p) - -class RemoteSequentialInferenceSession: - """An interface to a multi-step *inference* session for a sequence of remote transformer blocks""" - - def __init__(self, remote_sequence_info: RemoteSequenceManager, p2p: P2P): - self.remote_sequence_info = remote_sequence_info - self.p2p = p2p - self.closed = False - self.stack = contextlib.ExitStack() - self.active_sessions = [] - - def __enter__(self): - assert not self.closed - self.stack.__enter__() - # TODO(yozh) replace this code with a fault-tolerant chain that can be reconstructed if some peers fail - current_block = 0 - while current_block != len(self.remote_sequence_info): - candidate_spans = self.remote_sequence_info.spans_containing_block[current_block] - chosen_span = random.choice(candidate_spans) # TODO this is a temporary code - assert chosen_span.start <= current_block < chosen_span.end - - # TODO begin throwaway prototype code - remote = RemoteTransformerBlock(self.remote_sequence_info.block_infos[current_block], self.p2p) - _ = remote.info # TODO fix - span_uids = self.remote_sequence_info.block_uids[current_block : chosen_span.end] - remote._info = ExpertInfo(" ".join(span_uids), chosen_span.peer_id) - self.active_sessions.append(remote.inference_session()) - self.stack.enter_context(self.active_sessions[-1]) - current_block = chosen_span.end - # TODO end throwaway prototype code - - return self - - def step(self, inputs: torch.Tensor): - assert not self.closed - for session in self.active_sessions: - outputs = session.step(inputs) - assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}" - inputs = outputs - return inputs - - def close(self, *exc_details): - """Finish a given inference session, close the underlying connection""" - if not self.closed: - self.stack.__exit__(*exc_details or (None, None, None)) - self.active_sessions.clear() - self.closed = True - - def __exit__(self, *exc_details): - self.close(*exc_details) - - def __del__(self): - self.close() + def extra_repr(self) -> str: + return f"modules={self.sequence_manager.block_uids[0]}..{self.sequence_manager.block_uids[-1]}" diff --git a/src/client/sequence_manager.py b/src/client/sequence_manager.py index 7a05bb2..c520096 100644 --- a/src/client/sequence_manager.py +++ b/src/client/sequence_manager.py @@ -1,36 +1,37 @@ from __future__ import annotations +import random import threading from typing import List, Optional, Sequence, Tuple, Union -from hivemind import DHT, DHTExpiration +from hivemind import DHT, P2P, DHTExpiration, MSGPackSerializer +from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker +from hivemind.proto import runtime_pb2 from hivemind.utils.logging import get_logger, use_hivemind_log_handler from src.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState from src.dht_utils import get_remote_module_infos +from src.server.handler import TransformerConnectionHandler use_hivemind_log_handler("in_root_logger") logger = get_logger(__file__) class RemoteSequenceManager: - """Keeps and updates the meta-information about which peers host which blocks""" - - dht: DHT - block_uids: List[ModuleUID] - block_infos: List[Optional[RemoteModuleInfo]] - spans_by_priority: List[RemoteSpanInfo] # sorted from best to worst - spans_containing_block: Tuple[List[RemoteSpanInfo], ...] - last_update_time: DHTExpiration - lock_changes: threading.Lock - - def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID]): - self.dht = dht - self.block_uids = list(block_uids) - self.block_infos = [None] * len(self.block_uids) - self.spans_by_priority = [] - self.spans_containing_block = tuple(list() for _ in range(len(self.block_uids))) - self.last_update_time = -float("inf") + """ + Keeps and updates the meta-information about which peers host which blocks. + In future, this class is intended to maintain latency statistics, ban non-responsive peers, etc. + """ + + def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID], p2p: P2P, max_retries: int = 3): + self.dht, self.p2p = dht, p2p + self.block_uids: List[ModuleUID] = list(block_uids) + self.block_infos: List[Optional[RemoteModuleInfo]] = [None] * len(self.block_uids) + self.spans_by_priority: List[RemoteSpanInfo] = [] # sorted from best to worst + self.spans_containing_block: Tuple[List[RemoteSpanInfo], ...] = tuple([] for _ in range(len(self.block_uids))) + self.last_update_time: DHTExpiration = -float("inf") + self.max_retries = max_retries + self._rpc_info = None self.lock_changes = threading.Lock() self.update_() @@ -38,13 +39,33 @@ class RemoteSequenceManager: assert info is not None, f"Found no remote peers for block {uid}" assert self.spans_by_priority and self.spans_containing_block + def make_sequence(self, start_index: int = 0, end_index: Optional[int] = None) -> Sequence[RemoteSpanInfo]: + """ + Form a sequence of remote servers that collectively serve all consecutive layers + + :param start_index: optional index of the first module in a sequence, default = the first of block_uids + :param end_index: optional index of the last module (non-inclusive), default = after last of block uids + """ + end_index = end_index if end_index is not None else len(self.block_uids) + span_sequence = [] + current_index = start_index + while current_index < end_index: + candidate_spans = self.spans_containing_block[current_index] + chosen_span = random.choice(candidate_spans) # TODO this should be replaced with proper load balancing + + assert chosen_span.start <= current_index < chosen_span.end + span_sequence.append(chosen_span) + current_index = chosen_span.end + + return span_sequence + def __getitem__(self, ix: Union[int, slice]) -> RemoteSequenceManager: """Get a RemoteSequenceManager for a sub-sequence of blocks""" assert isinstance(ix, (int, slice)) if not isinstance(ix, slice): ix = slice(int(ix), int(ix) + 1, 1) with self.lock_changes: - subseq = RemoteSequenceManager(self.dht, self.block_uids[ix]) + subseq = RemoteSequenceManager(self.dht, self.block_uids[ix], self.p2p) subseq.block_infos = self.block_infos[ix] subseq.spans_by_priority, subseq.spans_containing_block = subseq.compute_spans(subseq.block_infos) subseq.last_update_time = self.last_update_time @@ -102,3 +123,25 @@ class RemoteSequenceManager: def __len__(self): return len(self.block_uids) + + @property + def rpc_info(self): + """Return the rpc_info queried from one of the servers that hold the first block""" + if self._rpc_info is None: + retries = 0 + for i in range(self.max_retries): + try: + self.update_() + peer_id = random.choice(list(self.block_infos[0].servers.keys())) + stub = TransformerConnectionHandler.get_stub(self.p2p, peer_id) + outputs = RemoteExpertWorker.run_coroutine( + stub.rpc_info(runtime_pb2.ExpertUID(uid=self.block_uids[0])) + ) + self._rpc_info = MSGPackSerializer.loads(outputs.serialized_info) + except Exception as e: + retries += 1 + if retries >= self.max_retries: + raise e + else: + logger.warning(f"Tried to call rpc_info, but caught {repr(e)}", exc_info=True) + return self._rpc_info diff --git a/src/data_structures.py b/src/data_structures.py index d0719a9..919c8c1 100644 --- a/src/data_structures.py +++ b/src/data_structures.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from enum import Enum -from typing import Dict +from typing import Any, Dict from hivemind import PeerID @@ -36,3 +36,6 @@ class RemoteSpanInfo: start: int end: int peer_id: PeerID + + +RPCInfo = Dict[str, Any] diff --git a/src/server/handler.py b/src/server/handler.py index 8c0707b..d040acf 100644 --- a/src/server/handler.py +++ b/src/server/handler.py @@ -1,4 +1,3 @@ -# Note: this code is being actively modified by justheuristic. If you want to change anything about it, please warn me. import contextlib from typing import AsyncIterator, Dict, Sequence diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..57287c3 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,51 @@ +import asyncio +import gc +from contextlib import suppress + +import psutil +import pytest +from hivemind.utils.crypto import RSAPrivateKey +from hivemind.utils.logging import get_logger, use_hivemind_log_handler +from hivemind.utils.mpfuture import MPFuture + +use_hivemind_log_handler("in_root_logger") +logger = get_logger(__name__) + + +@pytest.fixture +def event_loop(): + """ + This overrides the ``event_loop`` fixture from pytest-asyncio + (e.g. to make it compatible with ``asyncio.subprocess``). + + This fixture is identical to the original one but does not call ``loop.close()`` in the end. + Indeed, at this point, the loop is already stopped (i.e. next tests are free to create new loops). + However, finalizers of objects created in the current test may reference the current loop and fail if it is closed. + For example, this happens while using ``asyncio.subprocess`` (the ``asyncio.subprocess.Process`` finalizer + fails if the loop is closed, but works if the loop is only stopped). + """ + + yield asyncio.get_event_loop() + + +@pytest.fixture(autouse=True, scope="session") +def cleanup_children(): + yield + + with RSAPrivateKey._process_wide_key_lock: + RSAPrivateKey._process_wide_key = None + + gc.collect() # Call .__del__() for removed objects + + children = psutil.Process().children(recursive=True) + if children: + logger.info(f"Cleaning up {len(children)} leftover child processes") + for child in children: + with suppress(psutil.NoSuchProcess): + child.terminate() + psutil.wait_procs(children, timeout=1) + for child in children: + with suppress(psutil.NoSuchProcess): + child.kill() + + MPFuture.reset_backend() diff --git a/tests/test_block_exact_match.py b/tests/test_block_exact_match.py index 1a0caa6..caac346 100644 --- a/tests/test_block_exact_match.py +++ b/tests/test_block_exact_match.py @@ -1,47 +1,39 @@ -# Note: this code is being actively modified by justheuristic. If you want to change anything about it, please warn me. -import os +import random import hivemind +import pytest import torch import transformers +from test_utils import * from src.bloom.from_pretrained import load_pretrained_block from src.client.remote_block import RemoteTransformerBlock +from src.data_structures import UID_DELIMITER from src.dht_utils import get_remote_module -INITIAL_PEERS = os.environ.get("INITIAL_PEERS") -if not INITIAL_PEERS: - raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids") -INITIAL_PEERS = INITIAL_PEERS.split() - - -BLOCK_UID = os.environ.get("BLOCK_UID") -if not BLOCK_UID: - raise RuntimeError("Must specify BLOCK_UID as an index of a transformer block to be tested") - -REF_NAME = os.environ.get("REF_NAME", "bigscience/test-bloomd-6b3") -REF_INDEX = int(os.environ.get("REF_INDEX", BLOCK_UID.split(".")[-1])) - +@pytest.mark.forked def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3): dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True) + config = transformers.AutoConfig.from_pretrained(MODEL_NAME) - remote_block = get_remote_module(dht, BLOCK_UID) - assert remote_block is not None, f"Could not find {BLOCK_UID} in DHT" - assert isinstance(remote_block, RemoteTransformerBlock) - ref_config = transformers.AutoConfig.from_pretrained(REF_NAME) + for block_index in random.sample(range(config.n_layer), 3): + block_uid = f"{MODEL_NAME}{UID_DELIMITER}{block_index}" + remote_block = get_remote_module(dht, block_uid) + assert remote_block is not None, f"Could not find {block_uid} in DHT" + assert isinstance(remote_block, RemoteTransformerBlock) - inputs = torch.randn(1, 8, ref_config.hidden_size) - (outputs_forward,) = remote_block(inputs) + inputs = torch.randn(1, 8, config.hidden_size) + (outputs_forward,) = remote_block(inputs) - outputs_inference = [] - with remote_block.inference_session() as sess: - for i in range(inputs.shape[1]): - outputs_inference.append(sess.step(inputs[:, i : i + 1, :])) - outputs_inference = torch.cat(outputs_inference, dim=1) + outputs_inference = [] + with remote_block.inference_session() as sess: + for i in range(inputs.shape[1]): + outputs_inference.append(sess.step(inputs[:, i : i + 1, :])) + outputs_inference = torch.cat(outputs_inference, dim=1) - ref_block = load_pretrained_block(REF_NAME, REF_INDEX, torch_dtype=torch.float32) - (outputs_local,) = ref_block(inputs) + ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32) + (outputs_local,) = ref_block(inputs) - assert torch.allclose(outputs_local, outputs_forward, rtol=0, atol=atol_forward) - assert torch.allclose(outputs_local, outputs_inference, rtol=0, atol=atol_inference) + assert torch.allclose(outputs_local, outputs_forward, rtol=0, atol=atol_forward) + assert torch.allclose(outputs_local, outputs_inference, rtol=0, atol=atol_inference) diff --git a/tests/test_chained_calls.py b/tests/test_chained_calls.py index 140d58d..84c4232 100644 --- a/tests/test_chained_calls.py +++ b/tests/test_chained_calls.py @@ -3,30 +3,20 @@ # - if you want more stable tests, see test_block_exact_match # - if you want to figure out chained inference, ask yozh -import os import hivemind +import pytest import torch import transformers from hivemind.moe.expert_uid import UID_DELIMITER, ExpertInfo +from test_utils import * from src.bloom.from_pretrained import load_pretrained_block from src.client.remote_block import RemoteTransformerBlock from src.dht_utils import get_remote_module -INITIAL_PEERS = os.environ.get("INITIAL_PEERS") -if not INITIAL_PEERS: - raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids") -INITIAL_PEERS = INITIAL_PEERS.split() - - -MODEL_NAME = os.environ.get("MODEL_NAME") -if not MODEL_NAME: - raise RuntimeError("Must specify MODEL_NAME as a name of a model to be tested") - -REF_NAME = os.environ.get("REF_NAME", "bigscience/test-bloomd-6b3") - +@pytest.mark.forked def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1): dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True) config = transformers.AutoConfig.from_pretrained(MODEL_NAME) @@ -38,9 +28,9 @@ def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq remote_block._info = ExpertInfo(f"{MODEL_NAME}.3 {MODEL_NAME}.4 {MODEL_NAME}.5", remote_block._info.peer_id) ref_blocks = [ - load_pretrained_block(REF_NAME, 3, torch_dtype=torch.float32), - load_pretrained_block(REF_NAME, 4, torch_dtype=torch.float32), - load_pretrained_block(REF_NAME, 5, torch_dtype=torch.float32), + load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch.float32), + load_pretrained_block(MODEL_NAME, 4, torch_dtype=torch.float32), + load_pretrained_block(MODEL_NAME, 5, torch_dtype=torch.float32), ] inputs = torch.randn(1, seq_length, config.hidden_size, requires_grad=True) outputs_rpc = remote_block.forward(inputs)[0] @@ -59,6 +49,7 @@ def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq assert torch.allclose(grads_ref, grads_rpc, rtol=0, atol=atol_backward) +@pytest.mark.forked def test_chained_inference_exact_match(atol_inference=1e-4): dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True) config = transformers.AutoConfig.from_pretrained(MODEL_NAME) @@ -78,8 +69,8 @@ def test_chained_inference_exact_match(atol_inference=1e-4): outputs_inference = torch.cat(outputs_inference, dim=1) ref_blocks = [ - load_pretrained_block(REF_NAME, 3, torch_dtype=torch.float32), - load_pretrained_block(REF_NAME, 4, torch_dtype=torch.float32), + load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch.float32), + load_pretrained_block(MODEL_NAME, 4, torch_dtype=torch.float32), ] outputs_ref = [] caches = [None, None] diff --git a/tests/test_full_model.py b/tests/test_full_model.py index 5a60365..98140b4 100644 --- a/tests/test_full_model.py +++ b/tests/test_full_model.py @@ -1,9 +1,8 @@ -# Note: this code is being actively modified by justheuristic. If you want to change anything about it, please warn me. -import os - +import pytest import torch import transformers from hivemind import get_logger, use_hivemind_log_handler +from test_utils import * from src.client.remote_model import DistributedBloomForCausalLM @@ -11,19 +10,7 @@ use_hivemind_log_handler("in_root_logger") logger = get_logger(__file__) -INITIAL_PEERS = os.environ.get("INITIAL_PEERS") -if not INITIAL_PEERS: - raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids") -INITIAL_PEERS = INITIAL_PEERS.split() - - -MODEL_NAME = os.environ.get("MODEL_NAME") -if not MODEL_NAME: - raise RuntimeError("Must specify MODEL_NAME as an index of a transformer block to be tested") - -REF_NAME = os.environ.get("REF_NAME") - - +@pytest.mark.forked def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3): tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME) model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) @@ -31,23 +18,12 @@ def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3): assert len(model.transformer.h) == model.config.n_layer test_inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"] - parallel_outputs = model.forward(test_inputs).logits - assert torch.all(torch.isfinite(parallel_outputs)) - logger.info("Forward outputs are finite") - if REF_NAME: - with torch.no_grad(): - ref_model = transformers.AutoModelForCausalLM.from_pretrained(REF_NAME) - dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool) - # note: this creates a dummy mask to make the test compatible with older transformer versions - # prior to https://github.com/huggingface/transformers/pull/17837 - ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits - assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward) - del ref_model, ref_outputs - else: - logger.warning("Did not test exact match with local model: REF_NAME environment variable is not set") + with torch.no_grad(): + parallel_outputs = model.forward(test_inputs).logits + assert torch.all(torch.isfinite(parallel_outputs)) + logger.info("Forward outputs are finite") - with torch.inference_mode(): embs = model.transformer.word_embeddings(test_inputs) embs = model.transformer.word_embeddings_layernorm(embs) recurrent_outputs = [] @@ -60,5 +36,20 @@ def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3): dictionary = model.transformer.word_embeddings.weight.t() recurrent_outputs = recurrent_outputs.to(dictionary.dtype) recurrent_outputs = (recurrent_outputs @ dictionary).float() - assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference) - logger.info("Inference is consistent with forward") + assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference) + logger.info("Inference is consistent with forward") + + del model, recurrent_outputs + + if REF_NAME: + ref_model = transformers.AutoModelForCausalLM.from_pretrained(REF_NAME) + dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool) + # note: this creates a dummy mask to make the test compatible with older transformer versions + # prior to https://github.com/huggingface/transformers/pull/17837 + ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits + assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward) + logger.warning(f"Distributed forward is consistent with {type(ref_model)}.forward") + del ref_model, ref_outputs, dummy_mask + else: + logger.warning("Did not test exact match with local model: REF_NAME environment variable is not set") + assert False diff --git a/tests/test_remote_sequential.py b/tests/test_remote_sequential.py new file mode 100644 index 0000000..1ae4c38 --- /dev/null +++ b/tests/test_remote_sequential.py @@ -0,0 +1,43 @@ +import pytest +import torch +from hivemind import DHT, get_logger, use_hivemind_log_handler +from test_utils import * + +from src import RemoteSequential +from src.client.remote_model import DistributedBloomConfig + +use_hivemind_log_handler("in_root_logger") +logger = get_logger(__file__) + + +@pytest.mark.forked +def test_remote_sequential(): + config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) + dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True) + test_inputs = torch.randn(1, 5, config.hidden_size, requires_grad=True) + grad_proj = torch.randn(1, 5, config.hidden_size) + + sequential = RemoteSequential(config, dht) + + full_outputs = sequential(test_inputs) + (full_outputs * grad_proj).sum().backward() + assert test_inputs.grad is not None + full_grad = test_inputs.grad.clone() + test_inputs.grad.data.zero_() + + first_half = sequential[: config.n_layer // 2] + second_half = sequential[config.n_layer // 2 :] + assert len(first_half) + len(second_half) == len(sequential) + assert abs(len(first_half) - len(second_half)) == config.n_layer % 2 + for m in sequential, first_half, second_half: + assert isinstance(repr(m), str) + + hidden = first_half(test_inputs) + assert isinstance(hidden, torch.Tensor) + assert hidden.shape == test_inputs.shape + assert hidden.requires_grad + second_half_outputs = second_half(hidden) + assert torch.allclose(second_half_outputs, full_outputs) + + (second_half_outputs * grad_proj).sum().backward() + assert torch.allclose(test_inputs.grad, full_grad) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..ee440d6 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,13 @@ +import os + +INITIAL_PEERS = os.environ.get("INITIAL_PEERS") +if not INITIAL_PEERS: + raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids") +INITIAL_PEERS = INITIAL_PEERS.split() + + +MODEL_NAME = os.environ.get("MODEL_NAME") +if not MODEL_NAME: + raise RuntimeError("Must specify MODEL_NAME as an index of a transformer block to be tested") + +REF_NAME = os.environ.get("REF_NAME")