mirror of
https://github.com/bigscience-workshop/petals
synced 2024-10-31 09:20:41 +00:00
8c546d988a
This PR extends CI to: 1. Test Llama code using [TinyLlama-v0](https://huggingface.co/Maykeye/TinyLLama-v0). 2. Test rebalancing (sets up a situation where the 1st server needs to change its original position). 3. Check if benchmark scripts run (in case someone breaks its code). Note that the benchmark results are meaningless here (since they're measured on a tiny swarm of CPU servers, with low `--n_steps`). 4. Test `petals.cli.run_dht`. 5. Increase swap space and watch free RAM (a common issue is that actions are cancelled without explanation if there's not enough RAM - so it's a useful reminder + debug tool). 6. Fix flapping tests for bloom-560m by increasing tolerance. Other minor changes: fix `--help` messages to show defaults, fix docs, tune rebalancing constants.
75 lines
2.9 KiB
Python
75 lines
2.9 KiB
Python
######
|
|
# Warning:torch this test is a work in progress. It will be modified soon.
|
|
# - if you want more stable tests, see test_block_exact_match
|
|
# - if you want to figure out chained inference, ask yozh
|
|
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from petals import AutoDistributedConfig
|
|
from petals.client.remote_sequential import RemoteSequential
|
|
from petals.server.from_pretrained import load_pretrained_block
|
|
from test_utils import *
|
|
|
|
|
|
@pytest.mark.forked
|
|
def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):
|
|
config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
|
|
remote_blocks = RemoteSequential(config, start_block=3, end_block=6)
|
|
assert isinstance(remote_blocks, RemoteSequential)
|
|
|
|
ref_blocks = [
|
|
load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch.float32),
|
|
load_pretrained_block(MODEL_NAME, 4, torch_dtype=torch.float32),
|
|
load_pretrained_block(MODEL_NAME, 5, torch_dtype=torch.float32),
|
|
]
|
|
inputs = torch.randn(1, seq_length, config.hidden_size, requires_grad=True)
|
|
outputs_rpc = remote_blocks.forward(inputs)
|
|
outputs_rpc.sum().backward()
|
|
grads_rpc = inputs.grad
|
|
|
|
inputs.grad = None
|
|
hidden_states = inputs
|
|
for ref_block in ref_blocks:
|
|
hidden_states = ref_block.forward(hidden_states)[0]
|
|
outputs_ref = hidden_states
|
|
outputs_ref.sum().backward()
|
|
grads_ref = inputs.grad
|
|
|
|
assert torch.allclose(outputs_ref, outputs_rpc, rtol=0, atol=atol_forward)
|
|
assert torch.allclose(grads_ref, grads_rpc, rtol=0, atol=atol_backward)
|
|
|
|
|
|
@pytest.mark.forked
|
|
def test_chained_inference_exact_match(atol_inference=1e-4):
|
|
config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
|
|
remote_blocks = RemoteSequential(config, start_block=3, end_block=5)
|
|
|
|
inputs = torch.randn(1, 8, config.hidden_size)
|
|
|
|
outputs_inference = []
|
|
with remote_blocks.inference_session(max_length=inputs.shape[1]) as sess:
|
|
for i in range(inputs.shape[1]):
|
|
outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
|
|
outputs_inference = torch.cat(outputs_inference, dim=1)
|
|
|
|
ref_blocks = [
|
|
load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch.float32),
|
|
load_pretrained_block(MODEL_NAME, 4, torch_dtype=torch.float32),
|
|
]
|
|
outputs_ref = []
|
|
caches = [None, None]
|
|
for i in range(inputs.shape[1]):
|
|
new_caches = []
|
|
hidden_states = inputs[:, i : i + 1, :]
|
|
for ref_block, cache in zip(ref_blocks, caches):
|
|
with torch.no_grad():
|
|
hidden_states, new_cache = ref_block.forward(hidden_states, use_cache=True, layer_past=cache)
|
|
new_caches.append(new_cache)
|
|
|
|
outputs_ref.append(hidden_states)
|
|
caches = new_caches
|
|
outputs_ref = torch.cat(outputs_ref, dim=1)
|
|
assert torch.allclose(outputs_ref, outputs_inference, rtol=0, atol=atol_inference)
|