petals/tests/test_remote_sequential.py
Alexander Borzunov 8f6342a861
Refactor RemoteSequenceManager (#309)
This PR:

1. **Extracts `SequenceManagerConfig` and `SequenceManagerState` subclasses.**

    The config is provided by caller and never changed from inside `RemoteSequenceManager`. The state is a part of the `RemoteSequenceManager`'s state shared between the main manager and its slices. We fix some slicing bugs along the way.

2. **Removes `dht_prefix` and `p2p` arguments, makes `dht` argument optional.**

    `dht_prefix` can always be overridden using `config.dht_prefix`. `p2p` actually needed only under the hood of `RemoteSequenceManager`, so it can extract it by itself without exposing this low-level class to callers. If strictly necessary, a caller can provide `p2p` as a part of `SequenceManagerState`. `dht` is also needed only by `RemoteSequenceManager`, so we can make it optional in the parent classes and create it automatically when it's not provided.

3. **Simplifies retry logic.**

    Previously, we could have "nested" retry loops: one in `._update()`, another in inference/forward/backward steps. The loop in `._update()` could introduce issues to concurrent inference/forward/backward calls, since it blocks the entire class if its delay period becomes too high. Now this logic is simplified: `._update()` performs only one attempt to fetch the DHT info, any retries are triggered by the inference/forward/backward steps.

4. **Removes deprecated `RemoteTransformerBlock`.**

    `RemoteTransformerBlock` was deprecated a long time ago, before Petals 1.0.0. Its removal is long due.

5. **Removes `dht_utils.get_remote_module()`, `dht_utils.get_remote_sequence()`.**

    This functions duplicate the functionality of the `RemoteSequential` constructor.

6. (minor) **Removes `RemoteSequential.is_subsequence` flag.**

    This flag worked incorrectly and was never used. I am removing it for the sake of simplicity.
2023-05-07 13:41:13 +04:00

127 lines
5.5 KiB
Python

import pytest
import torch
import torch.nn.functional as F
from hivemind import DHT, BatchTensorDescriptor, get_logger
from hivemind.proto import runtime_pb2
from petals.bloom.from_pretrained import load_pretrained_block
from petals.client import RemoteSequenceManager, RemoteSequential
from petals.client.remote_model import DistributedBloomConfig
from petals.data_structures import UID_DELIMITER
from test_utils import *
logger = get_logger(__name__)
@pytest.mark.forked
def test_remote_sequential():
config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
test_inputs = torch.randn(1, 5, config.hidden_size, requires_grad=True)
grad_proj = torch.randn(1, 5, config.hidden_size)
sequential = RemoteSequential(config, dht=dht)
full_outputs = sequential(test_inputs)
(full_outputs * grad_proj).sum().backward()
assert test_inputs.grad is not None
full_grad = test_inputs.grad.clone()
test_inputs.grad.data.zero_()
first_half = sequential[: config.n_layer // 2]
second_half = sequential[config.n_layer // 2 :]
assert len(first_half) + len(second_half) == len(sequential)
assert abs(len(first_half) - len(second_half)) == config.n_layer % 2
for m in sequential, first_half, second_half:
assert isinstance(repr(m), str)
hidden = first_half(test_inputs)
assert isinstance(hidden, torch.Tensor)
assert hidden.shape == test_inputs.shape
assert hidden.requires_grad
second_half_outputs = second_half(hidden)
assert torch.allclose(second_half_outputs, full_outputs, atol=1e-4)
(second_half_outputs * grad_proj).sum().backward()
assert torch.allclose(test_inputs.grad, full_grad, atol=1e-3)
# test RemoteSequential with lossy compression
block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)]
lossy_sequential = RemoteSequential(
config, sequence_manager=DummyCustomSequenceManager(config, block_uids, dht=dht)
)
test_inputs.grad = None
approx_outputs = lossy_sequential(test_inputs)
(approx_outputs * grad_proj).sum().backward()
assert not torch.allclose(approx_outputs, full_outputs, rtol=0, atol=1e-4), "compression was not used"
assert not torch.allclose(test_inputs.grad, full_grad, rtol=0, atol=1e-2), "compression was not used"
assert abs(approx_outputs - full_outputs).mean() < 0.01
absmax = abs(full_grad).max()
assert abs(test_inputs.grad / absmax - full_grad / absmax).mean() < 0.05
class DummyCustomSequenceManager(RemoteSequenceManager):
"""A sequence manager that compresses inputs/outputs during forward and backward pass."""
@property
def rpc_info(self):
rpc_info = super().rpc_info
dims = (2048, 1024)
compressed_input_schema = BatchTensorDescriptor(dims, compression=runtime_pb2.CompressionType.FLOAT16)
rpc_info["forward_schema"] = (compressed_input_schema,), dict() # (args, kwargs)
return rpc_info
def get_request_metadata(self, protocol: str, *args, **kwargs):
metadata = super().get_request_metadata(protocol, *args, **kwargs)
if protocol == "rpc_forward":
metadata["output_compression"] = (runtime_pb2.CompressionType.FLOAT16,)
elif protocol == "rpc_backward":
metadata["output_compression"] = (runtime_pb2.CompressionType.BLOCKWISE_8BIT,)
return metadata
@pytest.mark.forked
def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3):
config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
remote_sequential = RemoteSequential(config)
inputs = F.normalize(torch.randn(batch_size, seq_len, config.hidden_size), dim=-1)
output_proj = F.normalize(torch.randn(batch_size, seq_len + pre_seq_len, config.hidden_size), dim=-1)
input_prompts = F.normalize(torch.randn(batch_size, pre_seq_len, config.hidden_size, requires_grad=True), dim=-1)
intermediate_prompts = torch.randn(config.n_layer, batch_size, pre_seq_len, config.hidden_size, requires_grad=True)
input_prompts = input_prompts.detach().requires_grad_(True)
intermediate_prompts = intermediate_prompts.detach().requires_grad_(True)
inputs_with_prompts = torch.cat([inputs, input_prompts], dim=1)
assert inputs_with_prompts.shape == (batch_size, seq_len + pre_seq_len, config.hidden_size)
outputs = remote_sequential(inputs_with_prompts, prompts=intermediate_prompts)
(outputs * output_proj).sum().backward()
assert intermediate_prompts.grad is not None
input_prompts_ref = input_prompts.clone().detach().requires_grad_(True)
intermediate_prompts_ref = intermediate_prompts.clone().detach().requires_grad_(True)
assert input_prompts_ref.grad is None
assert intermediate_prompts_ref.grad is None
outputs_ref = torch.cat([inputs, input_prompts_ref], dim=1)
for block_index in range(config.n_layer):
block_prompt = intermediate_prompts_ref[block_index]
outputs_ref[:, : block_prompt.shape[1]] += block_prompt
block = load_pretrained_block(MODEL_NAME, block_index=block_index, torch_dtype=torch.float32)
(outputs_ref,) = block(outputs_ref)
assert torch.allclose(outputs_ref, outputs, atol=1e-3)
(outputs_ref * output_proj).sum().backward()
assert input_prompts_ref.grad is not None
assert torch.allclose(input_prompts_ref.grad, input_prompts.grad, atol=1e-2)
assert intermediate_prompts_ref.grad is not None
assert torch.allclose(intermediate_prompts_ref.grad, intermediate_prompts.grad, atol=1e-2)