From fd9400b392d57d6ef16253e74aab0c60c82227a5 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Tue, 21 Feb 2023 06:11:53 +0400 Subject: [PATCH] Fix use_chunked_forward="auto" on non-x86_64 machines (#267) Import of cpufeature may crash on non-x86_64 machines, so this PR makes the client import it only if necessary. --- src/petals/bloom/modeling_utils.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/petals/bloom/modeling_utils.py b/src/petals/bloom/modeling_utils.py index cb069b8..eddbb9d 100644 --- a/src/petals/bloom/modeling_utils.py +++ b/src/petals/bloom/modeling_utils.py @@ -4,11 +4,12 @@ Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e See commit history for authorship. """ +import platform + import psutil import torch import torch.nn.functional as F import torch.utils.checkpoint -from cpufeature import CPUFeature from hivemind import get_logger from torch import nn from transformers import BloomConfig @@ -29,9 +30,15 @@ class LMHead(nn.Module): self.use_chunked_forward = config.use_chunked_forward if self.use_chunked_forward == "auto": - # If the CPU supports AVX512, plain bfloat16 is ~10x faster than chunked_forward(). - # Otherwise, it's ~8x slower. - self.use_chunked_forward = not (CPUFeature["AVX512f"] and CPUFeature["OS_AVX512"]) + if platform.machine() == "x86_64": + # Import of cpufeature may crash on non-x86_64 machines + from cpufeature import CPUFeature + + # If the CPU supports AVX512, plain bfloat16 is ~10x faster than chunked_forward(). + # Otherwise, it's ~8x slower. + self.use_chunked_forward = not (CPUFeature["AVX512f"] and CPUFeature["OS_AVX512"]) + else: + self.use_chunked_forward = True self.chunked_forward_step = config.chunked_forward_step self._bf16_warning_shown = False