From fa095f6461c50f600f950ba18deaf633d804c68e Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Tue, 11 Jul 2023 18:53:17 +0400 Subject: [PATCH] Use 4-bit for llama by default, use bitsandbytes 0.40.0.post3 (#340) NF4 inference with bitsandbytes 0.40.0.post3 is ~2x faster than int8 inference, though training is still ~3x slower, see: - [bitsandbytes 0.40.0 Release notes](https://github.com/TimDettmers/bitsandbytes/releases/tag/0.40.0) - [RPS benchmarks](https://github.com/bigscience-workshop/petals/pull/333#issuecomment-1614040385) We've decided to use NF4 by default for LLaMA. --- setup.cfg | 2 +- src/petals/server/server.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index fb1fa23..76185eb 100644 --- a/setup.cfg +++ b/setup.cfg @@ -32,7 +32,7 @@ packages = find: python_requires = >=3.7 install_requires = torch>=1.12 - bitsandbytes==0.39.1 + bitsandbytes==0.40.0.post3 accelerate>=0.16.0,<1.0.0 huggingface-hub>=0.11.1,<1.0.0 tokenizers>=0.13.3 diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 894e9ea..eddb76e 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -166,7 +166,10 @@ class Server: check_device_balance(self.tensor_parallel_devices) if quant_type is None: - quant_type = QuantType.INT8 if device.type == "cuda" else QuantType.NONE + if device.type == "cuda": + quant_type = QuantType.NF4 if self.block_config.model_type == "llama" else QuantType.INT8 + else: + quant_type = QuantType.NONE self.quant_type = quant_type logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, quant_type)} format")