Speed up loading blocks using init with meta weights (#285)

* Init WrappedBloomBlock with meta weights

---------

Co-authored-by: Alexander Borzunov <borzunov.alexander@gmail.com>
pull/284/head
Max Ryabinin 1 year ago committed by GitHub
parent c519bffc59
commit 793726b041
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -14,4 +14,5 @@ profile = "black"
line_length = 120
combine_as_imports = true
combine_star = true
known_local_folder = ["tests", "cli"]
known_local_folder = ["tests", "cli"]
known_first_party = ["test_utils"]

@ -13,6 +13,8 @@ import time
from typing import Optional, OrderedDict, Union
import torch
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
from hivemind.utils.logging import get_logger
from transformers.modeling_utils import WEIGHTS_NAME
from transformers.models.bloom.configuration_bloom import BloomConfig
@ -38,13 +40,16 @@ def load_pretrained_block(
max_disk_space: Optional[int] = None,
) -> WrappedBloomBlock:
"""Load one BLOOM block from a converted model. See convert_model.py (or README.md) on how to convert it."""
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
if config is None:
config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
if cache_dir is None:
cache_dir = DEFAULT_CACHE_DIR
block = WrappedBloomBlock(config)
with init_empty_weights():
block = WrappedBloomBlock(config)
state_dict = _load_state_dict(
converted_model_name_or_path,
block_index,
@ -54,16 +59,17 @@ def load_pretrained_block(
max_disk_space=max_disk_space,
)
if torch_dtype == "auto":
with torch.no_grad():
for name, param in block.named_parameters():
assert name in state_dict, f"{name} not in state dict"
param.data = param.data.to(state_dict[name].dtype)
else:
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
block = block.to(dtype=torch_dtype)
# dummy load, check that keys match
report = block.load_state_dict(state_dict, strict=True)
assert not report.missing_keys, f"Some block weights are missing: {report.missing_keys}"
for param_name, _ in block.named_parameters():
assert param_name in state_dict, f"{param_name} not in state dict"
param = state_dict[param_name]
if torch_dtype != "auto" and not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
param = param.to(torch_dtype)
set_module_tensor_to_device(block, param_name, "cpu", value=param)
logger.info(f"Loaded {converted_model_name_or_path} block {block_index}, {report}")
return block

@ -30,7 +30,7 @@ def get_block_size(
dtype is not None and load_in_8bit is not None
), 'get_block_size(..., location="memory") requires to specify dtype and load_in_8bit for calculations'
with init_empty_weights():
with init_empty_weights(include_buffers=True):
block = WrappedBloomBlock(config)
n_params = sum(param.numel() for param in block.parameters())

@ -1,9 +1,9 @@
import pytest
import torch
from test_utils import MODEL_NAME
from petals.client import DistributedBloomConfig
from petals.server.throughput import measure_compute_rps, measure_network_rps
from petals.server.throughput import measure_compute_rps
from test_utils import MODEL_NAME
@pytest.mark.forked

@ -1,15 +1,18 @@
import random
from typing import Union
import hivemind
import pytest
import torch
from test_utils import *
from transformers.models.bloom.configuration_bloom import BloomConfig
from petals.bloom.from_pretrained import load_pretrained_block
from petals.bloom.block import WrappedBloomBlock
from petals.bloom.from_pretrained import DTYPE_MAP, _load_state_dict, load_pretrained_block
from petals.client import DistributedBloomConfig
from petals.client.remote_sequential import RemoteTransformerBlock
from petals.data_structures import UID_DELIMITER
from petals.dht_utils import get_remote_module
from test_utils import *
@pytest.mark.forked
@ -41,3 +44,47 @@ def test_remote_block_exact_match(atol_forward=1e-4, atol_inference=1e-3):
assert torch.allclose(outputs_local, outputs_forward, rtol=0, atol=atol_forward)
assert torch.allclose(outputs_local, outputs_inference, rtol=0, atol=atol_inference)
def _old_load_pretrained_block(
converted_model_name_or_path: str,
block_index: int,
torch_dtype: Union[torch.dtype, str] = "auto",
) -> WrappedBloomBlock:
"""Load the BLOOM block by directly initializing the weights.
This test is used to check consistency with the previous implementation and can be removed in the future."""
config = BloomConfig.from_pretrained(converted_model_name_or_path)
block = WrappedBloomBlock(config)
state_dict = _load_state_dict(
converted_model_name_or_path,
block_index,
config,
cache_dir=None,
)
if torch_dtype == "auto":
with torch.no_grad():
for name, param in block.named_parameters():
assert name in state_dict, f"{name} not in state dict"
param.data = param.data.to(state_dict[name].dtype)
else:
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
block = block.to(dtype=torch_dtype)
block.load_state_dict(state_dict, strict=True)
return block
@pytest.mark.forked
def test_init_pretrained_block(torch_dtype=torch.float32, atol_forward=1e-8):
config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
torch.random.manual_seed(0)
inputs = torch.randn(1, 16, config.hidden_size, dtype=torch_dtype)
block = load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch_dtype)
ref_block = _old_load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch_dtype)
outputs = block.forward(inputs)[0]
outputs_ref = ref_block.forward(inputs)[0]
assert torch.allclose(outputs, outputs_ref, rtol=0, atol=atol_forward)

@ -7,12 +7,12 @@
import hivemind
import pytest
import torch
from test_utils import *
from petals.bloom.from_pretrained import load_pretrained_block
from petals.client import DistributedBloomConfig
from petals.client.remote_sequential import RemoteSequential
from petals.dht_utils import get_remote_sequence
from test_utils import *
@pytest.mark.forked

@ -2,11 +2,11 @@ import pytest
import torch
import transformers
from hivemind import get_logger
from test_utils import *
from transformers.generation import BeamSearchScorer
from transformers.models.bloom import BloomForCausalLM
from petals.client.remote_model import DistributedBloomForCausalLM
from test_utils import *
logger = get_logger(__name__)

@ -1,14 +1,14 @@
import pytest
import torch
import torch.nn.functional as F
from hivemind import DHT, BatchTensorDescriptor, get_logger, use_hivemind_log_handler
from hivemind import DHT, BatchTensorDescriptor, get_logger
from hivemind.proto import runtime_pb2
from test_utils import *
from petals.bloom.from_pretrained import load_pretrained_block
from petals.client import RemoteSequenceManager, RemoteSequential
from petals.client.remote_model import DistributedBloomConfig
from petals.data_structures import UID_DELIMITER
from test_utils import *
logger = get_logger(__name__)

@ -4,11 +4,11 @@ import time
import pytest
import torch
from hivemind import DHT, get_logger
from test_utils import *
from petals.client import RemoteSequenceManager, RemoteSequential
from petals.client.remote_model import DistributedBloomConfig
from petals.data_structures import UID_DELIMITER
from test_utils import *
logger = get_logger(__name__)

@ -3,12 +3,12 @@ import time
import hivemind
import pytest
import torch
from test_utils import *
from petals.client import DistributedBloomConfig
from petals.data_structures import UID_DELIMITER
from petals.dht_utils import get_remote_sequence
from petals.server.handler import CACHE_TOKENS_AVAILABLE
from test_utils import *
@pytest.mark.forked

@ -5,9 +5,9 @@ import torch
import transformers
from tensor_parallel import TensorParallel
from tensor_parallel.slicing_configs import get_bloom_config
from test_utils import MODEL_NAME
from petals.bloom.from_pretrained import load_pretrained_block
from test_utils import MODEL_NAME
@pytest.mark.forked

Loading…
Cancel
Save