Test Llama, rebalancing, throughput eval, and all CLI scripts (#452)

This PR extends CI to:

1. Test Llama code using [TinyLlama-v0](https://huggingface.co/Maykeye/TinyLLama-v0).
2. Test rebalancing (sets up a situation where the 1st server needs to change its original position).
3. Check if benchmark scripts run (in case someone breaks its code). Note that the benchmark results are meaningless here (since they're measured on a tiny swarm of CPU servers, with low `--n_steps`).
4. Test `petals.cli.run_dht`.
5. Increase swap space and watch free RAM (a common issue is that actions are cancelled without explanation if there's not enough RAM - so it's a useful reminder + debug tool).
6. Fix flapping tests for bloom-560m by increasing tolerance.

Other minor changes: fix `--help` messages to show defaults, fix docs, tune rebalancing constants.
pull/438/head^2
Alexander Borzunov 9 months ago committed by GitHub
parent df6fdd2d0b
commit 8c546d988a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -10,10 +10,20 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
python-version: [ '3.8', '3.9', '3.10', '3.11' ] include:
- { model: 'bigscience/bloom-560m', python-version: '3.8' }
- { model: 'bigscience/bloom-560m', python-version: '3.9' }
- { model: 'bigscience/bloom-560m', python-version: '3.10' }
- { model: 'bigscience/bloom-560m', python-version: '3.11' }
- { model: 'Maykeye/TinyLLama-v0', python-version: '3.8' }
- { model: 'Maykeye/TinyLLama-v0', python-version: '3.11' }
fail-fast: false fail-fast: false
timeout-minutes: 15 timeout-minutes: 15
steps: steps:
- name: Increase swap space
uses: pierotofy/set-swap-space@master
with:
swap-size-gb: 10
- name: Checkout - name: Checkout
uses: actions/checkout@v3 uses: actions/checkout@v3
- name: Set up Python - name: Set up Python
@ -31,44 +41,77 @@ jobs:
pip install .[dev] pip install .[dev]
- name: Test - name: Test
run: | run: |
export MODEL_NAME=bigscience/bloom-560m export MODEL_NAME="${{ matrix.model }}"
export REF_NAME=bigscience/bloom-560m export REF_NAME="${{ matrix.model }}"
export ADAPTER_NAME=artek0chumak/bloom-560m-safe-peft export ADAPTER_NAME="${{ matrix.model == 'bigscience/bloom-560m' && 'artek0chumak/bloom-560m-safe-peft' || '' }}"
export TENSOR_PARALLEL_ARGS="${{ matrix.model == 'bigscience/bloom-560m' && '--tensor_parallel_devices cpu cpu' || '' }}"
python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:12 \
--new_swarm --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --throughput 1 \ # [Step 1] Watch free RAM (lack of RAM is a common issue in CI)
--torch_dtype float32 --compression NONE --attn_cache_tokens 2048 --max_chunk_size_bytes 1024 \
--adapters $ADAPTER_NAME &> server1.log & bash -c 'while true; do free -h && sleep 30s; done' &
SERVER1_PID=$! RAM_WATCH_PID=$!
sleep 5 # wait for the first server to initialize DHT # [Step 2] Set up a tiny test swarm (see https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm)
python -m petals.cli.run_dht \
--identity_path tests/bootstrap.id --host_maddrs /ip4/127.0.0.1/tcp/31337 &> bootstrap.log &
BOOTSTRAP_PID=$!
export INITIAL_PEERS=/ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g export INITIAL_PEERS=/ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g
# ^-- server 1 multiaddr is determined by --identity and --host_maddrs # ^-- multiaddr in INITIAL_PEERS is determined by --identity_path and --host_maddrs
python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 12:22 \ sleep 5 # wait for DHT init
--initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --adapters $ADAPTER_NAME &> server2.log &
SERVER2_PID=$! python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --num_blocks 5 \
--mean_balance_check_period 10 \
--initial_peers $INITIAL_PEERS --throughput 1 &> server1.log &
SERVER1_PID=$!
# ^-- rebalacing test: this server chooses blocks 0:5, then sees a gap in the swarm and moves there
sleep 10 # wait for initial servers to declare blocks, then let server decide which blocks to serve sleep 10 # wait for the 1st server to choose blocks
python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 12:15 \ python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --block_indices 0:5 \
--initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --tensor_parallel_devices cpu cpu &> server3.log & --identity_path tests/server2.id \
--initial_peers $INITIAL_PEERS --throughput 1 &> server2.log &
SERVER2_PID=$!
python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --num_blocks 14 \
--attn_cache_tokens 2048 --max_chunk_size_bytes 1024 \
--initial_peers $INITIAL_PEERS --throughput auto &> server3.log &
SERVER3_PID=$! SERVER3_PID=$!
# ^-- chunking test
python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --num_blocks 3 \ python -m petals.cli.run_server $MODEL_NAME $TENSOR_PARALLEL_ARGS --torch_dtype float32 --block_indices 0:2 \
--initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --adapters $ADAPTER_NAME &> server4.log & --initial_peers $INITIAL_PEERS --throughput auto &> server4.log &
SERVER4_PID=$! SERVER4_PID=$!
# ^-- tensor parallelism test (not compatible with adapters yet)
tail -n 100 -f server*.log & sleep 5 # wait for the log files to appear
tail -n 100 -f bootstrap.log server*.log &
LOGGER_PID=$! LOGGER_PID=$!
sleep 30 # wait for servers to download layers
kill -0 $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all servers survived init sleep 30 # wait for servers to eval throughput, download layers, and rebalance
kill -0 $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all peers survived init
# [Step 3] Run PyTest
pytest tests --durations=0 --durations-min=1.0 -v pytest tests --durations=0 --durations-min=1.0 -v
kill -0 $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all servers survived tests # [Step 4] Check if benchmarks work (their results here are meaningless since it's a tiny swarm of CPU servers)
python benchmarks/benchmark_inference.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \
--seq_len 3
python benchmarks/benchmark_forward.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \
--seq_len 3 --batch_size 3 --n_steps 1
python benchmarks/benchmark_training.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \
--seq_len 3 --batch_size 3 --pre_seq_len 1 --n_steps 1 --task cls
python benchmarks/benchmark_training.py --model $MODEL_NAME --initial_peers $INITIAL_PEERS --torch_dtype float32 \
--seq_len 3 --batch_size 3 --pre_seq_len 1 --n_steps 1 --task causal_lm
# [Step 5] Clean up
kill -0 $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all peers survived tests
kill -s SIGINT $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $LOGGER_PID kill -s SIGINT $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $LOGGER_PID $RAM_WATCH_PID
echo "Done!" echo "Done!"

@ -15,15 +15,15 @@ logger = get_logger()
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--model", type=str, default="bigscience/bloom") parser.add_argument("--model", type=str, required=True, help="Model")
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS) parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS, help="Initial peers")
parser.add_argument("--torch_dtype", type=str, default="bfloat16") parser.add_argument("--torch_dtype", type=str, default="bfloat16", help="Torch dtype")
parser.add_argument("--n_processes", type=str, default=1) parser.add_argument("--n_processes", type=str, default=1, help="Number of concurrent processes")
parser.add_argument("--seq_len", type=int, default=128) parser.add_argument("--seq_len", type=int, default=128, help="Sequence length")
parser.add_argument("--n_steps", type=int, default=100) parser.add_argument("--n_steps", type=int, default=100, help="Number of benchmark steps")
parser.add_argument("--batch_size", type=int, required=True) parser.add_argument("--batch_size", type=int, required=True, help="Batch size")
parser.add_argument("--warmup_steps", type=int, default=1) parser.add_argument("--warmup_steps", type=int, default=1, help="Number of warmup steps")
args = parser.parse_args() args = parser.parse_args()
if args.n_processes == "n_gpus": if args.n_processes == "n_gpus":

@ -16,13 +16,13 @@ logger = get_logger()
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--model", type=str, default="bigscience/bloom") parser.add_argument("--model", type=str, required=True, help="Model")
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS) parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS, help="Initial peers")
parser.add_argument("--torch_dtype", type=str, default="bfloat16") parser.add_argument("--torch_dtype", type=str, default="bfloat16", help="Torch dtype")
parser.add_argument("--n_processes", type=str, default=1) parser.add_argument("--n_processes", type=str, default=1, help="Number of concurrent processes")
parser.add_argument("--seq_len", type=int, default=2048) parser.add_argument("--seq_len", type=int, default=2048, help="Sequence length")
parser.add_argument("--warmup_steps", type=int, default=1) parser.add_argument("--warmup_steps", type=int, default=1, help="Number of warmup steps")
args = parser.parse_args() args = parser.parse_args()
if args.n_processes == "n_gpus": if args.n_processes == "n_gpus":

