You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
49 lines
1.7 KiB
Python
49 lines
1.7 KiB
Python
from typing import Optional, Union
|
|
|
|
import torch
|
|
from accelerate import init_empty_weights
|
|
from transformers import BloomConfig
|
|
|
|
from petals.bloom.block import WrappedBloomBlock
|
|
|
|
|
|
def resolve_block_dtype(config: BloomConfig, dtype: Union[str, torch.dtype]) -> Union[str, torch.dtype]:
|
|
"""If dtype is "auto", resolves it using BloomConfig. Returns `dtype` intact otherwise."""
|
|
|
|
if dtype == "auto" or dtype is None:
|
|
dtype = config.torch_dtype
|
|
if dtype == "auto" or dtype is None:
|
|
dtype = torch.float32
|
|
return dtype
|
|
|
|
|
|
def get_block_size(
|
|
config: BloomConfig,
|
|
location: str,
|
|
*,
|
|
dtype: Optional[Union[str, torch.dtype]] = None,
|
|
load_in_8bit: Optional[bool] = None,
|
|
eps: float = 0.01, # eps accounts for ~1% of metainfo for tensor descriptions, quantization tables, etc.
|
|
) -> int:
|
|
if location == "memory":
|
|
assert (
|
|
dtype is not None and load_in_8bit is not None
|
|
), 'get_block_size(..., location="memory") requires to specify dtype and load_in_8bit for calculations'
|
|
|
|
with init_empty_weights(include_buffers=True):
|
|
block = WrappedBloomBlock(config)
|
|
n_params = sum(param.numel() for param in block.parameters())
|
|
|
|
if location == "memory" and load_in_8bit:
|
|
# Note: We may need a larger eps here for models of size < 1B
|
|
return n_params * (1 + eps)
|
|
|
|
if location == "memory":
|
|
dtype = resolve_block_dtype(config, dtype)
|
|
elif location == "disk":
|
|
dtype = resolve_block_dtype(config, "auto")
|
|
else:
|
|
raise ValueError('get_block_size() expects location to be "memory" or "disk"')
|
|
|
|
return round(n_params * torch.finfo(dtype).bits // 8 * (1 + eps))
|