|
|
|
@ -171,9 +171,11 @@ class Server:
|
|
|
|
|
|
|
|
|
|
if quant_type is None:
|
|
|
|
|
if device.type == "cuda":
|
|
|
|
|
quant_type = QuantType.NF4 if self.block_config.model_type == "llama" else QuantType.INT8
|
|
|
|
|
quant_type = QuantType.INT8
|
|
|
|
|
else:
|
|
|
|
|
quant_type = QuantType.NONE
|
|
|
|
|
elif quant_type == QuantType.NF4:
|
|
|
|
|
raise RuntimeError("4-bit quantization (NF4) is not supported on AMD GPUs!")
|
|
|
|
|
self.quant_type = quant_type
|
|
|
|
|
logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, quant_type)} format")
|
|
|
|
|
|
|
|
|
|