@ -15,18 +15,18 @@ logger = get_logger()
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--model", type=str, default="bigscience/bloom") parser.add_argument("--model", type=str, required=True, help="Model")
parser.add_argument("--device", type=str, default="cpu") parser.add_argument("--device", type=str, default="cpu", help="Torch device hosting the client")
parser.add_argument("--task", type=str, default="cls") parser.add_argument("--task", type=str, default="cls", help="Training task type")
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS) parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS, help="Initial peers")
parser.add_argument("--torch_dtype", type=str, default="bfloat16") parser.add_argument("--torch_dtype", type=str, default="bfloat16", help="Torch dtype")
parser.add_argument("--n_processes", type=str, default=1) parser.add_argument("--n_processes", type=str, default=1, help="Number of concurrent processes")
parser.add_argument("--seq_len", type=int, default=128) parser.add_argument("--seq_len", type=int, default=128, help="Sequence length")
parser.add_argument("--pre_seq_len", type=int, default=16) parser.add_argument("--pre_seq_len", type=int, default=16, help="Number of trainable tokens")
parser.add_argument("--n_steps", type=int, default=10) parser.add_argument("--n_steps", type=int, default=10, help="Number of benchmark steps")
parser.add_argument("--batch_size", type=int, required=True) parser.add_argument("--batch_size", type=int, required=True, help="Batch size")
parser.add_argument("--warmup_steps", type=int, default=1) parser.add_argument("--warmup_steps", type=int, default=1, help="Number of warmup steps")
args = parser.parse_args() args = parser.parse_args()
assert args.task in ["cls", "causal_lm"] assert args.task in ["cls", "causal_lm"]

