add quantization script for cpu

This commit is contained in:
justheuristic 2022-06-12 05:59:11 +03:00
parent ffb56a65ed
commit 05faa0b3c8
4 changed files with 76 additions and 16 deletions

0
cli/__init__.py Normal file
View File

49
cli/quantize_for_cpu.py Normal file
View File

@ -0,0 +1,49 @@
import argparse
import copy
import os
import psutil
import torch.backends.quantized
import transformers
from hivemind.utils.logging import get_logger
from tqdm.auto import trange
logger = get_logger(__file__)
DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Load bloom layers and convert to 8-bit using torch quantization.")
parser.add_argument("--output_path", required=True, type=str, help="Save quantized layers to this folder")
parser.add_argument("--model", type=str, default="bigscience/bloom", help="Model name for from_pretrained")
parser.add_argument("--revision", type=str, default=None, help="Optional commit id from HF hub")
parser.add_argument("--torch_dtype", type=str, default="auto", help="Load initial model in this dtype")
parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
args = parser.parse_args()
free_ram_gb = psutil.virtual_memory().available / 2**30
if free_ram_gb < 400:
logger.warning(f"ACHTUNG! converting bloom-176b will use up 370-400GB RAM, you have {free_ram_gb:.3f} free")
assert args.torch_dtype in DTYPE_MAP, f"torch_dtype must be one of {list(DTYPE_MAP.keys())}"
if os.path.exists(args.output_path) and (
len(os.listdir(args.output_path)) != 0 or not os.path.isdir(args.output_path)
):
raise FileExistsError(f"Output path {args.output_path} already exists and is not an empty directory")
model = transformers.BloomForCausalLM.from_pretrained(
args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype]
)
qconfig = torch.quantization.get_default_qconfig("fbgemm")
torch.backends.quantized.engine = "fbgemm"
os.makedirs(args.output_path, exist_ok=True)
for i in trange(len(model.transformer.h)):
layer_fp32 = copy.deepcopy(model.transformer.h[i]).float()
layer_quantized = torch.quantization.quantize_dynamic(
layer_fp32, {torch.nn.Linear: qconfig}, dtype=torch.qint8, inplace=True
)
torch.save(layer_quantized.state_dict(), os.path.join(args.output_path, f"block_{i}_qint8.pth"))

View File

@ -6,11 +6,17 @@ See commit history for authorship.
import math
import torch
import torch.nn.quantized.dynamic.modules.linear
import torch.nn as nn
import torch.nn.quantized.dynamic.modules.linear
from src.ops import BloomScaledSoftmax, BloomGelu
from src.ops import attention_mask_func, pre_process_alibi_for_pad, split_tensor_along_last_dim, dropout_add
from src.ops import (
BloomGelu,
BloomScaledSoftmax,
attention_mask_func,
dropout_add,
pre_process_alibi_for_pad,
split_tensor_along_last_dim,
)
class BloomAttention(nn.Module):
@ -43,11 +49,13 @@ class BloomAttention(nn.Module):
self.layer_number,
)
if config.compression == 'qint8':
if config.compression == "qint8":
self.query_key_value = nn.quantized.dynamic.modules.Linear(
self.hidden_size, 3 * self.hidden_size, bias_=True, dtype=torch.qint8)
self.hidden_size, 3 * self.hidden_size, bias_=True, dtype=torch.qint8
)
self.dense = nn.quantized.dynamic.modules.Linear(
self.hidden_size, self.hidden_size, bias_=True, dtype=torch.qint8)
self.hidden_size, self.hidden_size, bias_=True, dtype=torch.qint8
)
else:
self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
self.dense = nn.Linear(self.hidden_size, self.hidden_size)
@ -120,9 +128,7 @@ class BloomAttention(nn.Module):
# attention scores and attention mask [b, np, sq, sk]
max_positions = max(attention_scores.shape[-1], attention_scores.shape[-2])
attention_probs = self.scale_mask_softmax(attention_scores, attention_mask, max_positions).to(
value_layer.dtype
)
attention_probs = self.scale_mask_softmax(attention_scores, attention_mask, max_positions).to(value_layer.dtype)
attention_probs = self.attention_dropout(attention_probs)
if head_mask is not None:
@ -170,11 +176,13 @@ class BloomMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
if config.compression == 'qint8':
if config.compression == "qint8":
self.dense_h_to_4h = nn.quantized.dynamic.modules.Linear(
self.hidden_size, 4 * self.hidden_size, bias_=True, dtype=torch.qint8)
self.hidden_size, 4 * self.hidden_size, bias_=True, dtype=torch.qint8
)
self.dense_4h_to_h = nn.quantized.dynamic.modules.Linear(
4 * self.hidden_size, self.hidden_size, bias_=True, dtype=torch.qint8)
4 * self.hidden_size, self.hidden_size, bias_=True, dtype=torch.qint8
)
else:
self.dense_h_to_4h = nn.Linear(self.hidden_size, 4 * self.hidden_size)
self.dense_4h_to_h = nn.Linear(4 * self.hidden_size, self.hidden_size)

View File

@ -10,12 +10,15 @@ import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss, LayerNorm
from transformers.file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from transformers.file_utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
)
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from transformers.models.bloom.configuration_bloom import BloomConfig as _VanillaBloomConfig
from transformers.utils import logging
from src.block import BloomBlock
from src.ops import build_alibi_tensor
@ -28,7 +31,7 @@ _TOKENIZER_FOR_DOC = "BloomTokenizer"
class MemoryEfficientBloomConfig(_VanillaBloomConfig):
compression: str = 'none'
compression: str = "none"
slow_but_exact: bool = False