|
|
|
@ -173,6 +173,12 @@ class Server:
|
|
|
|
|
self.quant_type = quant_type
|
|
|
|
|
logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, quant_type)} format")
|
|
|
|
|
|
|
|
|
|
if self.block_config.model_type == "llama" and torch_dtype == torch.bfloat16 and quant_type != QuantType.NF4:
|
|
|
|
|
logger.warning(
|
|
|
|
|
"LLaMA is loaded in bfloat16 for compatibility with --quant_type nf4 servers (default). "
|
|
|
|
|
"If you use a private swarm without such servers, use --torch_dtype float16 to force the original float16 dtype"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens
|
|
|
|
|
self._cache_bytes_per_block = cache_values_per_block * torch.finfo(self.torch_dtype).bits // 8
|
|
|
|
|
|
|
|
|
|