2023-01-03 15:35:51 +00:00
|
|
|
import random
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
import torch
|
|
|
|
import transformers
|
|
|
|
from tensor_parallel import TensorParallel
|
|
|
|
from tensor_parallel.slicing_configs import get_bloom_config
|
|
|
|
|
Add LLaMA support (#323)
This PR:
1. **Abolishes the model conversion procedure.** Now, models are downloaded directly from original repositories like https://huggingface.co/bigscience/bloom. Servers download only shards with blocks to be hosted, and clients download only shards with input/output embeddings and layernorms.
- BLOOM is loaded from `bigscience/bloom`, but we use the DHT prefix `bigscience/bloom-petals` for backward compatibility. Same with smaller BLOOMs and BLOOMZ.
- LLaMA can be loaded from any repo like `username/llama-65b-hf`, but we use the DHT prefix `llama-65b-hf` (without the username) to accomodate blocks from different repos (there're a few of them with minor differences, such as `Llama` vs. `LLaMA` in the class name).
2. **Refactors the client to generalize it for multiple models.** Now, we have `petals.models` packages that contain model-specific code (e.g. `petals.models.bloom`, `petals.models.llama`). General code (e.g. CPU-efficient LM head, p-tuning) is kept in `petals.client`.
3. **Introduces** `WrappedLlamaBlock`, `DistributedLlamaConfig`, `DistributedLlamaForCausalLM`, `DistributedLlamaForSequenceClassification`, and `DistributedLlamaModel` compatible with Petals functionality (p-tuning, adapters, etc.).
4. **Introduces** `AutoDistributedConfig` that automatically chooses the correct config class (`DistributedLlamaConfig` or `DistributedBloomConfig`). The refactored configs contain all model-specific info for both clients and servers.
Upgrade instructions:
- Remove disk caches for blocks in old (converted) format to save disk space. That is, remove `~/.cache/petals/model--bigscience--bloom-petals` and `~/.cache/petals/model--bigscience--bloomz-petals` directories (if present).
2023-06-23 11:46:10 +00:00
|
|
|
from petals.server.from_pretrained import load_pretrained_block
|
2023-03-12 21:49:04 +00:00
|
|
|
from test_utils import MODEL_NAME
|
2023-01-03 15:35:51 +00:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.forked
|
|
|
|
@pytest.mark.parametrize("custom_config", [True, False])
|
|
|
|
@pytest.mark.parametrize("devices", [("cpu",) * 2, ("cpu",) * 3, ("cpu",) * 4])
|
|
|
|
def test_tp_block(devices, custom_config):
|
|
|
|
model_config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
|
2023-08-08 15:10:27 +00:00
|
|
|
if model_config.model_type != "bloom":
|
|
|
|
pytest.skip("Tensor parallelism is implemented only for BLOOM for now")
|
|
|
|
|
|
|
|
block_index = random.randint(0, 10)
|
2023-01-03 15:35:51 +00:00
|
|
|
block = load_pretrained_block(MODEL_NAME, block_index=block_index, torch_dtype=torch.float32).to(devices[0])
|
|
|
|
|
|
|
|
tp_config = None
|
|
|
|
if custom_config:
|
|
|
|
tp_config = get_bloom_config(model_config, devices)
|
|
|
|
|
|
|
|
batch_size = 2
|
|
|
|
prefix_length = 5
|
|
|
|
|
|
|
|
test_inputs1 = torch.randn(batch_size, 3, 1024, requires_grad=True, device=devices[0])
|
|
|
|
test_inputs2 = test_inputs1.detach().clone().requires_grad_(True)
|
|
|
|
test_prefix1 = torch.randn(batch_size, prefix_length, 1024, requires_grad=True, device=devices[0])
|
|
|
|
test_prefix2 = test_prefix1.detach().clone().requires_grad_(True)
|
|
|
|
grad_proj = torch.rand_like(test_inputs1)
|
|
|
|
|
|
|
|
y_prefix_ref, layer_past = block(test_prefix1, use_cache=True)
|
|
|
|
y_ref, cache_ref = block(test_inputs1, use_cache=True, layer_past=layer_past)
|
|
|
|
y_ref.backward(grad_proj)
|
|
|
|
|
|
|
|
block_tp = TensorParallel(block, devices, config=tp_config)
|
|
|
|
y_prefix, layer_past = block_tp(test_prefix2, use_cache=True)
|
|
|
|
y_ours, cache_ours = block_tp(test_inputs2, use_cache=True, layer_past=layer_past)
|
|
|
|
y_ours.backward(grad_proj)
|
|
|
|
|
2023-01-11 14:54:24 +00:00
|
|
|
assert torch.allclose(y_prefix, y_prefix_ref, atol=1e-5)
|
|
|
|
assert torch.allclose(y_ours, y_ref, atol=1e-5)
|
2023-01-03 15:35:51 +00:00
|
|
|
assert torch.allclose(test_inputs1.grad, test_inputs2.grad, atol=1e-4)
|
|
|
|
assert torch.allclose(test_prefix1.grad, test_prefix2.grad, atol=1e-4)
|