petals/tests/test_tensor_parallel.py
Alexander Borzunov 8c546d988a
Test Llama, rebalancing, throughput eval, and all CLI scripts (#452)
This PR extends CI to:

1. Test Llama code using [TinyLlama-v0](https://huggingface.co/Maykeye/TinyLLama-v0).
2. Test rebalancing (sets up a situation where the 1st server needs to change its original position).
3. Check if benchmark scripts run (in case someone breaks its code). Note that the benchmark results are meaningless here (since they're measured on a tiny swarm of CPU servers, with low `--n_steps`).
4. Test `petals.cli.run_dht`.
5. Increase swap space and watch free RAM (a common issue is that actions are cancelled without explanation if there's not enough RAM - so it's a useful reminder + debug tool).
6. Fix flapping tests for bloom-560m by increasing tolerance.

Other minor changes: fix `--help` messages to show defaults, fix docs, tune rebalancing constants.
2023-08-08 19:10:27 +04:00

50 lines
2.0 KiB
Python

import random
import pytest
import torch
import transformers
from tensor_parallel import TensorParallel
from tensor_parallel.slicing_configs import get_bloom_config
from petals.server.from_pretrained import load_pretrained_block
from test_utils import MODEL_NAME
@pytest.mark.forked
@pytest.mark.parametrize("custom_config", [True, False])
@pytest.mark.parametrize("devices", [("cpu",) * 2, ("cpu",) * 3, ("cpu",) * 4])
def test_tp_block(devices, custom_config):
model_config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
if model_config.model_type != "bloom":
pytest.skip("Tensor parallelism is implemented only for BLOOM for now")
block_index = random.randint(0, 10)
block = load_pretrained_block(MODEL_NAME, block_index=block_index, torch_dtype=torch.float32).to(devices[0])
tp_config = None
if custom_config:
tp_config = get_bloom_config(model_config, devices)
batch_size = 2
prefix_length = 5
test_inputs1 = torch.randn(batch_size, 3, 1024, requires_grad=True, device=devices[0])
test_inputs2 = test_inputs1.detach().clone().requires_grad_(True)
test_prefix1 = torch.randn(batch_size, prefix_length, 1024, requires_grad=True, device=devices[0])
test_prefix2 = test_prefix1.detach().clone().requires_grad_(True)
grad_proj = torch.rand_like(test_inputs1)
y_prefix_ref, layer_past = block(test_prefix1, use_cache=True)
y_ref, cache_ref = block(test_inputs1, use_cache=True, layer_past=layer_past)
y_ref.backward(grad_proj)
block_tp = TensorParallel(block, devices, config=tp_config)
y_prefix, layer_past = block_tp(test_prefix2, use_cache=True)
y_ours, cache_ours = block_tp(test_inputs2, use_cache=True, layer_past=layer_past)
y_ours.backward(grad_proj)
assert torch.allclose(y_prefix, y_prefix_ref, atol=1e-5)
assert torch.allclose(y_ours, y_ref, atol=1e-5)
assert torch.allclose(test_inputs1.grad, test_inputs2.grad, atol=1e-4)
assert torch.allclose(test_prefix1.grad, test_prefix2.grad, atol=1e-4)