diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index af6299b..eb9c988 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -21,11 +21,11 @@ jobs: uses: actions/cache@v2 with: path: ~/.cache/pip - key: Key-v1-py3.9-${{ hashFiles('setup.cfg') }} + key: Key-v1-3.9-${{ hashFiles('setup.cfg') }} - name: Install dependencies run: | python -m pip install --upgrade pip - pip install . + pip install .[dev] - name: Delete any test models older than 1 week run: | python tests/scripts/remove_old_models.py --author bloom-testing --use_auth_token $BLOOM_TESTING_WRITE_TOKEN diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py index 2b282bc..441b9d4 100644 --- a/src/petals/client/routing/sequence_manager.py +++ b/src/petals/client/routing/sequence_manager.py @@ -114,6 +114,8 @@ class RemoteSequenceManager: current_index = start_index while current_index < end_index: candidate_spans = self.sequence_info.spans_containing_block[current_index] + if not candidate_spans: + raise MissingBlocksError(current_index) if mode == "random": chosen_span = random.choice(candidate_spans) # TODO this should be replaced with proper load balancing elif mode == "fastest": @@ -186,7 +188,7 @@ class RemoteSequenceManager: self.sequence_info.update_(new_block_infos) missing_blocks = [i for i in range(len(self)) if not self.sequence_info.spans_containing_block[i]] if missing_blocks: - raise MissingBlocksError(f"no servers holding blocks {missing_blocks}") + raise MissingBlocksError(missing_blocks) self.ready.set() # if there is an active server for every block, we may begin running break @@ -245,7 +247,7 @@ class RemoteSequenceManager: if server.state == ServerState.ONLINE ] if not active_servers: - raise MissingBlocksError("no servers holding the first block are online") + raise MissingBlocksError(0) peer_id = random.choice(active_servers) stub = TransformerConnectionHandler.get_stub(self.p2p, peer_id) @@ -334,6 +336,11 @@ def maybe_log_traceback(exc: Exception): logger.log(traceback_level, "See detailed traceback below:", exc_info=True) -class MissingBlocksError(Exception): - def __repr__(self): - return self.args[0] +class MissingBlocksError(RuntimeError): + def __init__(self, block_indices: Union[int, Sequence[int]]): + super().__init__( + f"No servers holding blocks {block_indices} are online.\n" + f"You can check the public swarm's state at http://health.petals.ml\n\n" + f"If there are not enough servers, please consider connecting your own GPU:\n" + f"https://github.com/bigscience-workshop/petals#connect-your-gpu-and-increase-petals-capacity" + )