mirror of
https://github.com/bigscience-workshop/petals
synced 2024-10-31 09:20:41 +00:00
8c546d988a
This PR extends CI to: 1. Test Llama code using [TinyLlama-v0](https://huggingface.co/Maykeye/TinyLLama-v0). 2. Test rebalancing (sets up a situation where the 1st server needs to change its original position). 3. Check if benchmark scripts run (in case someone breaks its code). Note that the benchmark results are meaningless here (since they're measured on a tiny swarm of CPU servers, with low `--n_steps`). 4. Test `petals.cli.run_dht`. 5. Increase swap space and watch free RAM (a common issue is that actions are cancelled without explanation if there's not enough RAM - so it's a useful reminder + debug tool). 6. Fix flapping tests for bloom-560m by increasing tolerance. Other minor changes: fix `--help` messages to show defaults, fix docs, tune rebalancing constants.
40 lines
1.8 KiB
Python
40 lines
1.8 KiB
Python
import time
|
|
|
|
import hivemind
|
|
import pytest
|
|
import torch
|
|
|
|
from petals import AutoDistributedConfig, RemoteSequential
|
|
from petals.server.handler import CACHE_TOKENS_AVAILABLE
|
|
from test_utils import *
|
|
|
|
|
|
@pytest.mark.forked
|
|
def test_server_info(block_from: int = 2, block_to: int = 5, max_length: int = 100, max_length2: int = 50):
|
|
config = AutoDistributedConfig.from_pretrained(MODEL_NAME)
|
|
config.allowed_servers = ["QmNV5G3hq2UmAck2htEgsqrmPFBff5goFZAdmKDcZLBZLX"] # PeerID from server2.id
|
|
|
|
dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
|
|
blocks1 = RemoteSequential(config, dht=dht, start_block=block_from, end_block=block_to)
|
|
blocks2 = RemoteSequential(config, dht=dht, start_block=block_to - 1, end_block=block_to)
|
|
|
|
info_before = blocks1.sequence_manager.rpc_info
|
|
|
|
with blocks1.inference_session(max_length=max_length) as sess:
|
|
sess.step(torch.randn(1, 1, config.hidden_size))
|
|
blocks1.sequence_manager.state.rpc_info = None # invalidate cache
|
|
info_inside = blocks1.sequence_manager.rpc_info
|
|
|
|
with blocks2.inference_session(max_length=max_length2) as sess2:
|
|
sess2.step(torch.randn(1, 1, config.hidden_size))
|
|
blocks2.sequence_manager.state.rpc_info = None # invalidate cache
|
|
info_inside2 = blocks2.sequence_manager.rpc_info
|
|
|
|
time.sleep(0.1)
|
|
blocks1.sequence_manager.state.rpc_info = None # invalidate cache
|
|
info_after = blocks1.sequence_manager.rpc_info
|
|
|
|
assert info_before[CACHE_TOKENS_AVAILABLE] == info_after[CACHE_TOKENS_AVAILABLE]
|
|
assert info_before[CACHE_TOKENS_AVAILABLE] - info_inside[CACHE_TOKENS_AVAILABLE] == max_length * len(blocks1)
|
|
assert info_inside[CACHE_TOKENS_AVAILABLE] - info_inside2[CACHE_TOKENS_AVAILABLE] == max_length2 * len(blocks2)
|