@ -7,8 +7,8 @@ This script may be used for launching lightweight CPU machines serving as bootst
This may be eventually merged to the hivemind upstream. This may be eventually merged to the hivemind upstream.
""" """
import argparse
import time import time
from argparse import ArgumentParser
from secrets import token_hex from secrets import token_hex
from hivemind.dht import DHT, DHTNode from hivemind.dht import DHT, DHTNode
@ -35,7 +35,7 @@ async def report_status(dht: DHT, node: DHTNode):
def main(): def main():
parser = ArgumentParser() parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument( parser.add_argument(
"--initial_peers", "--initial_peers",
nargs="*", nargs="*",
@ -73,7 +73,9 @@ def main():
help="Disable circuit relay functionality in libp2p (see https://docs.libp2p.io/concepts/nat/circuit-relay/)", help="Disable circuit relay functionality in libp2p (see https://docs.libp2p.io/concepts/nat/circuit-relay/)",
) )
parser.add_argument( parser.add_argument(
"--use_auto_relay", action="store_true", help="Look for libp2p relays to reach peers behind NATs/firewalls" "--use_auto_relay",
action="store_true",
help="Look for libp2p relays to become reachable if we are behind NAT/firewall",
) )
parser.add_argument( parser.add_argument(
"--refresh_period", type=int, default=30, help="Period (in seconds) for fetching the keys from DHT" "--refresh_period", type=int, default=30, help="Period (in seconds) for fetching the keys from DHT"

@ -158,7 +158,7 @@ def main():
"when connecting to the public swarm. If you connect to a private swarm, " "when connecting to the public swarm. If you connect to a private swarm, "
"the check is skipped by default. Use this option only if you know what you are doing") "the check is skipped by default. Use this option only if you know what you are doing")
parser.add_argument("--adapters", nargs='+', default=(), parser.add_argument("--adapters", nargs='*', default=(),
help="List of pre-loaded LoRA adapters that can be used for inference or training") help="List of pre-loaded LoRA adapters that can be used for inference or training")
# fmt:on # fmt:on

