|
|
|
@ -173,10 +173,14 @@ class Server:
|
|
|
|
|
self.quant_type = quant_type
|
|
|
|
|
logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, quant_type)} format")
|
|
|
|
|
|
|
|
|
|
if self.block_config.model_type == "llama" and torch_dtype == torch.bfloat16 and quant_type != QuantType.NF4:
|
|
|
|
|
if (
|
|
|
|
|
self.block_config.torch_dtype == torch.float16 # If weights are in float16
|
|
|
|
|
and torch_dtype == torch.bfloat16 # but we load them in bfloat16
|
|
|
|
|
and quant_type != QuantType.NF4
|
|
|
|
|
):
|
|
|
|
|
logger.warning(
|
|
|
|
|
"LLaMA is loaded in bfloat16 for compatibility with --quant_type nf4 servers (default). "
|
|
|
|
|
"If you use a private swarm without such servers, use --torch_dtype float16 to force the original float16 dtype"
|
|
|
|
|
"LLaMA is loaded in bfloat16 for compatibility with NF4 servers holding Guanaco adapters. "
|
|
|
|
|
"If you want to run it in float16, use --torch_dtype float16"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens
|
|
|
|
|