Improve bfloat16 & no-AVX-512 warning

pull/464/head
Aleksandr Borzunov 10 months ago
parent c69e9f02e7
commit d19bac4962

@ -70,8 +70,8 @@ class LMHead(nn.Module):
if not self._bf16_warning_shown:
if self.weight.numel() * 4 < 0.9 * psutil.virtual_memory().total:
logger.warning(
"Running the client with dtype bfloat16 on CPU may be slow, since your CPU doesn't support AVX512. "
"Consider loading the model with torch_dtype='float32'"
"Running the model in bfloat16 on CPU will be slow since your CPU does not support AVX512. "
"To speed it up, load the model in float32 using .from_pretrained(..., torch_dtype=torch.float32)"
)
self._bf16_warning_shown = True

Loading…
Cancel
Save