@ -78,7 +78,7 @@ class Server:
sender_threads: int = 1, sender_threads: int = 1,
balance_quality: float = 0.75, balance_quality: float = 0.75,
mean_balance_check_period: float = 120, mean_balance_check_period: float = 120,
mean_block_selection_delay: float = 2.5, mean_block_selection_delay: float = 5,
token: Optional[Union[str, bool]] = None, token: Optional[Union[str, bool]] = None,
quant_type: Optional[QuantType] = None, quant_type: Optional[QuantType] = None,
tensor_parallel_devices: Optional[Sequence[torch.device]] = None, tensor_parallel_devices: Optional[Sequence[torch.device]] = None,
@ -201,6 +201,8 @@ class Server:
assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both" assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both"
if num_blocks is None and block_indices is None: if num_blocks is None and block_indices is None:
num_blocks = self._choose_num_blocks() num_blocks = self._choose_num_blocks()
if num_blocks is not None:
num_blocks = min(num_blocks, self.block_config.num_hidden_layers)
if block_indices is not None: if block_indices is not None:
try: try:
first_block_index, last_block_index = block_indices.split(":") first_block_index, last_block_index = block_indices.split(":")
@ -295,7 +297,7 @@ class Server:
num_blocks = min(num_blocks, self.block_config.num_hidden_layers) num_blocks = min(num_blocks, self.block_config.num_hidden_layers)
logger.info( logger.info(
f"Server will fill all your GPU memory with {num_blocks} transformer blocks. " f"Server will fill your GPU memory with {num_blocks} transformer blocks. "
f"If you want to leave some free GPU memory, please specify a lesser --num_blocks manually" f"If you want to leave some free GPU memory, please specify a lesser --num_blocks manually"
) )
return num_blocks return num_blocks

Binary file not shown.

@ -29,6 +29,9 @@ def test_bnb_not_imported_when_unnecessary():
@pytest.mark.parametrize("tensor_parallel", [False, True]) @pytest.mark.parametrize("tensor_parallel", [False, True])
def test_compute_throughput(inference: bool, n_tokens: int, tensor_parallel: bool): def test_compute_throughput(inference: bool, n_tokens: int, tensor_parallel: bool):
config = AutoDistributedConfig.from_pretrained(MODEL_NAME) config = AutoDistributedConfig.from_pretrained(MODEL_NAME)
if tensor_parallel and config.model_type != "bloom":
pytest.skip("Tensor parallelism is implemented only for BLOOM for now")
tensor_parallel_devices = ("cpu", "cpu") if tensor_parallel else () tensor_parallel_devices = ("cpu", "cpu") if tensor_parallel else ()
compute_rps = measure_compute_rps( compute_rps = measure_compute_rps(
config, config,

@ -3,14 +3,14 @@ import random
import pytest import pytest
import torch import torch
from petals import DistributedBloomConfig, RemoteSequential from petals import AutoDistributedConfig, RemoteSequential
from petals.server.from_pretrained import load_pretrained_block from petals.server.from_pretrained import load_pretrained_block
from test_utils import * from test_utils import *
@pytest.mark.forked @pytest.mark.forked
def test_remote_block_exact_match(atol_forward=1e-4, atol_inference=1e-3): def test_remote_block_exact_match(atol_forward=1e-4, atol_inference=1e-3):
config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
remote_sequential = RemoteSequential(config) remote_sequential = RemoteSequential(config)
for block_index in random.sample(range(config.num_hidden_layers), 3): for block_index in random.sample(range(config.num_hidden_layers), 3):

@ -7,7 +7,7 @@
import pytest import pytest
import torch import torch
from petals import DistributedBloomConfig from petals import AutoDistributedConfig
from petals.client.remote_sequential import RemoteSequential from petals.client.remote_sequential import RemoteSequential
from petals.server.from_pretrained import load_pretrained_block from petals.server.from_pretrained import load_pretrained_block
from test_utils import * from test_utils import *
@ -15,7 +15,7 @@ from test_utils import *
@pytest.mark.forked @pytest.mark.forked
def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1): def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):
config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
remote_blocks = RemoteSequential(config, start_block=3, end_block=6) remote_blocks = RemoteSequential(config, start_block=3, end_block=6)
assert isinstance(remote_blocks, RemoteSequential) assert isinstance(remote_blocks, RemoteSequential)
@ -43,7 +43,7 @@ def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq
@pytest.mark.forked @pytest.mark.forked
def test_chained_inference_exact_match(atol_inference=1e-4): def test_chained_inference_exact_match(atol_inference=1e-4):
config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
remote_blocks = RemoteSequential(config, start_block=3, end_block=5) remote_blocks = RemoteSequential(config, start_block=3, end_block=5)
inputs = torch.randn(1, 8, config.hidden_size) inputs = torch.randn(1, 8, config.hidden_size)

