mirror of
https://github.com/bigscience-workshop/petals
synced 2024-10-31 09:20:41 +00:00
8c546d988a
This PR extends CI to: 1. Test Llama code using [TinyLlama-v0](https://huggingface.co/Maykeye/TinyLLama-v0). 2. Test rebalancing (sets up a situation where the 1st server needs to change its original position). 3. Check if benchmark scripts run (in case someone breaks its code). Note that the benchmark results are meaningless here (since they're measured on a tiny swarm of CPU servers, with low `--n_steps`). 4. Test `petals.cli.run_dht`. 5. Increase swap space and watch free RAM (a common issue is that actions are cancelled without explanation if there's not enough RAM - so it's a useful reminder + debug tool). 6. Fix flapping tests for bloom-560m by increasing tolerance. Other minor changes: fix `--help` messages to show defaults, fix docs, tune rebalancing constants.
57 lines
1.9 KiB
Python
57 lines
1.9 KiB
Python
import threading
|
|
import time
|
|
|
|
import pytest
|
|
import torch
|
|
from hivemind import DHT, get_logger
|
|
|
|
from petals import AutoDistributedConfig
|
|
from petals.client import RemoteSequenceManager, RemoteSequential
|
|
from petals.data_structures import UID_DELIMITER
|
|
from test_utils import *
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
@pytest.mark.forked
|
|
@pytest.mark.parametrize("mode", ["max_throughput", "min_latency"])
|
|
def test_sequence_manager_basics(mode: str):
|
|
config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
|
|
dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
|
|
sequential = RemoteSequential(config, dht=dht)
|
|
shutdown_evt = threading.Event()
|
|
|
|
# test RemoteSequential with lossy compression
|
|
block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.num_hidden_layers)]
|
|
sequential = RemoteSequential(
|
|
config,
|
|
sequence_manager=RemoteSequenceManagerWithChecks(config, block_uids, dht=dht, _was_shut_down=shutdown_evt),
|
|
)
|
|
|
|
sequence = sequential.sequence_manager.make_sequence(mode=mode)
|
|
assert all(sequence[i].peer_id != sequence[i + 1].peer_id for i in range(len(sequence) - 1))
|
|
|
|
assert sequential.sequence_manager.is_alive()
|
|
assert sequential.sequence_manager._thread.ready.is_set()
|
|
assert not shutdown_evt.is_set()
|
|
sequential(torch.randn(1, 2, config.hidden_size))
|
|
|
|
sequential.sequence_manager.shutdown()
|
|
del sequential
|
|
time.sleep(1)
|
|
|
|
assert shutdown_evt.is_set()
|
|
|
|
|
|
class RemoteSequenceManagerWithChecks(RemoteSequenceManager):
|
|
"""A sequence manager that signals if it was shut down"""
|
|
|
|
def __init__(self, *args, _was_shut_down: threading.Event, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self._was_shut_down = _was_shut_down
|
|
|
|
def shutdown(self):
|
|
super().shutdown()
|
|
assert not self.is_alive()
|
|
self._was_shut_down.set()
|