Use 4-bit for llama by default, use bitsandbytes 0.40.0.post3 (#340)

NF4 inference with bitsandbytes 0.40.0.post3 is ~2x faster than int8 inference, though training is still ~3x slower, see:

- [bitsandbytes 0.40.0 Release notes](https://github.com/TimDettmers/bitsandbytes/releases/tag/0.40.0)
- [RPS benchmarks](https://github.com/bigscience-workshop/petals/pull/333#issuecomment-1614040385)

We've decided to use NF4 by default for LLaMA.
pull/336/head^2
Alexander Borzunov 10 months ago committed by GitHub
parent 158013a671
commit fa095f6461
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -32,7 +32,7 @@ packages = find:
python_requires = >=3.7
install_requires =
torch>=1.12
bitsandbytes==0.39.1
bitsandbytes==0.40.0.post3
accelerate>=0.16.0,<1.0.0
huggingface-hub>=0.11.1,<1.0.0
tokenizers>=0.13.3

@ -166,7 +166,10 @@ class Server:
check_device_balance(self.tensor_parallel_devices)
if quant_type is None:
quant_type = QuantType.INT8 if device.type == "cuda" else QuantType.NONE
if device.type == "cuda":
quant_type = QuantType.NF4 if self.block_config.model_type == "llama" else QuantType.INT8
else:
quant_type = QuantType.NONE
self.quant_type = quant_type
logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, quant_type)} format")

Loading…
Cancel
Save