@ -3,29 +3,31 @@ import pytest
import torch import torch
import transformers import transformers
from hivemind import get_logger from hivemind import get_logger
from transformers.generation import BeamSearchScorer from transformers.generation import BeamSearchScorer, GenerationMixin as HfGenerationMixin
from transformers.models.bloom import BloomForCausalLM
from petals import DistributedBloomForCausalLM from petals import AutoDistributedModelForCausalLM
from test_utils import * from test_utils import *
logger = get_logger(__name__) logger = get_logger(__name__)
@pytest.fixture
def tokenizer():
# We set use_fast=False since LlamaTokenizerFast is slow on load
return transformers.AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
@pytest.mark.forked @pytest.mark.forked
@pytest.mark.parametrize("use_peft", (True, False) if ADAPTER_NAME else (False,)) @pytest.mark.parametrize("use_peft", (True, False) if ADAPTER_NAME else (False,))
@pytest.mark.parametrize("pass_empty_tensors", (True, False)) @pytest.mark.parametrize("pass_empty_tensors", (True, False))
def test_full_model_exact_match(use_peft: bool, pass_empty_tensors: bool, atol_forward=1e-3, atol_inference=1e-3): def test_full_model_exact_match(tokenizer, use_peft, pass_empty_tensors, atol_forward=1e-3, atol_inference=1e-3):
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME) model = AutoDistributedModelForCausalLM.from_pretrained(
model = DistributedBloomForCausalLM.from_pretrained(
MODEL_NAME, MODEL_NAME,
initial_peers=INITIAL_PEERS, initial_peers=INITIAL_PEERS,
low_cpu_mem_usage=True,
torch_dtype=torch.float32, torch_dtype=torch.float32,
active_adapter=ADAPTER_NAME if use_peft else None, active_adapter=ADAPTER_NAME if use_peft else None,
) )
config = model.config config = model.config
assert isinstance(model, DistributedBloomForCausalLM)
assert len(model.transformer.h) == model.config.num_hidden_layers assert len(model.transformer.h) == model.config.num_hidden_layers
test_inputs = tokenizer("A quick brown fox was minding its own buisness", return_tensors="pt")["input_ids"] test_inputs = tokenizer("A quick brown fox was minding its own buisness", return_tensors="pt")["input_ids"]
@ -63,7 +65,7 @@ def test_full_model_exact_match(use_peft: bool, pass_empty_tensors: bool, atol_f
del model, embs, recurrent_outputs del model, embs, recurrent_outputs
if REF_NAME: if REF_NAME:
ref_model = transformers.BloomForCausalLM.from_pretrained( ref_model = transformers.AutoModelForCausalLM.from_pretrained(
REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32 REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
) )
if use_peft: if use_peft:
@ -86,27 +88,29 @@ def test_full_model_exact_match(use_peft: bool, pass_empty_tensors: bool, atol_f
@pytest.mark.forked @pytest.mark.forked
def test_greedy_generation(max_new_tokens=4): def test_greedy_generation(tokenizer, max_new_tokens=4):
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME) model = AutoDistributedModelForCausalLM.from_pretrained(
model = DistributedBloomForCausalLM.from_pretrained( MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32
MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
) )
inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"] inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
remote_outputs = model.generate( remote_outputs = model.generate(
inputs, inputs,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
) )
hf_outputs = BloomForCausalLM.greedy_search(model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens) hf_outputs = HfGenerationMixin.greedy_search(model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens)
assert torch.allclose(remote_outputs, hf_outputs), "Greedy search results are not identical to HF" assert torch.allclose(remote_outputs, hf_outputs), "Greedy search results are not identical to HF"
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
inputs_batch = tokenizer(["A cat sat on a mat", "A dog sat on a mat"], return_tensors="pt", padding=True)[ inputs_batch = tokenizer(["A cat sat on a mat", "A dog sat on a mat"], return_tensors="pt", padding=True)[
"input_ids" "input_ids"
] ]
remote_outputs_batch = model.generate( remote_outputs_batch = model.generate(
inputs_batch, inputs_batch,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
) )
hf_outputs_batch = BloomForCausalLM.greedy_search( hf_outputs_batch = HfGenerationMixin.greedy_search(
model, input_ids=inputs_batch, max_length=inputs_batch.size(1) + max_new_tokens model, input_ids=inputs_batch, max_length=inputs_batch.size(1) + max_new_tokens
) )
assert torch.allclose( assert torch.allclose(
@ -117,13 +121,13 @@ def test_greedy_generation(max_new_tokens=4):
@pytest.mark.forked @pytest.mark.forked
@pytest.mark.parametrize("sampling_options", [dict(), dict(temperature=100.0), dict(top_k=5), dict(top_p=0.9)]) @pytest.mark.parametrize("sampling_options", [dict(), dict(temperature=100.0), dict(top_k=5), dict(top_p=0.9)])
@pytest.mark.skip("Sampling is currently not consistent with outputs from Transformers") @pytest.mark.skip("Sampling is currently not consistent with outputs from Transformers")
def test_sampling(sampling_options, max_new_tokens=4): def test_sampling(tokenizer, sampling_options, max_new_tokens=4):
torch.manual_seed(0) torch.manual_seed(0)
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
model = DistributedBloomForCausalLM.from_pretrained( model = AutoDistributedModelForCausalLM.from_pretrained(
MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32 MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32
) )
logits_warper = BloomForCausalLM._get_logits_warper(model, num_beams=1, **sampling_options) logits_warper = HfGenerationMixin._get_logits_warper(model, num_beams=1, **sampling_options)
inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"] inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
with torch.random.fork_rng(): with torch.random.fork_rng():
remote_outputs = model.generate( remote_outputs = model.generate(
@ -133,7 +137,7 @@ def test_sampling(sampling_options, max_new_tokens=4):
**sampling_options, **sampling_options,
) )
with torch.random.fork_rng(): with torch.random.fork_rng():
hf_outputs = BloomForCausalLM.sample( hf_outputs = HfGenerationMixin.sample(
model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens, logits_warper=logits_warper model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens, logits_warper=logits_warper
) )
assert torch.allclose(remote_outputs, hf_outputs), "Sampling results are not identical to HF" assert torch.allclose(remote_outputs, hf_outputs), "Sampling results are not identical to HF"
@ -149,7 +153,7 @@ def test_sampling(sampling_options, max_new_tokens=4):
**sampling_options, **sampling_options,
) )
with torch.random.fork_rng(): with torch.random.fork_rng():
hf_outputs_batch = BloomForCausalLM.sample( hf_outputs_batch = HfGenerationMixin.sample(
model, model,
input_ids=inputs_batch, input_ids=inputs_batch,
max_length=inputs_batch.size(1) + max_new_tokens, max_length=inputs_batch.size(1) + max_new_tokens,
@ -161,10 +165,9 @@ def test_sampling(sampling_options, max_new_tokens=4):
@pytest.mark.forked @pytest.mark.forked
def test_beam_search_generation(max_new_tokens=4, num_beams=2): def test_beam_search_generation(tokenizer, max_new_tokens=4, num_beams=2):
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME) model = AutoDistributedModelForCausalLM.from_pretrained(
model = DistributedBloomForCausalLM.from_pretrained( MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32
MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
) )
text = "A cat sat on a mat" text = "A cat sat on a mat"
inputs = tokenizer(text, return_tensors="pt")["input_ids"] inputs = tokenizer(text, return_tensors="pt")["input_ids"]
@ -181,7 +184,7 @@ def test_beam_search_generation(max_new_tokens=4, num_beams=2):
do_early_stopping=False, do_early_stopping=False,
) )
hf_inputs = tokenizer([text] * 2, return_tensors="pt")["input_ids"] hf_inputs = tokenizer([text] * 2, return_tensors="pt")["input_ids"]
hf_outputs = BloomForCausalLM.beam_search( hf_outputs = HfGenerationMixin.beam_search(
model, input_ids=hf_inputs, max_length=inputs.size(1) + max_new_tokens, beam_scorer=beam_scorer model, input_ids=hf_inputs, max_length=inputs.size(1) + max_new_tokens, beam_scorer=beam_scorer
) )
assert torch.allclose(remote_outputs, hf_outputs), "Beam search results are not identical to HF" assert torch.allclose(remote_outputs, hf_outputs), "Beam search results are not identical to HF"

