From f53c581690f6bcc565b1325097157c17b5baac6d Mon Sep 17 00:00:00 2001 From: Aleksandr Borzunov Date: Wed, 12 Jul 2023 10:39:49 +0000 Subject: [PATCH] Resolve review comments --- src/petals/server/server.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 5f2f3c8..10adb00 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -173,10 +173,14 @@ class Server: self.quant_type = quant_type logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, quant_type)} format") - if self.block_config.model_type == "llama" and torch_dtype == torch.bfloat16 and quant_type != QuantType.NF4: + if ( + self.block_config.torch_dtype == torch.float16 # If weights are in float16 + and torch_dtype == torch.bfloat16 # but we load them in bfloat16 + and quant_type != QuantType.NF4 + ): logger.warning( - "LLaMA is loaded in bfloat16 for compatibility with --quant_type nf4 servers (default). " - "If you use a private swarm without such servers, use --torch_dtype float16 to force the original float16 dtype" + "LLaMA is loaded in bfloat16 for compatibility with NF4 servers holding Guanaco adapters. " + "If you want to run it in float16, use --torch_dtype float16" ) cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens