mirror of
https://github.com/bigscience-workshop/petals
synced 2024-11-16 06:12:50 +00:00
add quantization script for cpu
This commit is contained in:
parent
ffb56a65ed
commit
05faa0b3c8
0
cli/__init__.py
Normal file
0
cli/__init__.py
Normal file
49
cli/quantize_for_cpu.py
Normal file
49
cli/quantize_for_cpu.py
Normal file
@ -0,0 +1,49 @@
|
||||
import argparse
|
||||
import copy
|
||||
import os
|
||||
|
||||
import psutil
|
||||
import torch.backends.quantized
|
||||
import transformers
|
||||
from hivemind.utils.logging import get_logger
|
||||
from tqdm.auto import trange
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Load bloom layers and convert to 8-bit using torch quantization.")
|
||||
parser.add_argument("--output_path", required=True, type=str, help="Save quantized layers to this folder")
|
||||
parser.add_argument("--model", type=str, default="bigscience/bloom", help="Model name for from_pretrained")
|
||||
parser.add_argument("--revision", type=str, default=None, help="Optional commit id from HF hub")
|
||||
parser.add_argument("--torch_dtype", type=str, default="auto", help="Load initial model in this dtype")
|
||||
parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
|
||||
args = parser.parse_args()
|
||||
|
||||
free_ram_gb = psutil.virtual_memory().available / 2**30
|
||||
if free_ram_gb < 400:
|
||||
logger.warning(f"ACHTUNG! converting bloom-176b will use up 370-400GB RAM, you have {free_ram_gb:.3f} free")
|
||||
|
||||
assert args.torch_dtype in DTYPE_MAP, f"torch_dtype must be one of {list(DTYPE_MAP.keys())}"
|
||||
if os.path.exists(args.output_path) and (
|
||||
len(os.listdir(args.output_path)) != 0 or not os.path.isdir(args.output_path)
|
||||
):
|
||||
raise FileExistsError(f"Output path {args.output_path} already exists and is not an empty directory")
|
||||
|
||||
model = transformers.BloomForCausalLM.from_pretrained(
|
||||
args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype]
|
||||
)
|
||||
|
||||
qconfig = torch.quantization.get_default_qconfig("fbgemm")
|
||||
torch.backends.quantized.engine = "fbgemm"
|
||||
|
||||
os.makedirs(args.output_path, exist_ok=True)
|
||||
|
||||
for i in trange(len(model.transformer.h)):
|
||||
layer_fp32 = copy.deepcopy(model.transformer.h[i]).float()
|
||||
layer_quantized = torch.quantization.quantize_dynamic(
|
||||
layer_fp32, {torch.nn.Linear: qconfig}, dtype=torch.qint8, inplace=True
|
||||
)
|
||||
torch.save(layer_quantized.state_dict(), os.path.join(args.output_path, f"block_{i}_qint8.pth"))
|
32
src/block.py
32
src/block.py
@ -6,11 +6,17 @@ See commit history for authorship.
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.quantized.dynamic.modules.linear
|
||||
import torch.nn as nn
|
||||
import torch.nn.quantized.dynamic.modules.linear
|
||||
|
||||
from src.ops import BloomScaledSoftmax, BloomGelu
|
||||
from src.ops import attention_mask_func, pre_process_alibi_for_pad, split_tensor_along_last_dim, dropout_add
|
||||
from src.ops import (
|
||||
BloomGelu,
|
||||
BloomScaledSoftmax,
|
||||
attention_mask_func,
|
||||
dropout_add,
|
||||
pre_process_alibi_for_pad,
|
||||
split_tensor_along_last_dim,
|
||||
)
|
||||
|
||||
|
||||
class BloomAttention(nn.Module):
|
||||
@ -43,11 +49,13 @@ class BloomAttention(nn.Module):
|
||||
self.layer_number,
|
||||
)
|
||||
|
||||
if config.compression == 'qint8':
|
||||
if config.compression == "qint8":
|
||||
self.query_key_value = nn.quantized.dynamic.modules.Linear(
|
||||
self.hidden_size, 3 * self.hidden_size, bias_=True, dtype=torch.qint8)
|
||||
self.hidden_size, 3 * self.hidden_size, bias_=True, dtype=torch.qint8
|
||||
)
|
||||
self.dense = nn.quantized.dynamic.modules.Linear(
|
||||
self.hidden_size, self.hidden_size, bias_=True, dtype=torch.qint8)
|
||||
self.hidden_size, self.hidden_size, bias_=True, dtype=torch.qint8
|
||||
)
|
||||
else:
|
||||
self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
|
||||
self.dense = nn.Linear(self.hidden_size, self.hidden_size)
|
||||
@ -120,9 +128,7 @@ class BloomAttention(nn.Module):
|
||||
|
||||
# attention scores and attention mask [b, np, sq, sk]
|
||||
max_positions = max(attention_scores.shape[-1], attention_scores.shape[-2])
|
||||
attention_probs = self.scale_mask_softmax(attention_scores, attention_mask, max_positions).to(
|
||||
value_layer.dtype
|
||||
)
|
||||
attention_probs = self.scale_mask_softmax(attention_scores, attention_mask, max_positions).to(value_layer.dtype)
|
||||
attention_probs = self.attention_dropout(attention_probs)
|
||||
|
||||
if head_mask is not None:
|
||||
@ -170,11 +176,13 @@ class BloomMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
if config.compression == 'qint8':
|
||||
if config.compression == "qint8":
|
||||
self.dense_h_to_4h = nn.quantized.dynamic.modules.Linear(
|
||||
self.hidden_size, 4 * self.hidden_size, bias_=True, dtype=torch.qint8)
|
||||
self.hidden_size, 4 * self.hidden_size, bias_=True, dtype=torch.qint8
|
||||
)
|
||||
self.dense_4h_to_h = nn.quantized.dynamic.modules.Linear(
|
||||
4 * self.hidden_size, self.hidden_size, bias_=True, dtype=torch.qint8)
|
||||
4 * self.hidden_size, self.hidden_size, bias_=True, dtype=torch.qint8
|
||||
)
|
||||
else:
|
||||
self.dense_h_to_4h = nn.Linear(self.hidden_size, 4 * self.hidden_size)
|
||||
self.dense_4h_to_h = nn.Linear(4 * self.hidden_size, self.hidden_size)
|
||||
|
11
src/model.py
11
src/model.py
@ -10,12 +10,15 @@ import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss, LayerNorm
|
||||
|
||||
from transformers.file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
|
||||
from transformers.file_utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
)
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.utils import logging
|
||||
from transformers.models.bloom.configuration_bloom import BloomConfig as _VanillaBloomConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
from src.block import BloomBlock
|
||||
from src.ops import build_alibi_tensor
|
||||
@ -28,7 +31,7 @@ _TOKENIZER_FOR_DOC = "BloomTokenizer"
|
||||
|
||||
|
||||
class MemoryEfficientBloomConfig(_VanillaBloomConfig):
|
||||
compression: str = 'none'
|
||||
compression: str = "none"
|
||||
slow_but_exact: bool = False
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user