@ -4,7 +4,7 @@ import torch.nn.functional as F
from hivemind import DHT, BatchTensorDescriptor, get_logger from hivemind import DHT, BatchTensorDescriptor, get_logger
from hivemind.proto import runtime_pb2 from hivemind.proto import runtime_pb2
from petals import DistributedBloomConfig from petals import AutoDistributedConfig
from petals.client import RemoteSequenceManager, RemoteSequential from petals.client import RemoteSequenceManager, RemoteSequential
from petals.data_structures import UID_DELIMITER from petals.data_structures import UID_DELIMITER
from petals.server.from_pretrained import load_pretrained_block from petals.server.from_pretrained import load_pretrained_block
@ -15,7 +15,7 @@ logger = get_logger(__name__)
@pytest.mark.forked @pytest.mark.forked
def test_remote_sequential(): def test_remote_sequential():
config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True) dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
test_inputs = torch.randn(1, 5, config.hidden_size, requires_grad=True) test_inputs = torch.randn(1, 5, config.hidden_size, requires_grad=True)
grad_proj = torch.randn(1, 5, config.hidden_size) grad_proj = torch.randn(1, 5, config.hidden_size)
@ -40,10 +40,10 @@ def test_remote_sequential():
assert hidden.shape == test_inputs.shape assert hidden.shape == test_inputs.shape
assert hidden.requires_grad assert hidden.requires_grad
second_half_outputs = second_half(hidden) second_half_outputs = second_half(hidden)
assert torch.allclose(second_half_outputs, full_outputs, atol=1e-4) assert torch.allclose(second_half_outputs, full_outputs, atol=3e-4)
(second_half_outputs * grad_proj).sum().backward() (second_half_outputs * grad_proj).sum().backward()
assert torch.allclose(test_inputs.grad, full_grad, atol=1e-3) assert torch.allclose(test_inputs.grad, full_grad, atol=1e-2)
# test RemoteSequential with lossy compression # test RemoteSequential with lossy compression
block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.num_hidden_layers)] block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.num_hidden_layers)]
@ -56,7 +56,7 @@ def test_remote_sequential():
(approx_outputs * grad_proj).sum().backward() (approx_outputs * grad_proj).sum().backward()
assert not torch.allclose(approx_outputs, full_outputs, rtol=0, atol=1e-4), "compression was not used" assert not torch.allclose(approx_outputs, full_outputs, rtol=0, atol=1e-4), "compression was not used"
assert not torch.allclose(test_inputs.grad, full_grad, rtol=0, atol=1e-2), "compression was not used" assert not torch.allclose(test_inputs.grad, full_grad, rtol=0, atol=1e-3), "compression was not used"
assert abs(approx_outputs - full_outputs).mean() < 0.01 assert abs(approx_outputs - full_outputs).mean() < 0.01
absmax = abs(full_grad).max() absmax = abs(full_grad).max()
assert abs(test_inputs.grad / absmax - full_grad / absmax).mean() < 0.05 assert abs(test_inputs.grad / absmax - full_grad / absmax).mean() < 0.05
@ -87,7 +87,7 @@ class DummyCustomSequenceManager(RemoteSequenceManager):
@pytest.mark.forked @pytest.mark.forked
def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3): def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3):
config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
remote_sequential = RemoteSequential(config) remote_sequential = RemoteSequential(config)
inputs = F.normalize(torch.randn(batch_size, seq_len, config.hidden_size), dim=-1) inputs = F.normalize(torch.randn(batch_size, seq_len, config.hidden_size), dim=-1)

