diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index 3bccda3..735fd2a 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -10,10 +10,20 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [ '3.8', '3.9', '3.10', '3.11' ] + include: + - { model: 'bigscience/bloom-560m', python-version: '3.8' } + - { model: 'bigscience/bloom-560m', python-version: '3.9' } + - { model: 'bigscience/bloom-560m', python-version: '3.10' } + - { model: 'bigscience/bloom-560m', python-version: '3.11' } + - { model: 'Maykeye/TinyLLama-v0', python-version: '3.8' } + - { model: 'Maykeye/TinyLLama-v0', python-version: '3.11' } fail-fast: false timeout-minutes: 15 steps: + - name: Increase swap space + uses: pierotofy/set-swap-space@master + with: + swap-size-gb: 10 - name: Checkout uses: actions/checkout@v3 - name: Set up Python @@ -31,44 +41,77 @@ jobs: pip install .[dev] - name: Test run: | - export MODEL_NAME=bigscience/bloom-560m - export REF_NAME=bigscience/bloom-560m - export ADAPTER_NAME=artek0chumak/bloom-560m-safe-peft - - python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:12 \ - --new_swarm --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --throughput 1 \ - --torch_dtype float32 --compression NONE --attn_cache_tokens 2048 --max_chunk_size_bytes 1024 \ - --adapters $ADAPTER_NAME &> server1.log & - SERVER1_PID=$! + export MODEL_NAME="${{ matrix.model }}" + export REF_NAME="${{ matrix.model }}" + export ADAPTER_NAME="${{ matrix.model == 'bigscience/bloom-560m' && 'artek0chumak/bloom-560m-safe-peft' || '' }}" + export TENSOR_PARALLEL_ARGS="${{ matrix.model == 'bigscience/bloom-560m' && '--tensor_parallel_devices cpu cpu' || '' }}" + + # [Step 1] Watch free RAM (lack of RAM is a common issue in CI) + + bash -c 'while true; do free -h && sleep 30s; done' & + RAM_WATCH_PID=$! - sleep 5 # wait for the first server to initialize DHT + # [Step 2] Set up a tiny test swarm (see https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) + + python -m petals.cli.run_dht \ + --identity_path tests/bootstrap.id --host_maddrs /ip4/127.0.0.1/tcp/31337 &> bootstrap.log & + BOOTSTRAP_PID=$! export INITIAL_PEERS=/ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g - # ^-- server 1 multiaddr is determined by --identity and --host_maddrs + # ^-- multiaddr in INITIAL_PEERS is determined by --identity_path and --host_maddrs - python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 12:22 \ - --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --adapters $ADAPTER_NAME &> server2.log & - SERVER2_PID=$! + sleep 5 # wait for DHT init + + python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --num_blocks 5 \ + --mean_balance_check_period 10 \ + --initial_peers $INITIAL_PEERS --throughput 1 &> server1.log & + SERVER1_PID=$! + # ^-- rebalacing test: this server chooses blocks 0:5, then sees a gap in the swarm and moves there - sleep 10 # wait for initial servers to declare blocks, then let server decide which blocks to serve + sleep 10 # wait for the 1st server to choose blocks - python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 12:15 \ - --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --tensor_parallel_devices cpu cpu &> server3.log & + python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --block_indices 0:5 \ + --identity_path tests/server2.id \ + --initial_peers $INITIAL_PEERS --throughput 1 &> server2.log & + SERVER2_PID=$! + + python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --num_blocks 14 \ + --attn_cache_tokens 2048 --max_chunk_size_bytes 1024 \ + --initial_peers $INITIAL_PEERS --throughput auto &> server3.log & SERVER3_PID=$! + # ^-- chunking test - python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --num_blocks 3 \ - --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --adapters $ADAPTER_NAME &> server4.log & + python -m petals.cli.run_server $MODEL_NAME $TENSOR_PARALLEL_ARGS --torch_dtype float32 --block_indices 0:2 \ + --initial_peers $INITIAL_PEERS --throughput auto &> server4.log & SERVER4_PID=$! + # ^-- tensor parallelism test (not compatible with adapters yet) - tail -n 100 -f server*.log & + sleep 5 # wait for the log files to appear + + tail -n 100 -f bootstrap.log server*.log & LOGGER_PID=$! - sleep 30 # wait for servers to download layers - kill -0 $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all servers survived init + sleep 30 # wait for servers to eval throughput, download layers, and rebalance + kill -0 $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all peers survived init + + # [Step 3] Run PyTest pytest tests --durations=0 --durations-min=1.0 -v - kill -0 $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all servers survived tests + # [Step 4] Check if benchmarks work (their results here are meaningless since it's a tiny swarm of CPU servers) + + python benchmarks/benchmark_inference.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \ + --seq_len 3 + python benchmarks/benchmark_forward.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \ + --seq_len 3 --batch_size 3 --n_steps 1 + python benchmarks/benchmark_training.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \ + --seq_len 3 --batch_size 3 --pre_seq_len 1 --n_steps 1 --task cls + python benchmarks/benchmark_training.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \ + --seq_len 3 --batch_size 3 --pre_seq_len 1 --n_steps 1 --task causal_lm + + # [Step 5] Clean up + + kill -0 $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all peers survived tests - kill -s SIGINT $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $LOGGER_PID + kill -s SIGINT $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $LOGGER_PID $RAM_WATCH_PID echo "Done!" diff --git a/benchmarks/benchmark_forward.py b/benchmarks/benchmark_forward.py index e95c5ec..bf547ec 100755 --- a/benchmarks/benchmark_forward.py +++ b/benchmarks/benchmark_forward.py @@ -15,15 +15,15 @@ logger = get_logger() def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--model", type=str, default="bigscience/bloom") - parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS) - parser.add_argument("--torch_dtype", type=str, default="bfloat16") - parser.add_argument("--n_processes", type=str, default=1) - parser.add_argument("--seq_len", type=int, default=128) - parser.add_argument("--n_steps", type=int, default=100) - parser.add_argument("--batch_size", type=int, required=True) - parser.add_argument("--warmup_steps", type=int, default=1) + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("--model", type=str, required=True, help="Model") + parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS, help="Initial peers") + parser.add_argument("--torch_dtype", type=str, default="bfloat16", help="Torch dtype") + parser.add_argument("--n_processes", type=str, default=1, help="Number of concurrent processes") + parser.add_argument("--seq_len", type=int, default=128, help="Sequence length") + parser.add_argument("--n_steps", type=int, default=100, help="Number of benchmark steps") + parser.add_argument("--batch_size", type=int, required=True, help="Batch size") + parser.add_argument("--warmup_steps", type=int, default=1, help="Number of warmup steps") args = parser.parse_args() if args.n_processes == "n_gpus": diff --git a/benchmarks/benchmark_inference.py b/benchmarks/benchmark_inference.py index 607ff88..e894bb1 100755 --- a/benchmarks/benchmark_inference.py +++ b/benchmarks/benchmark_inference.py @@ -16,13 +16,13 @@ logger = get_logger() def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--model", type=str, default="bigscience/bloom") - parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS) - parser.add_argument("--torch_dtype", type=str, default="bfloat16") - parser.add_argument("--n_processes", type=str, default=1) - parser.add_argument("--seq_len", type=int, default=2048) - parser.add_argument("--warmup_steps", type=int, default=1) + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("--model", type=str, required=True, help="Model") + parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS, help="Initial peers") + parser.add_argument("--torch_dtype", type=str, default="bfloat16", help="Torch dtype") + parser.add_argument("--n_processes", type=str, default=1, help="Number of concurrent processes") + parser.add_argument("--seq_len", type=int, default=2048, help="Sequence length") + parser.add_argument("--warmup_steps", type=int, default=1, help="Number of warmup steps") args = parser.parse_args() if args.n_processes == "n_gpus": diff --git a/benchmarks/benchmark_training.py b/benchmarks/benchmark_training.py index 0853dfc..85061a3 100755 --- a/benchmarks/benchmark_training.py +++ b/benchmarks/benchmark_training.py @@ -15,18 +15,18 @@ logger = get_logger() def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--model", type=str, default="bigscience/bloom") - parser.add_argument("--device", type=str, default="cpu") - parser.add_argument("--task", type=str, default="cls") - parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS) - parser.add_argument("--torch_dtype", type=str, default="bfloat16") - parser.add_argument("--n_processes", type=str, default=1) - parser.add_argument("--seq_len", type=int, default=128) - parser.add_argument("--pre_seq_len", type=int, default=16) - parser.add_argument("--n_steps", type=int, default=10) - parser.add_argument("--batch_size", type=int, required=True) - parser.add_argument("--warmup_steps", type=int, default=1) + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("--model", type=str, required=True, help="Model") + parser.add_argument("--device", type=str, default="cpu", help="Torch device hosting the client") + parser.add_argument("--task", type=str, default="cls", help="Training task type") + parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS, help="Initial peers") + parser.add_argument("--torch_dtype", type=str, default="bfloat16", help="Torch dtype") + parser.add_argument("--n_processes", type=str, default=1, help="Number of concurrent processes") + parser.add_argument("--seq_len", type=int, default=128, help="Sequence length") + parser.add_argument("--pre_seq_len", type=int, default=16, help="Number of trainable tokens") + parser.add_argument("--n_steps", type=int, default=10, help="Number of benchmark steps") + parser.add_argument("--batch_size", type=int, required=True, help="Batch size") + parser.add_argument("--warmup_steps", type=int, default=1, help="Number of warmup steps") args = parser.parse_args() assert args.task in ["cls", "causal_lm"] diff --git a/src/petals/cli/run_dht.py b/src/petals/cli/run_dht.py index 2f30516..777d9d0 100644 --- a/src/petals/cli/run_dht.py +++ b/src/petals/cli/run_dht.py @@ -7,8 +7,8 @@ This script may be used for launching lightweight CPU machines serving as bootst This may be eventually merged to the hivemind upstream. """ +import argparse import time -from argparse import ArgumentParser from secrets import token_hex from hivemind.dht import DHT, DHTNode @@ -35,7 +35,7 @@ async def report_status(dht: DHT, node: DHTNode): def main(): - parser = ArgumentParser() + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument( "--initial_peers", nargs="*", @@ -73,7 +73,9 @@ def main(): help="Disable circuit relay functionality in libp2p (see https://docs.libp2p.io/concepts/nat/circuit-relay/)", ) parser.add_argument( - "--use_auto_relay", action="store_true", help="Look for libp2p relays to reach peers behind NATs/firewalls" + "--use_auto_relay", + action="store_true", + help="Look for libp2p relays to become reachable if we are behind NAT/firewall", ) parser.add_argument( "--refresh_period", type=int, default=30, help="Period (in seconds) for fetching the keys from DHT" diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index c82ff44..d85c8ac 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -158,7 +158,7 @@ def main(): "when connecting to the public swarm. If you connect to a private swarm, " "the check is skipped by default. Use this option only if you know what you are doing") - parser.add_argument("--adapters", nargs='+', default=(), + parser.add_argument("--adapters", nargs='*', default=(), help="List of pre-loaded LoRA adapters that can be used for inference or training") # fmt:on diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 7772fa6..bf7470a 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -78,7 +78,7 @@ class Server: sender_threads: int = 1, balance_quality: float = 0.75, mean_balance_check_period: float = 120, - mean_block_selection_delay: float = 2.5, + mean_block_selection_delay: float = 5, token: Optional[Union[str, bool]] = None, quant_type: Optional[QuantType] = None, tensor_parallel_devices: Optional[Sequence[torch.device]] = None, @@ -201,6 +201,8 @@ class Server: assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both" if num_blocks is None and block_indices is None: num_blocks = self._choose_num_blocks() + if num_blocks is not None: + num_blocks = min(num_blocks, self.block_config.num_hidden_layers) if block_indices is not None: try: first_block_index, last_block_index = block_indices.split(":") @@ -295,7 +297,7 @@ class Server: num_blocks = min(num_blocks, self.block_config.num_hidden_layers) logger.info( - f"Server will fill all your GPU memory with {num_blocks} transformer blocks. " + f"Server will fill your GPU memory with {num_blocks} transformer blocks. " f"If you want to leave some free GPU memory, please specify a lesser --num_blocks manually" ) return num_blocks diff --git a/tests/test.id b/tests/bootstrap.id similarity index 100% rename from tests/test.id rename to tests/bootstrap.id diff --git a/tests/server2.id b/tests/server2.id new file mode 100644 index 0000000..2615557 Binary files /dev/null and b/tests/server2.id differ diff --git a/tests/test_aux_functions.py b/tests/test_aux_functions.py index 64c9c6a..f75281e 100644 --- a/tests/test_aux_functions.py +++ b/tests/test_aux_functions.py @@ -29,6 +29,9 @@ def test_bnb_not_imported_when_unnecessary(): @pytest.mark.parametrize("tensor_parallel", [False, True]) def test_compute_throughput(inference: bool, n_tokens: int, tensor_parallel: bool): config = AutoDistributedConfig.from_pretrained(MODEL_NAME) + if tensor_parallel and config.model_type != "bloom": + pytest.skip("Tensor parallelism is implemented only for BLOOM for now") + tensor_parallel_devices = ("cpu", "cpu") if tensor_parallel else () compute_rps = measure_compute_rps( config, diff --git a/tests/test_block_exact_match.py b/tests/test_block_exact_match.py index 62c4e89..d98918b 100644 --- a/tests/test_block_exact_match.py +++ b/tests/test_block_exact_match.py @@ -3,14 +3,14 @@ import random import pytest import torch -from petals import DistributedBloomConfig, RemoteSequential +from petals import AutoDistributedConfig, RemoteSequential from petals.server.from_pretrained import load_pretrained_block from test_utils import * @pytest.mark.forked def test_remote_block_exact_match(atol_forward=1e-4, atol_inference=1e-3): - config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) + config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) remote_sequential = RemoteSequential(config) for block_index in random.sample(range(config.num_hidden_layers), 3): diff --git a/tests/test_chained_calls.py b/tests/test_chained_calls.py index d20f654..d4b012c 100644 --- a/tests/test_chained_calls.py +++ b/tests/test_chained_calls.py @@ -7,7 +7,7 @@ import pytest import torch -from petals import DistributedBloomConfig +from petals import AutoDistributedConfig from petals.client.remote_sequential import RemoteSequential from petals.server.from_pretrained import load_pretrained_block from test_utils import * @@ -15,7 +15,7 @@ from test_utils import * @pytest.mark.forked def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1): - config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) + config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) remote_blocks = RemoteSequential(config, start_block=3, end_block=6) assert isinstance(remote_blocks, RemoteSequential) @@ -43,7 +43,7 @@ def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq @pytest.mark.forked def test_chained_inference_exact_match(atol_inference=1e-4): - config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) + config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) remote_blocks = RemoteSequential(config, start_block=3, end_block=5) inputs = torch.randn(1, 8, config.hidden_size) diff --git a/tests/test_full_model.py b/tests/test_full_model.py index 511604b..dc2f3d7 100644 --- a/tests/test_full_model.py +++ b/tests/test_full_model.py @@ -3,29 +3,31 @@ import pytest import torch import transformers from hivemind import get_logger -from transformers.generation import BeamSearchScorer -from transformers.models.bloom import BloomForCausalLM +from transformers.generation import BeamSearchScorer, GenerationMixin as HfGenerationMixin -from petals import DistributedBloomForCausalLM +from petals import AutoDistributedModelForCausalLM from test_utils import * logger = get_logger(__name__) +@pytest.fixture +def tokenizer(): + # We set use_fast=False since LlamaTokenizerFast is slow on load + return transformers.AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False) + + @pytest.mark.forked @pytest.mark.parametrize("use_peft", (True, False) if ADAPTER_NAME else (False,)) @pytest.mark.parametrize("pass_empty_tensors", (True, False)) -def test_full_model_exact_match(use_peft: bool, pass_empty_tensors: bool, atol_forward=1e-3, atol_inference=1e-3): - tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME) - model = DistributedBloomForCausalLM.from_pretrained( +def test_full_model_exact_match(tokenizer, use_peft, pass_empty_tensors, atol_forward=1e-3, atol_inference=1e-3): + model = AutoDistributedModelForCausalLM.from_pretrained( MODEL_NAME, initial_peers=INITIAL_PEERS, - low_cpu_mem_usage=True, torch_dtype=torch.float32, active_adapter=ADAPTER_NAME if use_peft else None, ) config = model.config - assert isinstance(model, DistributedBloomForCausalLM) assert len(model.transformer.h) == model.config.num_hidden_layers test_inputs = tokenizer("A quick brown fox was minding its own buisness", return_tensors="pt")["input_ids"] @@ -63,7 +65,7 @@ def test_full_model_exact_match(use_peft: bool, pass_empty_tensors: bool, atol_f del model, embs, recurrent_outputs if REF_NAME: - ref_model = transformers.BloomForCausalLM.from_pretrained( + ref_model = transformers.AutoModelForCausalLM.from_pretrained( REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32 ) if use_peft: @@ -86,27 +88,29 @@ def test_full_model_exact_match(use_peft: bool, pass_empty_tensors: bool, atol_f @pytest.mark.forked -def test_greedy_generation(max_new_tokens=4): - tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME) - model = DistributedBloomForCausalLM.from_pretrained( - MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32 +def test_greedy_generation(tokenizer, max_new_tokens=4): + model = AutoDistributedModelForCausalLM.from_pretrained( + MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32 ) inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"] remote_outputs = model.generate( inputs, max_new_tokens=max_new_tokens, ) - hf_outputs = BloomForCausalLM.greedy_search(model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens) + hf_outputs = HfGenerationMixin.greedy_search(model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens) assert torch.allclose(remote_outputs, hf_outputs), "Greedy search results are not identical to HF" + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id inputs_batch = tokenizer(["A cat sat on a mat", "A dog sat on a mat"], return_tensors="pt", padding=True)[ "input_ids" ] + remote_outputs_batch = model.generate( inputs_batch, max_new_tokens=max_new_tokens, ) - hf_outputs_batch = BloomForCausalLM.greedy_search( + hf_outputs_batch = HfGenerationMixin.greedy_search( model, input_ids=inputs_batch, max_length=inputs_batch.size(1) + max_new_tokens ) assert torch.allclose( @@ -117,13 +121,13 @@ def test_greedy_generation(max_new_tokens=4): @pytest.mark.forked @pytest.mark.parametrize("sampling_options", [dict(), dict(temperature=100.0), dict(top_k=5), dict(top_p=0.9)]) @pytest.mark.skip("Sampling is currently not consistent with outputs from Transformers") -def test_sampling(sampling_options, max_new_tokens=4): +def test_sampling(tokenizer, sampling_options, max_new_tokens=4): torch.manual_seed(0) - tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME) - model = DistributedBloomForCausalLM.from_pretrained( - MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32 + + model = AutoDistributedModelForCausalLM.from_pretrained( + MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32 ) - logits_warper = BloomForCausalLM._get_logits_warper(model, num_beams=1, **sampling_options) + logits_warper = HfGenerationMixin._get_logits_warper(model, num_beams=1, **sampling_options) inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"] with torch.random.fork_rng(): remote_outputs = model.generate( @@ -133,7 +137,7 @@ def test_sampling(sampling_options, max_new_tokens=4): **sampling_options, ) with torch.random.fork_rng(): - hf_outputs = BloomForCausalLM.sample( + hf_outputs = HfGenerationMixin.sample( model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens, logits_warper=logits_warper ) assert torch.allclose(remote_outputs, hf_outputs), "Sampling results are not identical to HF" @@ -149,7 +153,7 @@ def test_sampling(sampling_options, max_new_tokens=4): **sampling_options, ) with torch.random.fork_rng(): - hf_outputs_batch = BloomForCausalLM.sample( + hf_outputs_batch = HfGenerationMixin.sample( model, input_ids=inputs_batch, max_length=inputs_batch.size(1) + max_new_tokens, @@ -161,10 +165,9 @@ def test_sampling(sampling_options, max_new_tokens=4): @pytest.mark.forked -def test_beam_search_generation(max_new_tokens=4, num_beams=2): - tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME) - model = DistributedBloomForCausalLM.from_pretrained( - MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32 +def test_beam_search_generation(tokenizer, max_new_tokens=4, num_beams=2): + model = AutoDistributedModelForCausalLM.from_pretrained( + MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32 ) text = "A cat sat on a mat" inputs = tokenizer(text, return_tensors="pt")["input_ids"] @@ -181,7 +184,7 @@ def test_beam_search_generation(max_new_tokens=4, num_beams=2): do_early_stopping=False, ) hf_inputs = tokenizer([text] * 2, return_tensors="pt")["input_ids"] - hf_outputs = BloomForCausalLM.beam_search( + hf_outputs = HfGenerationMixin.beam_search( model, input_ids=hf_inputs, max_length=inputs.size(1) + max_new_tokens, beam_scorer=beam_scorer ) assert torch.allclose(remote_outputs, hf_outputs), "Beam search results are not identical to HF" diff --git a/tests/test_remote_sequential.py b/tests/test_remote_sequential.py index 3c8a48f..9189e68 100644 --- a/tests/test_remote_sequential.py +++ b/tests/test_remote_sequential.py @@ -4,7 +4,7 @@ import torch.nn.functional as F from hivemind import DHT, BatchTensorDescriptor, get_logger from hivemind.proto import runtime_pb2 -from petals import DistributedBloomConfig +from petals import AutoDistributedConfig from petals.client import RemoteSequenceManager, RemoteSequential from petals.data_structures import UID_DELIMITER from petals.server.from_pretrained import load_pretrained_block @@ -15,7 +15,7 @@ logger = get_logger(__name__) @pytest.mark.forked def test_remote_sequential(): - config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) + config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True) test_inputs = torch.randn(1, 5, config.hidden_size, requires_grad=True) grad_proj = torch.randn(1, 5, config.hidden_size) @@ -40,10 +40,10 @@ def test_remote_sequential(): assert hidden.shape == test_inputs.shape assert hidden.requires_grad second_half_outputs = second_half(hidden) - assert torch.allclose(second_half_outputs, full_outputs, atol=1e-4) + assert torch.allclose(second_half_outputs, full_outputs, atol=3e-4) (second_half_outputs * grad_proj).sum().backward() - assert torch.allclose(test_inputs.grad, full_grad, atol=1e-3) + assert torch.allclose(test_inputs.grad, full_grad, atol=1e-2) # test RemoteSequential with lossy compression block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.num_hidden_layers)] @@ -56,7 +56,7 @@ def test_remote_sequential(): (approx_outputs * grad_proj).sum().backward() assert not torch.allclose(approx_outputs, full_outputs, rtol=0, atol=1e-4), "compression was not used" - assert not torch.allclose(test_inputs.grad, full_grad, rtol=0, atol=1e-2), "compression was not used" + assert not torch.allclose(test_inputs.grad, full_grad, rtol=0, atol=1e-3), "compression was not used" assert abs(approx_outputs - full_outputs).mean() < 0.01 absmax = abs(full_grad).max() assert abs(test_inputs.grad / absmax - full_grad / absmax).mean() < 0.05 @@ -87,7 +87,7 @@ class DummyCustomSequenceManager(RemoteSequenceManager): @pytest.mark.forked def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3): - config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) + config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) remote_sequential = RemoteSequential(config) inputs = F.normalize(torch.randn(batch_size, seq_len, config.hidden_size), dim=-1) diff --git a/tests/test_sequence_manager.py b/tests/test_sequence_manager.py index 03e17e3..6e3d3d3 100644 --- a/tests/test_sequence_manager.py +++ b/tests/test_sequence_manager.py @@ -5,7 +5,7 @@ import pytest import torch from hivemind import DHT, get_logger -from petals import DistributedBloomConfig +from petals import AutoDistributedConfig from petals.client import RemoteSequenceManager, RemoteSequential from petals.data_structures import UID_DELIMITER from test_utils import * @@ -16,7 +16,7 @@ logger = get_logger(__name__) @pytest.mark.forked @pytest.mark.parametrize("mode", ["max_throughput", "min_latency"]) def test_sequence_manager_basics(mode: str): - config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) + config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True) sequential = RemoteSequential(config, dht=dht) shutdown_evt = threading.Event() diff --git a/tests/test_server_stats.py b/tests/test_server_stats.py index 5de3393..c8e6ab6 100644 --- a/tests/test_server_stats.py +++ b/tests/test_server_stats.py @@ -4,14 +4,16 @@ import hivemind import pytest import torch -from petals import DistributedBloomConfig, RemoteSequential +from petals import AutoDistributedConfig, RemoteSequential from petals.server.handler import CACHE_TOKENS_AVAILABLE from test_utils import * @pytest.mark.forked -def test_server_info(block_from: int = 22, block_to: int = 24, max_length: int = 100, max_length2: int = 50): - config = DistributedBloomConfig.from_pretrained(MODEL_NAME) +def test_server_info(block_from: int = 2, block_to: int = 5, max_length: int = 100, max_length2: int = 50): + config = AutoDistributedConfig.from_pretrained(MODEL_NAME) + config.allowed_servers = ["QmNV5G3hq2UmAck2htEgsqrmPFBff5goFZAdmKDcZLBZLX"] # PeerID from server2.id + dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True) blocks1 = RemoteSequential(config, dht=dht, start_block=block_from, end_block=block_to) blocks2 = RemoteSequential(config, dht=dht, start_block=block_to - 1, end_block=block_to) diff --git a/tests/test_tensor_parallel.py b/tests/test_tensor_parallel.py index 408a261..5630f4a 100644 --- a/tests/test_tensor_parallel.py +++ b/tests/test_tensor_parallel.py @@ -14,8 +14,11 @@ from test_utils import MODEL_NAME @pytest.mark.parametrize("custom_config", [True, False]) @pytest.mark.parametrize("devices", [("cpu",) * 2, ("cpu",) * 3, ("cpu",) * 4]) def test_tp_block(devices, custom_config): - block_index = random.randint(0, 10) model_config = transformers.AutoConfig.from_pretrained(MODEL_NAME) + if model_config.model_type != "bloom": + pytest.skip("Tensor parallelism is implemented only for BLOOM for now") + + block_index = random.randint(0, 10) block = load_pretrained_block(MODEL_NAME, block_index=block_index, torch_dtype=torch.float32).to(devices[0]) tp_config = None