Disable chunked_forward() on AVX512 CPUs (#179)

pull/177/head
Alexander Borzunov 1 year ago committed by GitHub
parent 6948a0c5ee
commit 55698381d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -42,6 +42,7 @@ install_requires =
tensor_parallel==1.0.23
humanfriendly
async-timeout>=4.0.2
cpufeature>=0.2.0
[options.extras_require]
dev =

@ -4,9 +4,11 @@ Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e
See commit history for authorship.
"""
import psutil
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from cpufeature import CPUFeature
from hivemind import get_logger
from torch import nn
from transformers import BloomConfig
@ -24,7 +26,14 @@ class LMHead(nn.Module):
def __init__(self, config: BloomConfig, word_embeddings: nn.Embedding):
super().__init__()
self.word_embeddings = word_embeddings
self.chunk_size = config.chunk_size_for_efficient_fp16_on_cpu
self.use_chunked_forward = config.use_chunked_forward
if self.use_chunked_forward == "auto":
# If the CPU supports AVX512, plain bfloat16 is ~10x faster than chunked_forward().
# Otherwise, it's ~8x slower.
self.use_chunked_forward = not (CPUFeature["AVX512f"] and CPUFeature["OS_AVX512"])
self.chunked_forward_step = config.chunked_forward_step
self._bf16_warning_shown = False
@property
def in_features(self) -> int:
@ -46,9 +55,9 @@ class LMHead(nn.Module):
word_embeddings = self.word_embeddings.weight
if (
self.chunk_size is not None
and word_embeddings.dtype in [torch.float16, torch.bfloat16]
word_embeddings.dtype in [torch.float16, torch.bfloat16]
and word_embeddings.device.type == "cpu"
and self.use_chunked_forward
):
lm_logits = self.chunked_forward(hidden_states)
else:
@ -59,9 +68,17 @@ class LMHead(nn.Module):
def chunked_forward(self, hidden_states):
"""Splits word embeddings on chunks and iteratively casts them into fp32 to perform matmul more efficiently on CPU.
chunk_size: provides trade-off between efficiency and extra memory consumption.
chunked_forward_step: provides trade-off between efficiency and extra memory consumption.
"""
assert self.chunk_size > 0, "Chunk size for chunked forward must be positive"
assert self.chunked_forward_step > 0, "Chunk size for chunked forward must be positive"
if not self._bf16_warning_shown:
if self.word_embeddings.weight.numel() * 4 < 0.9 * psutil.virtual_memory().total:
logger.warning(
"Running the client with dtype bfloat16 on CPU may be slow, since your CPU doesn't support AVX512. "
"Consider loading the model with torch_dtype='float32'"
)
self._bf16_warning_shown = True
word_embeddings = self.word_embeddings.weight
num_embeddings = self.word_embeddings.num_embeddings
@ -69,7 +86,7 @@ class LMHead(nn.Module):
hidden_states = hidden_states.float()
output = torch.empty(*hidden_states.shape[:-1], num_embeddings)
for i in range(0, num_embeddings, self.chunk_size):
chunk = word_embeddings[i : i + self.chunk_size].float()
output[..., i : i + self.chunk_size] = F.linear(hidden_states, chunk)
for i in range(0, num_embeddings, self.chunked_forward_step):
chunk = word_embeddings[i : i + self.chunked_forward_step].float()
output[..., i : i + self.chunked_forward_step] = F.linear(hidden_states, chunk)
return output

@ -1,6 +1,6 @@
import os
from contextlib import contextmanager
from typing import List, Optional
from typing import List, Optional, Union
import hivemind
import torch
@ -34,11 +34,15 @@ class DistributedBloomConfig(BloomConfig):
dht_prefix: str # a prefix for all dht keys that correspond to this model (usually equal to model name)
daemon_startup_timeout: int = 30
dht: Optional[hivemind.DHT] = None # a running DHT instance, e.g. when using the same DHT for multiple models
chunk_size_for_efficient_fp16_on_cpu: Optional[int] = 10000
# Chunk size for efficient half-precision on CPU in the LM head. Set to None if your CPU works fast with bfloat16.
request_timeout: int = 30 # a number of seconds for waiting result from each node
pre_seq_len: int = 0 # a number of tokens for prompt tuning.
tuning_mode: Optional[str] = None # One of the finetune options: [None, 'shallow_ptune', 'deep_ptune', 'adapters']
request_timeout: int = 30 # a number of seconds for waiting result from each node
# This settings matter for running the client with dtype bfloat16 on CPU.
# If the CPU doesn't support AVX512, chunked_forward() significantly speeds up computations.
use_chunked_forward: Union[str, bool] = "auto"
chunked_forward_step: int = 16384
original_register_parameter = nn.Module.register_parameter

Loading…
Cancel
Save