|
|
|
@ -4,11 +4,12 @@ Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e
|
|
|
|
|
See commit history for authorship.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import platform
|
|
|
|
|
|
|
|
|
|
import psutil
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
import torch.utils.checkpoint
|
|
|
|
|
from cpufeature import CPUFeature
|
|
|
|
|
from hivemind import get_logger
|
|
|
|
|
from torch import nn
|
|
|
|
|
from transformers import BloomConfig
|
|
|
|
@ -29,9 +30,15 @@ class LMHead(nn.Module):
|
|
|
|
|
|
|
|
|
|
self.use_chunked_forward = config.use_chunked_forward
|
|
|
|
|
if self.use_chunked_forward == "auto":
|
|
|
|
|
# If the CPU supports AVX512, plain bfloat16 is ~10x faster than chunked_forward().
|
|
|
|
|
# Otherwise, it's ~8x slower.
|
|
|
|
|
self.use_chunked_forward = not (CPUFeature["AVX512f"] and CPUFeature["OS_AVX512"])
|
|
|
|
|
if platform.machine() == "x86_64":
|
|
|
|
|
# Import of cpufeature may crash on non-x86_64 machines
|
|
|
|
|
from cpufeature import CPUFeature
|
|
|
|
|
|
|
|
|
|
# If the CPU supports AVX512, plain bfloat16 is ~10x faster than chunked_forward().
|
|
|
|
|
# Otherwise, it's ~8x slower.
|
|
|
|
|
self.use_chunked_forward = not (CPUFeature["AVX512f"] and CPUFeature["OS_AVX512"])
|
|
|
|
|
else:
|
|
|
|
|
self.use_chunked_forward = True
|
|
|
|
|
self.chunked_forward_step = config.chunked_forward_step
|
|
|
|
|
self._bf16_warning_shown = False
|
|
|
|
|
|
|
|
|
|