From 0b0277ed6f5497a6ad33e4e97aad360150734e53 Mon Sep 17 00:00:00 2001
From: Alexander Borzunov
Run 100B+ language models at home, BitTorrent-style.
Fine-tuning and inference up to 10x faster than offloading
+
From d4c687daca8c68a631cc42db59dca152bf3a4d98 Mon Sep 17 00:00:00 2001
From: Artem Chumachenko
+ 📚 See FAQ + 📜 Read paper
-## FAQ - -1. **What's the motivation for people to host model layers in the public swarm?** - - People who run inference and fine-tuning themselves get a certain speedup if they host a part of the model locally. Some may be also motivated to "give back" to the community helping them to run the model (similarly to how [BitTorrent](https://en.wikipedia.org/wiki/BitTorrent) users help others by sharing data they have already downloaded). - - Since it may be not enough for everyone, we are also working on introducing explicit __incentives__ ("bloom points") for people donating their GPU time to the public swarm. Once this system is ready, people who earned these points will be able to spend them on inference/fine-tuning with higher priority or increased security guarantees, or (maybe) exchange them for other rewards. - -2. **Why is the platform named "Petals"?** - - "Petals" is a metaphor for people serving different parts of the model. Together, they host the entire language model — [BLOOM](https://huggingface.co/bigscience/bloom). - - While our platform focuses on BLOOM now, we aim to support more [foundation models](https://arxiv.org/abs/2108.07258) in future. - ## Installation Here's how to install Petals with conda: From 38b071135bc09d937679dc04d0cbd5b24cd17bf2 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov
Run 100B+ language models at home, BitTorrent-style.
- Fine-tuning and inference up to 10x faster than offloading
+ Fine-tuning and inference up to 10x faster than offloading
@@ -98,61 +98,106 @@ Learning more: ## Installation -Here's how to install Petals with conda: +Here's how to install Petals with [Anaconda](https://www.anaconda.com/products/distribution) on Linux: ```bash conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia pip install -U petals ``` -This script uses Anaconda to install CUDA-enabled PyTorch. -If you don't have anaconda, you can get it from [here](https://www.anaconda.com/products/distribution). -If you don't want anaconda, you can install PyTorch [any other way](https://pytorch.org/get-started/locally/). -If you want to run models with 8-bit weights, please install **PyTorch with CUDA 11** or newer for compatility with [bitsandbytes](https://github.com/timDettmers/bitsandbytes). - -__System requirements:__ Petals only supports Linux for now. If you don't have a Linux machine, consider running Petals in Docker (see our [image](https://hub.docker.com/r/learningathome/petals)) or, in case of Windows, in WSL2 ([read more](https://learn.microsoft.com/en-us/windows/ai/directml/gpu-cuda-in-wsl)). CPU is enough to run a client, but you probably need a GPU to run a server efficiently. - -## 🛠️ Development - -Petals uses pytest with a few plugins. To install them, run: - -```bash -conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia -git clone https://github.com/bigscience-workshop/petals.git && cd petals -pip install -e .[dev] -``` - -To run minimalistic tests, you need to make a local swarm with a small model and some servers. You may find more information about how local swarms work and how to run them in [this tutorial](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm). - -```bash -export MODEL_NAME=bloom-testing/test-bloomd-560m-main - -python -m petals.cli.run_server $MODEL_NAME --block_indices 0:12 \ - --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --new_swarm &> server1.log & -sleep 5 # wait for the first server to initialize DHT - -python -m petals.cli.run_server $MODEL_NAME --block_indices 12:24 \ - --initial_peers SEE_THE_OUTPUT_OF_THE_1ST_PEER &> server2.log & - -tail -f server1.log server2.log # view logs for both servers -``` - -Then launch pytest: - -```bash -export MODEL_NAME=bloom-testing/test-bloomd-560m-main REF_NAME=bigscience/bloom-560m -export INITIAL_PEERS=/ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g -PYTHONPATH=. pytest tests --durations=0 --durations-min=1.0 -v -``` - -After you're done, you can terminate the servers and ensure that no zombie processes are left with `pkill -f petals.cli.run_server && pkill -f p2p`. - -The automated tests use a more complex server configuration that can be found [here](https://github.com/bigscience-workshop/petals/blob/main/.github/workflows/run-tests.yaml). - -### Code style - -We use [black](https://black.readthedocs.io/en/stable/the_black_code_style/current_style.html) and [isort](https://pycqa.github.io/isort/) for all pull requests. -Before committing your code, simply run `black . && isort .` and you will be fine. +If you don't use Anaconda, you can install PyTorch in [any other way](https://pytorch.org/get-started/locally/). If you want to run models with 8-bit weights, please install PyTorch with CUDA 11.x or newer for compatility with [bitsandbytes](https://github.com/timDettmers/bitsandbytes). + +See the instructions for macOS and Windows, the full requirements, and troubleshooting advice in our [FAQ](https://github.com/bigscience-workshop/petals/wiki/FAQ:-Frequently-asked-questions#running-a-client). + +## ⏱️ Benchmarks + +
Network | +Single-batch inference (steps/s) |
+ Parallel forward (tokens/s) |
+ |||
---|---|---|---|---|---|
Bandwidth | +Round-trip latency |
+ Sequence length | +Batch size | +||
128 | +2048 | +1 | +64 | +||
Offloading, max. possible speed on 1x A100 1 | +|||||
256 Gbit/s | ++ | 0.18 | +0.18 | +2.7 | +170.3 | +
128 Gbit/s | ++ | 0.09 | +0.09 | +2.4 | +152.8 | +
Petals on 14 heterogeneous servers across Europe and North America 2 | +|||||
Real world | +0.83 | +0.79 | +32.6 | +179.4 | +|
Petals on 3 servers, with one A100 each 3 | +|||||
1 Gbit/s | +< 5 ms | +1.71 | +1.54 | +70.0 | +253.6 | +
100 Mbit/s | +< 5 ms | +1.66 | +1.49 | +56.4 | +182.0 | +
100 Mbit/s | +100 ms | +1.23 | +1.11 | +19.7 | +112.2 | +
Network | diff --git a/src/petals/client/__init__.py b/src/petals/client/__init__.py index b728962..5ff26bc 100644 --- a/src/petals/client/__init__.py +++ b/src/petals/client/__init__.py @@ -5,6 +5,6 @@ from petals.client.remote_model import ( DistributedBloomForSequenceClassification, DistributedBloomModel, ) -from petals.client.remote_sequential import RemoteSequential, RemoteTransformerBlock +from petals.client.remote_sequential import RemoteSequential from petals.client.routing.sequence_manager import RemoteSequenceManager from petals.client.routing.spending_policy import NoSpendingPolicy, SpendingPolicyBase diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index 24a188a..93700f9 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -8,7 +8,6 @@ from typing import AsyncIterator, List, Optional import torch from hivemind import ( - P2P, MSGPackSerializer, anext, deserialize_torch_tensor, @@ -162,9 +161,8 @@ class InferenceSession: An interface to a multi-step *inference* session for a sequence of remote transformer blocks """ - def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, max_length: int): + def __init__(self, sequence_manager: RemoteSequenceManager, max_length: int): self._sequence_manager = sequence_manager - self._p2p = p2p self._closed = False self._chosen_spans = [] self._server_sessions = [] @@ -181,7 +179,7 @@ class InferenceSession: server_sessions = [] try: for span in chosen_spans: - stub = TransformerConnectionHandler.get_stub(self._p2p, span.peer_id) + stub = TransformerConnectionHandler.get_stub(self._sequence_manager.state.p2p, span.peer_id) span_uids = CHAIN_DELIMITER.join(self._sequence_manager.block_uids[span.start : span.end]) metadata = self._sequence_manager.get_request_metadata("rpc_inference", span_uids, peer_id=span.peer_id) session = RemoteExpertWorker.run_coroutine( @@ -189,7 +187,7 @@ class InferenceSession: stub, span_uids, rpc_info=self._sequence_manager.rpc_info, - timeout=self._sequence_manager.request_timeout, + timeout=self._sequence_manager.config.request_timeout, max_length=self._max_length, **metadata, ) @@ -305,9 +303,8 @@ class InferenceSession: self._sequence_manager.on_request_success(span.peer_id) break except Exception as e: - if span is not None: - self._sequence_manager.on_request_failure(span.peer_id) - if attempt_no + 1 == self._sequence_manager.max_retries: + self._sequence_manager.on_request_failure(span.peer_id if span is not None else None) + if attempt_no + 1 == self._sequence_manager.config.max_retries: raise delay = self._sequence_manager.get_retry_delay(attempt_no) logger.warning( diff --git a/src/petals/client/remote_model.py b/src/petals/client/remote_model.py index d67d4bf..0d218d1 100644 --- a/src/petals/client/remote_model.py +++ b/src/petals/client/remote_model.py @@ -18,13 +18,14 @@ from transformers.models.bloom import ( from petals.bloom.modeling_utils import LMHead from petals.client.remote_generation import RemoteGenerationMixin from petals.client.remote_sequential import RemoteSequential +from petals.client.routing.sequence_manager import SequenceManagerConfig from petals.constants import PUBLIC_INITIAL_PEERS from petals.utils.misc import DUMMY logger = get_logger(__name__) -class DistributedBloomConfig(BloomConfig): +class DistributedBloomConfig(BloomConfig, SequenceManagerConfig): """ A bloom config that contains information about DHT peers. To create a distributed model, one must provide dht_prefix and either initial_peers or dht. @@ -33,15 +34,9 @@ class DistributedBloomConfig(BloomConfig): initial_peers: List[str] = PUBLIC_INITIAL_PEERS # a list of initial peers for hivemind DHT dht_prefix: str # a prefix for all dht keys that correspond to this model (usually equal to model name) daemon_startup_timeout: int = 60 # timeout for the libp2p daemon connecting to initial peers - dht: Optional[hivemind.DHT] = None # a running DHT instance, e.g. when using the same DHT for multiple models - request_timeout: int = 3 * 60 # a number of seconds for waiting result from each node - max_retries: Optional[int] = None # max number retries before the client raises an exception (default: inf) - allowed_servers: Optional[ - Collection[Union[str, hivemind.PeerID]] - ] = None # if defined, send requests only to these servers pre_seq_len: int = 0 # a number of tokens for prompt tuning. - tuning_mode: Optional[str] = None # One of the finetune options: [None, 'shallow_ptune', 'deep_ptune', 'adapters'] + tuning_mode: Optional[str] = None # fine-tuning regime, one of [None, "ptune", "deep_ptune"] # This settings matter for running the client with dtype bfloat16 on CPU. # If the CPU doesn't support AVX512, chunked_forward() significantly speeds up computations. @@ -106,30 +101,16 @@ class DistributedBloomModel(_FromPretrainedDefaultsMixin, BloomModel): config_class = DistributedBloomConfig - def __init__(self, config: DistributedBloomConfig): + def __init__(self, config: DistributedBloomConfig, *, dht: Optional[hivemind.DHT] = None): assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..." - assert config.initial_peers or config.dht, "Please specify initial_peers=list(...) or dht=hivemind.DHT(...)" + assert config.initial_peers or dht is not None, "Please specify `config.initial_peers` or `dht`" n_layer, config.n_layer = config.n_layer, 0 # temporarily set n_layer to 0 to prevent layer initialization super().__init__(config) assert len(self.h) == 0 config.n_layer = n_layer - dht = config.dht - if dht is None: - dht = hivemind.DHT( - initial_peers=config.initial_peers, - client_mode=True, - num_workers=n_layer, - startup_timeout=config.daemon_startup_timeout, - start=True, - ) - assert isinstance(dht, hivemind.DHT) and dht.is_alive(), "dht must be a running hivemind.DHT instance" - self.h = RemoteSequential( - config, - dht, - config.dht_prefix, - ) + self.h = RemoteSequential(config, dht=dht) # Forbid accumulate grads for embeddings and layernorm self.set_requires_grad(False) diff --git a/src/petals/client/remote_sequential.py b/src/petals/client/remote_sequential.py index 788805d..8bc60ff 100644 --- a/src/petals/client/remote_sequential.py +++ b/src/petals/client/remote_sequential.py @@ -3,7 +3,7 @@ from __future__ import annotations from typing import Optional, Union import torch -from hivemind import DHT, P2P, get_logger +from hivemind import DHT, get_logger from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker from torch import nn @@ -25,39 +25,26 @@ class RemoteSequential(nn.Module): def __init__( self, config: petals.client.DistributedBloomConfig, - dht: DHT, - dht_prefix: Optional[str] = None, - p2p: Optional[P2P] = None, + *, sequence_manager: Optional[RemoteSequenceManager] = None, - **kwargs, + dht: Optional[DHT] = None, + start_block: Optional[int] = None, + end_block: Optional[int] = None, ): super().__init__() self.config = config - self.dht = dht - self.dht_prefix = dht_prefix or config.dht_prefix - self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p()) if p2p is None else p2p - num_blocks = self.config.n_layer if sequence_manager is None else len(sequence_manager) - block_uids = tuple(f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(num_blocks)) + assert sequence_manager is None or ( + dht is None and start_block is None and end_block is None + ), "`dht`, `start_block`, and `end_block` have no effect when you provide a custom `sequence_manager`" if sequence_manager is None: - logger.debug(f"Creating new sequence manager for block uids: {block_uids}") - self.sequence_manager = RemoteSequenceManager( - dht, - block_uids, - self.p2p, - request_timeout=config.request_timeout, - max_retries=config.max_retries, - allowed_servers=config.allowed_servers, - **kwargs, - ) - self.is_subsequence = False - else: - logger.debug(f"Reusing sequence manager with {len(sequence_manager)} modules") - if kwargs: - logger.warning(f"Parameters {kwargs} are ignored because sequence_manager is explicitly provided") - self.sequence_manager = sequence_manager - assert isinstance(sequence_manager.sequence_info.block_uids, tuple) - self.is_subsequence = self.sequence_manager.sequence_info.block_uids != block_uids + if start_block is None: + start_block = 0 + if end_block is None: + end_block = self.config.n_layer + block_uids = tuple(f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block, end_block)) + sequence_manager = RemoteSequenceManager(config, block_uids, dht=dht) + self.sequence_manager = sequence_manager def forward(self, inputs: torch.Tensor, prompts: torch.Tensor = DUMMY): assert inputs.ndim == 3, "inputs must be a tensor of shape [batch_size, seq_length, hidden_size]" @@ -66,23 +53,10 @@ class RemoteSequential(nn.Module): return outputs def __getitem__(self, ix: Union[int, slice]) -> RemoteSequential: - assert isinstance(ix, (int, slice)) - if isinstance(ix, int): - return RemoteTransformerBlock( - self.config, - self.dht, - dht_prefix=self.dht_prefix, - p2p=self.p2p, - sequence_manager=self.sequence_manager[ix], - ) - else: - return RemoteSequential( - self.config, - self.dht, - dht_prefix=self.dht_prefix, - p2p=self.p2p, - sequence_manager=self.sequence_manager[ix], - ) + return RemoteSequential( + self.config, + sequence_manager=self.sequence_manager[ix], + ) def __iter__(self): for block_index in range(len(self)): @@ -92,22 +66,7 @@ class RemoteSequential(nn.Module): return len(self.sequence_manager) def inference_session(self, **kwargs) -> InferenceSession: - return InferenceSession(self.sequence_manager, self.p2p, **kwargs) + return InferenceSession(self.sequence_manager, **kwargs) def extra_repr(self) -> str: return f"modules={self.sequence_manager.block_uids[0]}..{self.sequence_manager.block_uids[-1]}" - - -class RemoteTransformerBlock(RemoteSequential): - """Single transformer block hosted by swarm - - This class is deprecated and kept for backward compatibility. - It will be removed soon in favor of using ``RemoteSequential`` directly. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - assert len(self) == 1, "Remote Block is a sequence size 1" - - def extra_repr(self): - return f"{self.sequence_manager.block_uids[0]}" diff --git a/src/petals/client/routing/sequence_info.py b/src/petals/client/routing/sequence_info.py index de7eb37..8dafb6e 100644 --- a/src/petals/client/routing/sequence_info.py +++ b/src/petals/client/routing/sequence_info.py @@ -27,14 +27,14 @@ class RemoteSequenceInfo: block_infos: Tuple[RemoteModuleInfo, ...] # note: the contents of RemoteModuleInfo can and will be updated spans_by_priority: List[RemoteSpanInfo] spans_containing_block: Tuple[List[RemoteSpanInfo], ...] - last_updated_time: float + last_updated_time: Optional[float] @classmethod def make_empty(cls: Type[T], block_uids: Iterable[ModuleUID]) -> T: block_uids = tuple(block_uids) empty_block_infos = tuple(RemoteModuleInfo(uid, {}) for uid in block_uids) empty_spans = tuple([] for _ in range(len(block_uids))) - return cls(block_uids, empty_block_infos, [], empty_spans, last_updated_time=-float("inf")) + return cls(block_uids, empty_block_infos, [], empty_spans, last_updated_time=None) def __getitem__(self, ix: slice): assert isinstance(ix, slice) diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py index 25c68ef..5f387c4 100644 --- a/src/petals/client/routing/sequence_manager.py +++ b/src/petals/client/routing/sequence_manager.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import dataclasses import itertools import logging import random @@ -13,7 +14,6 @@ import numpy as np from hivemind import DHT, P2P, MSGPackSerializer, PeerID, get_dht_time from hivemind.dht.node import Blacklist from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker -from hivemind.p2p import P2PHandlerError from hivemind.proto import runtime_pb2 from hivemind.utils.logging import get_logger @@ -26,6 +26,33 @@ from petals.server.handler import TransformerConnectionHandler logger = get_logger(__name__) +@dataclasses.dataclass +class SequenceManagerConfig: + allowed_servers: Optional[Collection[Union[PeerID, str]]] = None # if defined, send requests only to these servers + + request_timeout: float = 3 * 60 # timeout for forward/backward/inference requests + update_period: float = 60 # refresh DHT information once in this many seconds + + max_retries: Optional[int] = None # max number retries before the client raises an exception (default: inf) + min_backoff: float = 1 # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1) + max_backoff: float = 60 # limit maximal sleep time between retries to this value + ban_timeout: float = 15 # when a remote peer fails to respond, prevent routing to that peer for this many seconds + + +@dataclasses.dataclass +class SequenceManagerState: + p2p: P2P = None + sequence_info: Optional[RemoteSequenceInfo] = None + rpc_info: Optional[dict] = None + banned_peers: Optional[Blacklist] = None + + def __getitem__(self, ix: Union[int, slice]) -> SequenceManagerState: + return dataclasses.replace(self, sequence_info=self.sequence_info[ix]) + + def __len__(self) -> int: + return len(self.sequence_info) + + class RemoteSequenceManager: """ Sequence manager is a thread that keeps track of remote servers that hold the specified sequence of blocks. @@ -34,67 +61,56 @@ class RemoteSequenceManager: Using this information, sequence manager can form sequences of servers that collectively have the full sequence. To form such a sequence, call .make_sequence with the appropriate optimization policy (see make_sequence docstr). - :param dht: a running hivemind.DHT instance, connected to peers that serve the corresponding blocks - :param block_uids: a sequence of DHT keys (strings) corresponding to remote layers - :param p2p: an optional P2P replica (if not specified, create one via dht.replicate_p2p()) - :param update_period: by default, refresh DHT information once in this many seconds - :param request_timeout: float, in seconds, default timeout for RPC forward/backward/inference requests - :param min_backoff: after a repeated failure, sleep for this many seconds times 2 ^ (num_failures - 1) - :param max_backoff: limit maximal sleep time between retries to this value - :param ban_timeout: when a remote peer fails to respond, prevent routing to that peer for this many seconds - :param sequence_info: optionally, specify pre-generated sequence info. by default, create a new one using dht - :param rpc_info: optionally, specify rpc info (communicated tensor shapes and compression) to save time - :param allowed_servers: if defined, send requests only to these servers - :param start: start the background thread (see the note below). If false, you will need to start it manually. :note: RemoteSequenceManager takes up some CPU and network I/O to operate in background. It is recommended to avoid running redundant sequence managers for the same set of layers. - """ def __init__( self, - dht: DHT, + config: SequenceManagerConfig, block_uids: Sequence[ModuleUID], - p2p: P2P, - update_period: float = 30, - request_timeout: float = 30, - max_retries: Optional[int] = None, - min_backoff: float = 1, - max_backoff: float = 15 * 60, - ban_timeout: float = 15, - sequence_info: Optional[RemoteSequenceInfo] = None, - rpc_info: Optional[dict] = None, - allowed_servers: Optional[Collection[Union[str, hivemind.PeerID]]] = None, - banned_peers: Optional[Blacklist] = None, - # dear dev, if you add more parameters to this class, please make sure to handle them in __getitem__ (below) + *, + dht: Optional[DHT] = None, + state: Optional[SequenceManagerState] = None, ): assert len(block_uids) > 0, "Sequences must contain at least one block" - self.dht, self.p2p = dht, p2p - self.request_timeout, self.max_retries = request_timeout, max_retries - self.ban_timeout, self.min_backoff, self.max_backoff = ban_timeout, min_backoff, max_backoff + + self.config = config + if state is None: + state = SequenceManagerState() + self.state = state + + if dht is None: + dht = DHT( + initial_peers=config.initial_peers, + client_mode=True, + num_workers=config.n_layer, + startup_timeout=config.daemon_startup_timeout, + start=True, + ) + assert isinstance(dht, DHT) and dht.is_alive(), "`dht` must be a running hivemind.DHT instance" + self.dht = dht + + if state.p2p is None: + state.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p()) + self.lock_changes = threading.Lock() - self._thread = _SequenceManagerUpdateThread(update_period, WeakMethod(self._update)) + self._thread = _SequenceManagerUpdateThread(config.update_period, WeakMethod(self._update)) self._thread_start_lock = threading.Lock() self.policy = NoSpendingPolicy() - self._rpc_info = rpc_info - if allowed_servers is not None: - allowed_servers = { - PeerID.from_base58(peer_id) if isinstance(peer_id, str) else peer_id for peer_id in allowed_servers - } - self.allowed_servers = allowed_servers - self.banned_peers = Blacklist(base_time=ban_timeout, backoff_rate=2.0) if banned_peers is None else banned_peers - - if sequence_info is None: - self.sequence_info = RemoteSequenceInfo.make_empty(block_uids) + if state.banned_peers is None: + state.banned_peers = Blacklist(base_time=config.ban_timeout, backoff_rate=2.0) + if state.sequence_info is None: + state.sequence_info = RemoteSequenceInfo.make_empty(block_uids) + if state.sequence_info.last_updated_time is None: # Pre-fetch module infos in DHT in parallel with .from_pretrained(), then use cached records # in the first _update() instead of the latest ones. This makes the first .update() faster. petals.dht_utils.get_remote_module_infos(self.dht, self.block_uids, latest=True, return_future=True) self._need_latest_infos = False else: - self.sequence_info = sequence_info - assert block_uids == sequence_info.block_uids + assert block_uids == state.sequence_info.block_uids self._thread.ready.set() # no need to await the first dht fetch self._need_latest_infos = True @@ -118,7 +134,7 @@ class RemoteSequenceManager: span_sequence = [] current_index = start_index while current_index < end_index: - candidate_spans = self.sequence_info.spans_containing_block[current_index] + candidate_spans = self.state.sequence_info.spans_containing_block[current_index] if not candidate_spans: raise MissingBlocksError(current_index) if mode == "random": @@ -143,86 +159,62 @@ class RemoteSequenceManager: assert isinstance(ix, (int, slice)) if not isinstance(ix, slice): ix = slice(int(ix), int(ix) + 1, 1) - return type(self)( - self.dht, - self.block_uids[ix], - self.p2p, - update_period=self._thread.update_period, - request_timeout=self.request_timeout, - ban_timeout=self.ban_timeout, - min_backoff=self.min_backoff, - max_backoff=self.max_backoff, - sequence_info=self.sequence_info[ix], - rpc_info=self._rpc_info, - allowed_servers=self.allowed_servers, - banned_peers=self.banned_peers, - ) + return type(self)(self.config, self.block_uids[ix], dht=self.dht, state=self.state[ix]) def update(self, *, wait: bool): """Run an asynchronous update in background as soon as possible""" - self.ready.clear() # TODO this should be a separate event + self.ready.clear() self._thread.trigger.set() if wait: self.ready.wait() def _update(self): """Perform an immediate and synchronous refresh, may take time""" - for attempt_no in itertools.count(): - try: - new_block_infos = petals.dht_utils.get_remote_module_infos( - self.dht, self.block_uids, latest=self._need_latest_infos - ) - self._need_latest_infos = True # All future _update() should use latest infos - - for block_info in new_block_infos: - if not block_info: - continue - - # Apply whitelist, if defined - if self.allowed_servers is not None: - block_info.servers = { - peer_id: server_info - for peer_id, server_info in block_info.servers.items() - if peer_id in self.allowed_servers - } - - # Remove temporarily banned peers, unless there are no peers left - valid_servers = { - peer_id: server_info - for peer_id, server_info in block_info.servers.items() - if peer_id not in self.banned_peers - } - if len(valid_servers) < len(block_info.servers): - if valid_servers: - logger.debug( - f"Kept {len(valid_servers)} out of {len(block_info.servers)} servers holding {block_info.uid}" - ) - block_info.servers = valid_servers - else: - # If we blacklisted all servers, the error may actually be client-caused - logger.debug(f"All servers holding {block_info.uid} are blacklisted, ignoring blacklist") - - with self.lock_changes: - self.sequence_info.update_(new_block_infos) - missing_blocks = [i for i in range(len(self)) if not self.sequence_info.spans_containing_block[i]] - if missing_blocks: - raise MissingBlocksError(missing_blocks) - self.ready.set() # if there is an active server for every block, we may begin running - break + new_block_infos = petals.dht_utils.get_remote_module_infos( + self.dht, self.block_uids, latest=self._need_latest_infos + ) + self._need_latest_infos = True # All future _update() should use latest infos + + for block_info in new_block_infos: + if not block_info: + continue + + # Apply whitelist, if defined + if self.config.allowed_servers is not None: + block_info.servers = { + peer_id: server_info + for peer_id, server_info in block_info.servers.items() + if peer_id in self.config.allowed_servers or str(peer_id) in self.config.allowed_servers + } + + # Remove temporarily banned peers, unless there are no peers left + valid_servers = { + peer_id: server_info + for peer_id, server_info in block_info.servers.items() + if peer_id not in self.state.banned_peers + } + if len(valid_servers) < len(block_info.servers): + if valid_servers: + logger.debug( + f"Kept {len(valid_servers)} out of {len(block_info.servers)} servers holding {block_info.uid}" + ) + block_info.servers = valid_servers + else: + # If we blacklisted all servers, the error may actually be client-caused + logger.debug(f"All servers holding {block_info.uid} are blacklisted, ignoring blacklist") - except Exception as e: - delay = self.get_retry_delay(attempt_no) - logger.warning(f"Could not find route through the model: {repr(e)} (retry in {delay:.0f} sec)") - maybe_log_traceback(e) - time.sleep(delay) + with self.lock_changes: + self.state.sequence_info.update_(new_block_infos) + self.ready.set() - def on_request_failure(self, peer_id: PeerID): + def on_request_failure(self, peer_id: Optional[PeerID]): """remove a given peer from the routing table. If the routing is no longer possible, trigger an update""" - logger.info(f"Peer {peer_id} did not respond, banning it temporarily") - self.banned_peers.register_failure(peer_id) + if peer_id is not None: + logger.debug(f"Peer {peer_id} did not respond, banning it temporarily") + self.state.banned_peers.register_failure(peer_id) with self.lock_changes: should_update = False - for info in self.sequence_info.block_infos: + for info in self.state.sequence_info.block_infos: info.servers.pop(peer_id, None) if not info.servers: should_update = True @@ -232,7 +224,7 @@ class RemoteSequenceManager: def on_request_success(self, peer_id: PeerID): """if peer has a failure streak, clear that streak""" - self.banned_peers.register_success(peer_id) + self.state.banned_peers.register_success(peer_id) def __len__(self): return len(self.block_uids) @@ -247,57 +239,58 @@ class RemoteSequenceManager: @property def block_uids(self): - return self.sequence_info.block_uids + return self.state.sequence_info.block_uids @property def rpc_info(self): """Return the rpc_info queried from one of the servers that hold the first block""" - if self._rpc_info is None: - with self._thread_start_lock: - if not self.is_alive(): - self._thread.start() - - for attempt_no in itertools.count(): - peer_id = None - try: - if not self.ready.is_set(): - self.update(wait=True) - - active_servers = [ - peer_id - for peer_id, server in self.sequence_info.block_infos[0].servers.items() - if server.state == ServerState.ONLINE - ] - if not active_servers: - raise MissingBlocksError(0) - peer_id = random.choice(active_servers) - - stub = TransformerConnectionHandler.get_stub(self.p2p, peer_id) - outputs = RemoteExpertWorker.run_coroutine( - stub.rpc_info(runtime_pb2.ExpertUID(uid=self.block_uids[0])) - ) - self._rpc_info = MSGPackSerializer.loads(outputs.serialized_info) - self.on_request_success(peer_id) - break - except Exception as e: - if peer_id is not None and not isinstance(e, P2PHandlerError): - self.on_request_failure(peer_id) - if attempt_no + 1 == self.max_retries: - raise - delay = self.get_retry_delay(attempt_no) - logger.warning( - f"Caught exception when gathering information from peer {peer_id} " - f"(retry in {delay:.0f} sec): {repr(e)}" - ) - maybe_log_traceback(e) - time.sleep(delay) + if self.state.rpc_info is not None: + return self.state.rpc_info + + with self._thread_start_lock: + if not self.is_alive(): + self._thread.start() + + for attempt_no in itertools.count(): + peer_id = None + try: + if not self.ready.is_set(): + self.update(wait=True) + + active_servers = [ + peer_id + for peer_id, server in self.state.sequence_info.block_infos[0].servers.items() + if server.state == ServerState.ONLINE + ] + if not active_servers: + raise MissingBlocksError(0) + peer_id = random.choice(active_servers) + + stub = TransformerConnectionHandler.get_stub(self.state.p2p, peer_id) + outputs = RemoteExpertWorker.run_coroutine( + stub.rpc_info(runtime_pb2.ExpertUID(uid=self.block_uids[0]), timeout=self.config.request_timeout) + ) + self.state.rpc_info = MSGPackSerializer.loads(outputs.serialized_info) + self.on_request_success(peer_id) + break + except Exception as e: + self.on_request_failure(peer_id) + if attempt_no + 1 == self.config.max_retries: + raise + delay = self.get_retry_delay(attempt_no) + logger.warning( + f"Caught exception when gathering information from peer {peer_id} " + f"(retry in {delay:.0f} sec): {repr(e)}" + ) + maybe_log_traceback(e) + time.sleep(delay) - return self._rpc_info + return self.state.rpc_info def get_retry_delay(self, attempt_no: int) -> float: if attempt_no == 0: return 0 - return min(self.min_backoff * 2 ** (attempt_no - 1), self.max_backoff) + return min(self.config.min_backoff * 2 ** (attempt_no - 1), self.config.max_backoff) def get_request_metadata(self, protocol: str, *args, **kwargs) -> Optional[Dict[str, Any]]: """ diff --git a/src/petals/client/sequential_autograd.py b/src/petals/client/sequential_autograd.py index b846dfc..166b93c 100644 --- a/src/petals/client/sequential_autograd.py +++ b/src/petals/client/sequential_autograd.py @@ -67,7 +67,7 @@ async def sequential_forward( span = sequences.popleft() - stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id) + stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id) inputs_and_prompts = [inputs, prompts[span.start : span.end]] span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end]) @@ -77,7 +77,7 @@ async def sequential_forward( stub, sequence_manager.rpc_info, *inputs_and_prompts, - timeout=sequence_manager.request_timeout, + timeout=sequence_manager.config.request_timeout, metadata=MSGPackSerializer.dumps(metadata), ) @@ -93,9 +93,8 @@ async def sequential_forward( sequence_manager.on_request_success(span.peer_id) break except Exception as e: - if span is not None: - sequence_manager.on_request_failure(span.peer_id) - if attempt_no + 1 == sequence_manager.max_retries: + sequence_manager.on_request_failure(span.peer_id if span is not None else None) + if attempt_no + 1 == sequence_manager.config.max_retries: raise delay = sequence_manager.get_retry_delay(attempt_no) logger.warning( @@ -152,7 +151,7 @@ async def sequential_backward( span = forward_sequences.pop() span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end]) - stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id) + stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id) metadata = sequence_manager.get_request_metadata( "rpc_backward", span_uids, *inputs, *grad_outputs, peer_id=span.peer_id ) @@ -163,7 +162,7 @@ async def sequential_backward( inputs, grad_outputs, prompts[span.start : span.end], - timeout=sequence_manager.request_timeout, + timeout=sequence_manager.config.request_timeout, metadata=MSGPackSerializer.dumps(metadata), ) grad_outputs = [grad_outputs] @@ -171,9 +170,8 @@ async def sequential_backward( sequence_manager.on_request_success(span.peer_id) break except Exception as e: - if span is not None: - sequence_manager.on_request_failure(span.peer_id) - if attempt_no + 1 == sequence_manager.max_retries: + sequence_manager.on_request_failure(span.peer_id if span is not None else None) + if attempt_no + 1 == sequence_manager.config.max_retries: raise delay = sequence_manager.get_retry_delay(attempt_no) logger.warning( diff --git a/src/petals/dht_utils.py b/src/petals/dht_utils.py index 06c30eb..69cd64f 100644 --- a/src/petals/dht_utils.py +++ b/src/petals/dht_utils.py @@ -71,67 +71,6 @@ async def _declare_active_modules( ) -def get_remote_sequence( - dht: DHT, - start: int, - stop: int, - config: petals.client.DistributedBloomConfig, - dht_prefix: Optional[str] = None, - return_future: bool = False, -) -> Union[petals.client.RemoteSequential, MPFuture]: - return RemoteExpertWorker.run_coroutine( - _get_remote_sequence(dht, start, stop, config, dht_prefix), return_future=return_future - ) - - -async def _get_remote_sequence( - dht: DHT, - start: int, - stop: int, - config: petals.client.DistributedBloomConfig, - dht_prefix: Optional[str] = None, -) -> petals.client.RemoteSequential: - uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start, stop)] - p2p = await dht.replicate_p2p() - manager = petals.client.RemoteSequenceManager(dht, uids, p2p) - return petals.client.RemoteSequential(config, dht, dht_prefix, p2p, manager) - - -def get_remote_module( - dht: DHT, - uid_or_uids: Union[ModuleUID, List[ModuleUID]], - config: petals.client.DistributedBloomConfig, - dht_prefix: Optional[str] = None, - return_future: bool = False, -) -> Union[Union[petals.client.RemoteTransformerBlock, List[petals.client.RemoteTransformerBlock]], MPFuture]: - """ - :param uid_or_uids: find one or more modules with these ids from across the DHT - :param config: model config, usually taken by .from_pretrained(MODEL_NAME) - :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background. - :returns: a list of [RemoteTransformerBlock] - """ - return RemoteExpertWorker.run_coroutine( - _get_remote_module(dht, uid_or_uids, config, dht_prefix), return_future=return_future - ) - - -async def _get_remote_module( - dht: DHT, - uid_or_uids: Union[ModuleUID, List[ModuleUID]], - config: petals.client.DistributedBloomConfig, - dht_prefix: Optional[str] = None, -) -> Union[petals.client.RemoteTransformerBlock, List[petals.client.RemoteTransformerBlock]]: - single_uid = isinstance(uid_or_uids, ModuleUID) - uids = [uid_or_uids] if single_uid else uid_or_uids - p2p = await dht.replicate_p2p() - managers = (petals.client.RemoteSequenceManager(dht, [uid], p2p) for uid in uids) - modules = [ - petals.client.RemoteTransformerBlock(config, dht, dht_prefix=dht_prefix, p2p=p2p, sequence_manager=m) - for m in managers - ] - return modules[0] if single_uid else modules - - def get_remote_module_infos( dht: DHT, uids: Sequence[ModuleUID], diff --git a/tests/test_block_exact_match.py b/tests/test_block_exact_match.py index d2fbdde..4cddfed 100644 --- a/tests/test_block_exact_match.py +++ b/tests/test_block_exact_match.py @@ -1,28 +1,24 @@ import random from typing import Union -import hivemind import pytest import torch from transformers.models.bloom.configuration_bloom import BloomConfig from petals.bloom.block import WrappedBloomBlock from petals.bloom.from_pretrained import DTYPE_MAP, _load_state_dict, load_pretrained_block -from petals.client import DistributedBloomConfig -from petals.client.remote_sequential import RemoteTransformerBlock +from petals.client import DistributedBloomConfig, RemoteSequential from petals.data_structures import UID_DELIMITER -from petals.dht_utils import get_remote_module from test_utils import * @pytest.mark.forked def test_remote_block_exact_match(atol_forward=1e-4, atol_inference=1e-3): - dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True) - config = DistributedBloomConfig.from_pretrained(MODEL_NAME) + config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) + remote_sequential = RemoteSequential(config) for block_index in random.sample(range(config.n_layer), 3): - remote_block = get_remote_module(dht, f"{MODEL_NAME}{UID_DELIMITER}{block_index}", config) - assert isinstance(remote_block, RemoteTransformerBlock) + remote_block = remote_sequential[block_index] inputs = torch.randn(1, 8, config.hidden_size) outputs_forward = remote_block(inputs) @@ -36,7 +32,6 @@ def test_remote_block_exact_match(atol_forward=1e-4, atol_inference=1e-3): with pytest.raises(ValueError, match=r"Maximum length exceeded") as exc_info: sess.step(inputs[:, -1:, :]) assert "Maximum length exceeded" in repr(exc_info.value) - outputs_inference = torch.cat(outputs_inference, dim=1) ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32) diff --git a/tests/test_chained_calls.py b/tests/test_chained_calls.py index 9a619b7..15f3b5c 100644 --- a/tests/test_chained_calls.py +++ b/tests/test_chained_calls.py @@ -4,22 +4,19 @@ # - if you want to figure out chained inference, ask yozh -import hivemind import pytest import torch from petals.bloom.from_pretrained import load_pretrained_block from petals.client import DistributedBloomConfig from petals.client.remote_sequential import RemoteSequential -from petals.dht_utils import get_remote_sequence from test_utils import * @pytest.mark.forked def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1): - dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True) - config = DistributedBloomConfig.from_pretrained(MODEL_NAME) - remote_blocks = get_remote_sequence(dht, 3, 6, config) + config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) + remote_blocks = RemoteSequential(config, start_block=3, end_block=6) assert isinstance(remote_blocks, RemoteSequential) ref_blocks = [ @@ -46,10 +43,8 @@ def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq @pytest.mark.forked def test_chained_inference_exact_match(atol_inference=1e-4): - dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True) - config = DistributedBloomConfig.from_pretrained(MODEL_NAME) - remote_blocks = get_remote_sequence(dht, 3, 5, config) - assert isinstance(remote_blocks, RemoteSequential) + config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) + remote_blocks = RemoteSequential(config, start_block=3, end_block=5) inputs = torch.randn(1, 8, config.hidden_size) diff --git a/tests/test_remote_sequential.py b/tests/test_remote_sequential.py index a69c42f..d46ca1c 100644 --- a/tests/test_remote_sequential.py +++ b/tests/test_remote_sequential.py @@ -20,7 +20,7 @@ def test_remote_sequential(): test_inputs = torch.randn(1, 5, config.hidden_size, requires_grad=True) grad_proj = torch.randn(1, 5, config.hidden_size) - sequential = RemoteSequential(config, dht) + sequential = RemoteSequential(config, dht=dht) full_outputs = sequential(test_inputs) (full_outputs * grad_proj).sum().backward() @@ -48,7 +48,7 @@ def test_remote_sequential(): # test RemoteSequential with lossy compression block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)] lossy_sequential = RemoteSequential( - config, dht, sequence_manager=DummyCustomSequenceManager(dht, block_uids, sequential.p2p) + config, sequence_manager=DummyCustomSequenceManager(config, block_uids, dht=dht) ) test_inputs.grad = None @@ -85,8 +85,7 @@ class DummyCustomSequenceManager(RemoteSequenceManager): @pytest.mark.forked def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3): config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) - dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True) - remote_sequential = RemoteSequential(config, dht) + remote_sequential = RemoteSequential(config) inputs = F.normalize(torch.randn(batch_size, seq_len, config.hidden_size), dim=-1) output_proj = F.normalize(torch.randn(batch_size, seq_len + pre_seq_len, config.hidden_size), dim=-1) diff --git a/tests/test_sequence_manager.py b/tests/test_sequence_manager.py index f0b61cf..7dbc82f 100644 --- a/tests/test_sequence_manager.py +++ b/tests/test_sequence_manager.py @@ -18,15 +18,14 @@ logger = get_logger(__name__) def test_sequence_manager_basics(mode: str): config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True) - sequential = RemoteSequential(config, dht) + sequential = RemoteSequential(config, dht=dht) shutdown_evt = threading.Event() # test RemoteSequential with lossy compression block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)] sequential = RemoteSequential( config, - dht, - sequence_manager=TestSequenceManager(dht, block_uids, sequential.p2p, _was_shut_down=shutdown_evt), + sequence_manager=TestSequenceManager(config, block_uids, dht=dht, _was_shut_down=shutdown_evt), ) sequence = sequential.sequence_manager.make_sequence(mode=mode) diff --git a/tests/test_server_stats.py b/tests/test_server_stats.py index 54d6d33..0010167 100644 --- a/tests/test_server_stats.py +++ b/tests/test_server_stats.py @@ -4,34 +4,33 @@ import hivemind import pytest import torch -from petals.client import DistributedBloomConfig +from petals.client import DistributedBloomConfig, RemoteSequential from petals.data_structures import UID_DELIMITER -from petals.dht_utils import get_remote_sequence from petals.server.handler import CACHE_TOKENS_AVAILABLE from test_utils import * @pytest.mark.forked def test_server_info(block_from: int = 22, block_to: int = 24, max_length: int = 100, max_length2: int = 50): - dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True) config = DistributedBloomConfig.from_pretrained(MODEL_NAME) + dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True) + blocks1 = RemoteSequential(config, dht=dht, start_block=block_from, end_block=block_to) + blocks2 = RemoteSequential(config, dht=dht, start_block=block_to - 1, end_block=block_to) - blocks1 = get_remote_sequence(dht, block_from, block_to, config, f"{MODEL_NAME}{UID_DELIMITER}") - blocks2 = get_remote_sequence(dht, block_to - 1, block_to, config, f"{MODEL_NAME}{UID_DELIMITER}") info_before = blocks1.sequence_manager.rpc_info with blocks1.inference_session(max_length=max_length) as sess: sess.step(torch.randn(1, 1, config.hidden_size)) - blocks1.sequence_manager._rpc_info = None # invalidate cache + blocks1.sequence_manager.state.rpc_info = None # invalidate cache info_inside = blocks1.sequence_manager.rpc_info with blocks2.inference_session(max_length=max_length2) as sess2: sess2.step(torch.randn(1, 1, config.hidden_size)) - blocks2.sequence_manager._rpc_info = None # invalidate cache + blocks2.sequence_manager.state.rpc_info = None # invalidate cache info_inside2 = blocks2.sequence_manager.rpc_info time.sleep(0.1) - blocks1.sequence_manager._rpc_info = None # invalidate cache + blocks1.sequence_manager.state.rpc_info = None # invalidate cache info_after = blocks1.sequence_manager.rpc_info assert info_before[CACHE_TOKENS_AVAILABLE] == info_after[CACHE_TOKENS_AVAILABLE] From 0a313bf6c5d82b57103b973f6c851e5186f91cb1 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov
---|