Determine block dtype in a unified manner (#325)

* Extract backend_dtype, remove duplicate DTYPE_MAP

* Use bfloat16 as the default dtype, resolve dtype in load_pretrained_block
pull/329/head
Max Ryabinin 11 months ago committed by GitHub
parent 3e7ae5116d
commit c839173e57
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -21,7 +21,7 @@ from transformers.models.bloom.configuration_bloom import BloomConfig
from transformers.utils import get_file_from_repo
from petals.bloom.block import WrappedBloomBlock
from petals.server.block_utils import get_block_size
from petals.server.block_utils import get_block_size, resolve_block_dtype
from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
logger = get_logger(__name__)
@ -41,6 +41,7 @@ def load_pretrained_block(
) -> WrappedBloomBlock:
"""Load one BLOOM block from a converted model. See convert_model.py (or README.md) on how to convert it."""
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
torch_dtype = resolve_block_dtype(config, torch_dtype)
if config is None:
config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
@ -66,7 +67,7 @@ def load_pretrained_block(
for param_name, _ in block.named_parameters():
assert param_name in state_dict, f"{param_name} not in state dict"
param = state_dict[param_name]
if torch_dtype != "auto" and not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
if not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
param = param.to(torch_dtype)
set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype)

@ -10,13 +10,11 @@ from huggingface_hub import HfApi, Repository
from tqdm.auto import tqdm
from transformers.models.bloom.modeling_bloom import BloomModel
from petals.bloom.from_pretrained import BLOCK_BRANCH_PREFIX, CLIENT_BRANCH
from petals.bloom.from_pretrained import BLOCK_BRANCH_PREFIX, CLIENT_BRANCH, DTYPE_MAP
from petals.client import DistributedBloomConfig
logger = get_logger(__name__)
DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
def main():
parser = argparse.ArgumentParser(description="Load bloom layers and convert to 8-bit using torch quantization.")

@ -7,14 +7,13 @@ from transformers import BloomConfig
from petals.bloom.block import WrappedBloomBlock
def resolve_block_dtype(config: BloomConfig, dtype: Union[str, torch.dtype]) -> Union[str, torch.dtype]:
def resolve_block_dtype(config: BloomConfig, dtype: Union[str, torch.dtype]) -> torch.dtype:
"""If dtype is "auto", resolves it using BloomConfig. Returns `dtype` intact otherwise."""
if dtype == "auto" or dtype is None:
dtype = config.torch_dtype
if dtype == "auto" or dtype is None:
dtype = torch.float32
return dtype
if dtype not in ("auto", None):
return dtype
if config.torch_dtype not in ("auto", None):
return config.torch_dtype
return torch.bfloat16
def get_block_size(

@ -22,7 +22,7 @@ from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
from petals.dht_utils import declare_active_modules, get_remote_module_infos
from petals.server import block_selection
from petals.server.backend import TransformerBackend, merge_inference_pools_inplace
from petals.server.block_utils import get_block_size
from petals.server.block_utils import get_block_size, resolve_block_dtype
from petals.server.handler import TransformerConnectionHandler
from petals.server.memory_cache import MemoryCache
from petals.server.reachability import ReachabilityProtocol, check_direct_reachability, validate_reachability
@ -151,7 +151,7 @@ class Server:
if isinstance(torch_dtype, str):
torch_dtype = DTYPE_MAP[torch_dtype]
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
self.torch_dtype = torch_dtype
self.torch_dtype = resolve_block_dtype(self.block_config, torch_dtype)
if tensor_parallel_devices is None:
tensor_parallel_devices = (device,)
@ -182,6 +182,7 @@ class Server:
if attn_cache_size is None:
# Hidden size is 14336 for the bigscience/bloom-petals model. For other models, scale accordingly
attn_cache_size = 0.5 * gib * num_blocks * self.block_config.hidden_size / 14336
self.attn_cache_size, self.alloc_timeout = attn_cache_size, alloc_timeout
logger.info(f"Attention cache for all blocks will consume up to {attn_cache_size / gib:.2f} GiB")
@ -404,22 +405,21 @@ class ModuleContainer(threading.Thread):
)
block = convert_block(block, block_config, tensor_parallel_devices, device, load_in_8bit, freeze=True)
backend_dtype = next(block.parameters()).dtype if torch_dtype == "auto" else torch_dtype
blocks[module_uid] = TransformerBackend(
module_uid,
block,
config=block_config,
memory_cache=memory_cache,
backend_dtype=backend_dtype,
backend_dtype=torch_dtype,
args_schema=(
BatchTensorDescriptor(
1, 2048, block_config.hidden_size, dtype=backend_dtype, compression=compression
1, 2048, block_config.hidden_size, dtype=torch_dtype, compression=compression
),
),
kwargs_schema={},
outputs_schema=(
BatchTensorDescriptor(
1, 2048, block_config.hidden_size, dtype=backend_dtype, compression=compression
1, 2048, block_config.hidden_size, dtype=torch_dtype, compression=compression
),
),
min_batch_size=min_batch_size,

@ -0,0 +1,17 @@
import pytest
import torch
from petals.bloom.from_pretrained import load_pretrained_block
from petals.client import DistributedBloomConfig
from petals.server.block_utils import resolve_block_dtype
from test_utils import MODEL_NAME
@pytest.mark.forked
@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.float16, "auto"])
def test_backend_dtype(torch_dtype):
config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
block = load_pretrained_block(MODEL_NAME, 0, config, torch_dtype=torch_dtype)
backend_dtype = resolve_block_dtype(config, torch_dtype)
other_backend_dtype = next(block.parameters()).dtype if torch_dtype == "auto" else torch_dtype
assert backend_dtype == other_backend_dtype
Loading…
Cancel
Save