Use bnb==0.40.0.post4 to fix bias bug, use bfloat16 by default

fix-nf4-and-dtypes
Aleksandr Borzunov 11 months ago
parent b28f5016ea
commit b0d55ee655

@ -32,7 +32,7 @@ packages = find:
python_requires = >=3.7
install_requires =
torch>=1.12
bitsandbytes==0.40.0.post3
bitsandbytes==0.40.0.post4
accelerate>=0.16.0,<1.0.0
huggingface-hub>=0.11.1,<1.0.0
tokenizers>=0.13.3

@ -29,9 +29,8 @@ class FromPretrainedMixin:
if low_cpu_mem_usage is None:
low_cpu_mem_usage = True
if torch_dtype is None:
# torch_dtype=None gives torch.float32 in transformers>=4.26.0. In contrast,
# torch_dtype="auto" attempts to (1) use config.torch_dtype (if exists), (2) use dtype of the weights.
torch_dtype = "auto"
# torch_dtype=None gives torch.float32 in transformers>=4.26.0
torch_dtype = torch.bfloat16
with ignore_keys(cls._keys_to_ignore_on_load_unexpected):
return super().from_pretrained(

@ -11,8 +11,6 @@ def resolve_block_dtype(config: PretrainedConfig, dtype: Union[str, torch.dtype]
"""If dtype is "auto", resolves it using BloomConfig. Returns `dtype` intact otherwise."""
if dtype not in ("auto", None):
return dtype
if config.torch_dtype not in ("auto", None):
return config.torch_dtype
return torch.bfloat16

@ -173,6 +173,12 @@ class Server:
self.quant_type = quant_type
logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, quant_type)} format")
if self.block_config.model_type == "llama" and torch_dtype == torch.bfloat16 and quant_type != QuantType.NF4:
logger.warning(
"LLaMA is loaded in bfloat16 for compatibility with --quant_type nf4 servers (default). "
"If you use a private swarm without such servers, use --torch_dtype float16 to force the original float16 dtype"
)
cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens
self._cache_bytes_per_block = cache_values_per_block * torch.finfo(self.torch_dtype).bits // 8

Loading…
Cancel
Save