Improve block size calculations (#149)
parent
f42e559c77
commit
83d9493b6c
@ -0,0 +1,48 @@
|
|||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
|
||||||
|
from petals.bloom import BloomBlock, BloomConfig
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_block_dtype(config: BloomConfig, dtype: Union[str, torch.dtype]) -> Union[str, torch.dtype]:
|
||||||
|
"""If dtype is "auto", resolves it using BloomConfig. Returns `dtype` intact otherwise."""
|
||||||
|
|
||||||
|
if dtype == "auto" or dtype is None:
|
||||||
|
dtype = config.torch_dtype
|
||||||
|
if dtype == "auto" or dtype is None:
|
||||||
|
dtype = torch.float32
|
||||||
|
return dtype
|
||||||
|
|
||||||
|
|
||||||
|
def get_block_size(
|
||||||
|
config: BloomConfig,
|
||||||
|
location: str,
|
||||||
|
*,
|
||||||
|
dtype: Optional[Union[str, torch.dtype]] = None,
|
||||||
|
load_in_8bit: Optional[bool] = None,
|
||||||
|
layer_index: int = 0,
|
||||||
|
eps: float = 0.01, # eps accounts for ~1% of metainfo for tensor descriptions, quantization tables, etc.
|
||||||
|
) -> int:
|
||||||
|
if location == "memory":
|
||||||
|
assert (
|
||||||
|
dtype is not None and load_in_8bit is not None
|
||||||
|
), 'get_block_size(..., location="memory") requires to specify dtype and load_in_8bit for calculations'
|
||||||
|
|
||||||
|
with init_empty_weights():
|
||||||
|
block = BloomBlock(config, layer_index)
|
||||||
|
n_params = sum(param.numel() for param in block.parameters())
|
||||||
|
|
||||||
|
if location == "memory" and load_in_8bit:
|
||||||
|
# Note: We may need a larger eps here for models of size < 1B
|
||||||
|
return n_params * (1 + eps)
|
||||||
|
|
||||||
|
if location == "memory":
|
||||||
|
dtype = resolve_block_dtype(config, dtype)
|
||||||
|
elif location == "disk":
|
||||||
|
dtype = resolve_block_dtype(config, "auto")
|
||||||
|
else:
|
||||||
|
raise ValueError('get_block_size() expects location to be "memory" or "disk"')
|
||||||
|
|
||||||
|
return round(n_params * torch.finfo(dtype).bits // 8 * (1 + eps))
|
Loading…
Reference in New Issue