remove slow_but_exact, add quantization

This commit is contained in:
justheuristic 2022-06-12 04:56:26 +03:00
parent 51e96ac19b
commit 43399d7898
2 changed files with 36 additions and 44 deletions

View File

@ -5,20 +5,17 @@ Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e
import math
import torch
from torch import nn
from torch.nn import LayerNorm
import torch.nn.quantized.dynamic.modules.linear
import torch.nn as nn
from src.ops import BloomScaledSoftmax, attention_mask_func, pre_process_alibi_for_pad, split_tensor_along_last_dim, \
dropout_add, BloomGelu
from src.ops import BloomScaledSoftmax, BloomGelu
from src.ops import attention_mask_func, pre_process_alibi_for_pad, split_tensor_along_last_dim, dropout_add
class BloomAttention(nn.Module):
def __init__(self, config, layer_number=None):
super().__init__()
self.pretraining_tp = config.pretraining_tp
self.slow_but_exact = config.slow_but_exact
self.hidden_size = config.hidden_size
self.num_heads = config.n_head
self.head_dim = self.hidden_size // self.num_heads
@ -45,8 +42,15 @@ class BloomAttention(nn.Module):
self.layer_number,
)
self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
self.dense = nn.Linear(self.hidden_size, self.hidden_size)
if config.compression == 'qint8':
self.query_key_value = nn.quantized.dynamic.modules.Linear(
self.hidden_size, 3 * self.hidden_size, bias_=True, dtype=torch.qint8)
self.dense = nn.quantized.dynamic.modules.Linear(
self.hidden_size, self.hidden_size, bias_=True, dtype=torch.qint8)
else:
self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
self.dense = nn.Linear(self.hidden_size, self.hidden_size)
self.attention_dropout = nn.Dropout(config.attention_dropout)
def forward(
@ -149,17 +153,7 @@ class BloomAttention(nn.Module):
# Output. [q_length, batch_size, hidden_size]
# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
if self.pretraining_tp > 1 and self.slow_but_exact:
slices = context_layer.shape[-1] / self.pretraining_tp
output_tensor = torch.zeros_like(context_layer)
for i in range(self.pretraining_tp):
output_tensor = output_tensor + nn.functional.linear(
context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
)
else:
output_tensor = self.dense(context_layer)
output_tensor = self.dense(context_layer)
output = output_tensor.transpose(1, 0)
output = dropout_add(output, residual, self.hidden_dropout, self.training)
@ -175,30 +169,21 @@ class BloomMLP(nn.Module):
def __init__(self, config):
super().__init__()
hidden_size = config.hidden_size
self.pretraining_tp = config.pretraining_tp
self.slow_but_exact = config.slow_but_exact
self.dense_h_to_4h = nn.Linear(hidden_size, 4 * hidden_size)
self.dense_4h_to_h = nn.Linear(4 * hidden_size, hidden_size)
if config.compression == 'qint8':
self.dense_h_to_4h = nn.quantized.dynamic.modules.Linear(
self.hidden_size, 4 * self.hidden_size, bias_=True, dtype=torch.qint8)
self.dense_4h_to_h = nn.quantized.dynamic.modules.Linear(
4 * self.hidden_size, self.hidden_size, bias_=True, dtype=torch.qint8)
else:
self.dense_h_to_4h = nn.Linear(hidden_size, 4 * hidden_size)
self.dense_4h_to_h = nn.Linear(4 * hidden_size, hidden_size)
self.hidden_dropout = config.hidden_dropout
self.gelu_impl = BloomGelu()
def forward(self, hidden_states, residual):
hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))
if self.pretraining_tp > 1 and self.slow_but_exact:
intermediate_output = torch.zeros_like(residual)
slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp
for i in range(self.pretraining_tp):
intermediate_output = intermediate_output + nn.functional.linear(
hidden_states[:, :, int(i * slices) : int((i + 1) * slices)],
self.dense_4h_to_h.weight[:, int(i * slices) : int((i + 1) * slices)],
)
else:
intermediate_output = self.dense_4h_to_h(hidden_states)
intermediate_output = self.dense_4h_to_h(hidden_states)
output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
return output
@ -207,10 +192,10 @@ class BloomBlock(nn.Module):
super().__init__()
hidden_size = config.hidden_size
self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.input_layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.n_head = config.n_head
self.self_attention = BloomAttention(config, layer_number=layer_number)
self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.post_attention_layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = BloomMLP(config)

View File

@ -11,7 +11,7 @@ from transformers.file_utils import add_code_sample_docstrings, add_start_docstr
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from transformers.models.bloom.configuration_bloom import BloomConfig
from transformers.models.bloom.configuration_bloom import BloomConfig as _VanillaBloomConfig
from src.layer import BloomBlock
from src.ops import build_alibi_tensor
@ -19,8 +19,14 @@ from src.ops import build_alibi_tensor
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "bigscience/Bloom"
_CONFIG_FOR_DOC = "BloomConfig"
_CONFIG_FOR_DOC = "MemoryEfficientBloomConfig"
_TOKENIZER_FOR_DOC = "BloomTokenizer"
_NOT_IMPLEMENTED = 'NOT_IMPLEMENTED'
class MemoryEfficientBloomConfig(_VanillaBloomConfig):
compression: str = 'none'
slow_but_exact = _NOT_IMPLEMENTED
class BloomPreTrainedModel(PreTrainedModel):
@ -30,7 +36,7 @@ class BloomPreTrainedModel(PreTrainedModel):
models.
"""
config_class = BloomConfig
config_class = MemoryEfficientBloomConfig
base_model_prefix = "transformer"
supports_gradient_checkpointing = True
_no_split_modules = ["BloomBlock"]
@ -69,7 +75,7 @@ BLOOM_START_DOCSTRING = r"""
and behavior.
Parameters:
config ([`BloomConfig`]): Model configuration class with all the parameters of the model.
config ([`MemoryEfficientBloomConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@ -138,6 +144,7 @@ BLOOM_INPUTS_DOCSTRING = r"""
class BloomModel(BloomPreTrainedModel):
def __init__(self, config):
super().__init__(config)
assert config.slow_but_exact == _NOT_IMPLEMENTED, "slow_but_exact mode was removed for code simplicity"
self.embed_dim = config.hidden_size
self.n_head = config.n_head