petals/tests/test_sequence_manager.py
Alexander Borzunov 8c546d988a
Test Llama, rebalancing, throughput eval, and all CLI scripts (#452)
This PR extends CI to:

1. Test Llama code using [TinyLlama-v0](https://huggingface.co/Maykeye/TinyLLama-v0).
2. Test rebalancing (sets up a situation where the 1st server needs to change its original position).
3. Check if benchmark scripts run (in case someone breaks its code). Note that the benchmark results are meaningless here (since they're measured on a tiny swarm of CPU servers, with low `--n_steps`).
4. Test `petals.cli.run_dht`.
5. Increase swap space and watch free RAM (a common issue is that actions are cancelled without explanation if there's not enough RAM - so it's a useful reminder + debug tool).
6. Fix flapping tests for bloom-560m by increasing tolerance.

Other minor changes: fix `--help` messages to show defaults, fix docs, tune rebalancing constants.
2023-08-08 19:10:27 +04:00

57 lines
1.9 KiB
Python

import threading
import time
import pytest
import torch
from hivemind import DHT, get_logger
from petals import AutoDistributedConfig
from petals.client import RemoteSequenceManager, RemoteSequential
from petals.data_structures import UID_DELIMITER
from test_utils import *
logger = get_logger(__name__)
@pytest.mark.forked
@pytest.mark.parametrize("mode", ["max_throughput", "min_latency"])
def test_sequence_manager_basics(mode: str):
config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
sequential = RemoteSequential(config, dht=dht)
shutdown_evt = threading.Event()
# test RemoteSequential with lossy compression
block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.num_hidden_layers)]
sequential = RemoteSequential(
config,
sequence_manager=RemoteSequenceManagerWithChecks(config, block_uids, dht=dht, _was_shut_down=shutdown_evt),
)
sequence = sequential.sequence_manager.make_sequence(mode=mode)
assert all(sequence[i].peer_id != sequence[i + 1].peer_id for i in range(len(sequence) - 1))
assert sequential.sequence_manager.is_alive()
assert sequential.sequence_manager._thread.ready.is_set()
assert not shutdown_evt.is_set()
sequential(torch.randn(1, 2, config.hidden_size))
sequential.sequence_manager.shutdown()
del sequential
time.sleep(1)
assert shutdown_evt.is_set()
class RemoteSequenceManagerWithChecks(RemoteSequenceManager):
"""A sequence manager that signals if it was shut down"""
def __init__(self, *args, _was_shut_down: threading.Event, **kwargs):
super().__init__(*args, **kwargs)
self._was_shut_down = _was_shut_down
def shutdown(self):
super().shutdown()
assert not self.is_alive()
self._was_shut_down.set()