@ -5,7 +5,7 @@ import pytest
import torch import torch
from hivemind import DHT, get_logger from hivemind import DHT, get_logger
from petals import DistributedBloomConfig from petals import AutoDistributedConfig
from petals.client import RemoteSequenceManager, RemoteSequential from petals.client import RemoteSequenceManager, RemoteSequential
from petals.data_structures import UID_DELIMITER from petals.data_structures import UID_DELIMITER
from test_utils import * from test_utils import *
@ -16,7 +16,7 @@ logger = get_logger(__name__)
@pytest.mark.forked @pytest.mark.forked
@pytest.mark.parametrize("mode", ["max_throughput", "min_latency"]) @pytest.mark.parametrize("mode", ["max_throughput", "min_latency"])
def test_sequence_manager_basics(mode: str): def test_sequence_manager_basics(mode: str):
config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True) dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
sequential = RemoteSequential(config, dht=dht) sequential = RemoteSequential(config, dht=dht)
shutdown_evt = threading.Event() shutdown_evt = threading.Event()

@ -4,14 +4,16 @@ import hivemind
import pytest import pytest
import torch import torch
from petals import DistributedBloomConfig, RemoteSequential from petals import AutoDistributedConfig, RemoteSequential
from petals.server.handler import CACHE_TOKENS_AVAILABLE from petals.server.handler import CACHE_TOKENS_AVAILABLE
from test_utils import * from test_utils import *
@pytest.mark.forked @pytest.mark.forked
def test_server_info(block_from: int = 22, block_to: int = 24, max_length: int = 100, max_length2: int = 50): def test_server_info(block_from: int = 2, block_to: int = 5, max_length: int = 100, max_length2: int = 50):
config = DistributedBloomConfig.from_pretrained(MODEL_NAME) config = AutoDistributedConfig.from_pretrained(MODEL_NAME)
config.allowed_servers = ["QmNV5G3hq2UmAck2htEgsqrmPFBff5goFZAdmKDcZLBZLX"] # PeerID from server2.id
dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True) dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
blocks1 = RemoteSequential(config, dht=dht, start_block=block_from, end_block=block_to) blocks1 = RemoteSequential(config, dht=dht, start_block=block_from, end_block=block_to)
blocks2 = RemoteSequential(config, dht=dht, start_block=block_to - 1, end_block=block_to) blocks2 = RemoteSequential(config, dht=dht, start_block=block_to - 1, end_block=block_to)

@ -14,8 +14,11 @@ from test_utils import MODEL_NAME
@pytest.mark.parametrize("custom_config", [True, False]) @pytest.mark.parametrize("custom_config", [True, False])
@pytest.mark.parametrize("devices", [("cpu",) * 2, ("cpu",) * 3, ("cpu",) * 4]) @pytest.mark.parametrize("devices", [("cpu",) * 2, ("cpu",) * 3, ("cpu",) * 4])
def test_tp_block(devices, custom_config): def test_tp_block(devices, custom_config):
block_index = random.randint(0, 10)
model_config = transformers.AutoConfig.from_pretrained(MODEL_NAME) model_config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
if model_config.model_type != "bloom":
pytest.skip("Tensor parallelism is implemented only for BLOOM for now")
block_index = random.randint(0, 10)
block = load_pretrained_block(MODEL_NAME, block_index=block_index, torch_dtype=torch.float32).to(devices[0]) block = load_pretrained_block(MODEL_NAME, block_index=block_index, torch_dtype=torch.float32).to(devices[0])
tp_config = None tp_config = None

Loading…
Cancel
Save