Fix use_chunked_forward="auto" on non-x86_64 machines (#267)

Import of cpufeature may crash on non-x86_64 machines, so this PR makes the client import it only if necessary.
pull/275/head
Alexander Borzunov 1 year ago committed by GitHub
parent a2e7f27a5a
commit fd9400b392
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -4,11 +4,12 @@ Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e
See commit history for authorship.
"""
import platform
import psutil
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from cpufeature import CPUFeature
from hivemind import get_logger
from torch import nn
from transformers import BloomConfig
@ -29,9 +30,15 @@ class LMHead(nn.Module):
self.use_chunked_forward = config.use_chunked_forward
if self.use_chunked_forward == "auto":
# If the CPU supports AVX512, plain bfloat16 is ~10x faster than chunked_forward().
# Otherwise, it's ~8x slower.
self.use_chunked_forward = not (CPUFeature["AVX512f"] and CPUFeature["OS_AVX512"])
if platform.machine() == "x86_64":
# Import of cpufeature may crash on non-x86_64 machines
from cpufeature import CPUFeature
# If the CPU supports AVX512, plain bfloat16 is ~10x faster than chunked_forward().
# Otherwise, it's ~8x slower.
self.use_chunked_forward = not (CPUFeature["AVX512f"] and CPUFeature["OS_AVX512"])
else:
self.use_chunked_forward = True
self.chunked_forward_step = config.chunked_forward_step
self._bf16_warning_shown = False

Loading…
Cancel
Save