|
|
|
@ -70,8 +70,8 @@ class LMHead(nn.Module):
|
|
|
|
|
if not self._bf16_warning_shown:
|
|
|
|
|
if self.weight.numel() * 4 < 0.9 * psutil.virtual_memory().total:
|
|
|
|
|
logger.warning(
|
|
|
|
|
"Running the client with dtype bfloat16 on CPU may be slow, since your CPU doesn't support AVX512. "
|
|
|
|
|
"Consider loading the model with torch_dtype='float32'"
|
|
|
|
|
"Running the model in bfloat16 on CPU will be slow since your CPU does not support AVX512. "
|
|
|
|
|
"To speed it up, load the model in float32 using .from_pretrained(..., torch_dtype=torch.float32)"
|
|
|
|
|
)
|
|
|
|
|
self._bf16_warning_shown = True
|
|
|
|
|
|
|
|
|
|