Fix use_chunked_forward="auto" on non-x86_64 machines (#267)

Import of cpufeature may crash on non-x86_64 machines, so this PR makes the client import it only if necessary.
pull/275/head
Alexander Borzunov 1 year ago committed by GitHub
parent a2e7f27a5a
commit fd9400b392
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -4,11 +4,12 @@ Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e
See commit history for authorship. See commit history for authorship.
""" """
import platform
import psutil import psutil
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from cpufeature import CPUFeature
from hivemind import get_logger from hivemind import get_logger
from torch import nn from torch import nn
from transformers import BloomConfig from transformers import BloomConfig
@ -29,9 +30,15 @@ class LMHead(nn.Module):
self.use_chunked_forward = config.use_chunked_forward self.use_chunked_forward = config.use_chunked_forward
if self.use_chunked_forward == "auto": if self.use_chunked_forward == "auto":
# If the CPU supports AVX512, plain bfloat16 is ~10x faster than chunked_forward(). if platform.machine() == "x86_64":
# Otherwise, it's ~8x slower. # Import of cpufeature may crash on non-x86_64 machines
self.use_chunked_forward = not (CPUFeature["AVX512f"] and CPUFeature["OS_AVX512"]) from cpufeature import CPUFeature
# If the CPU supports AVX512, plain bfloat16 is ~10x faster than chunked_forward().
# Otherwise, it's ~8x slower.
self.use_chunked_forward = not (CPUFeature["AVX512f"] and CPUFeature["OS_AVX512"])
else:
self.use_chunked_forward = True
self.chunked_forward_step = config.chunked_forward_step self.chunked_forward_step = config.chunked_forward_step
self._bf16_warning_shown = False self._bf16_warning_shown = False

Loading…
Cancel
Save