mirror of
https://github.com/bigscience-workshop/petals
synced 2024-11-16 06:12:50 +00:00
remove slow_but_exact, add quantization
This commit is contained in:
parent
51e96ac19b
commit
43399d7898
65
src/layer.py
65
src/layer.py
@ -5,20 +5,17 @@ Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import LayerNorm
|
||||
import torch.nn.quantized.dynamic.modules.linear
|
||||
import torch.nn as nn
|
||||
|
||||
from src.ops import BloomScaledSoftmax, attention_mask_func, pre_process_alibi_for_pad, split_tensor_along_last_dim, \
|
||||
dropout_add, BloomGelu
|
||||
from src.ops import BloomScaledSoftmax, BloomGelu
|
||||
from src.ops import attention_mask_func, pre_process_alibi_for_pad, split_tensor_along_last_dim, dropout_add
|
||||
|
||||
|
||||
class BloomAttention(nn.Module):
|
||||
def __init__(self, config, layer_number=None):
|
||||
super().__init__()
|
||||
|
||||
self.pretraining_tp = config.pretraining_tp
|
||||
self.slow_but_exact = config.slow_but_exact
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.n_head
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
@ -45,8 +42,15 @@ class BloomAttention(nn.Module):
|
||||
self.layer_number,
|
||||
)
|
||||
|
||||
self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
|
||||
self.dense = nn.Linear(self.hidden_size, self.hidden_size)
|
||||
if config.compression == 'qint8':
|
||||
self.query_key_value = nn.quantized.dynamic.modules.Linear(
|
||||
self.hidden_size, 3 * self.hidden_size, bias_=True, dtype=torch.qint8)
|
||||
self.dense = nn.quantized.dynamic.modules.Linear(
|
||||
self.hidden_size, self.hidden_size, bias_=True, dtype=torch.qint8)
|
||||
else:
|
||||
self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
|
||||
self.dense = nn.Linear(self.hidden_size, self.hidden_size)
|
||||
|
||||
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
||||
|
||||
def forward(
|
||||
@ -149,17 +153,7 @@ class BloomAttention(nn.Module):
|
||||
# Output. [q_length, batch_size, hidden_size]
|
||||
|
||||
# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
|
||||
if self.pretraining_tp > 1 and self.slow_but_exact:
|
||||
slices = context_layer.shape[-1] / self.pretraining_tp
|
||||
output_tensor = torch.zeros_like(context_layer)
|
||||
for i in range(self.pretraining_tp):
|
||||
output_tensor = output_tensor + nn.functional.linear(
|
||||
context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
|
||||
self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
|
||||
)
|
||||
else:
|
||||
output_tensor = self.dense(context_layer)
|
||||
|
||||
output_tensor = self.dense(context_layer)
|
||||
output = output_tensor.transpose(1, 0)
|
||||
|
||||
output = dropout_add(output, residual, self.hidden_dropout, self.training)
|
||||
@ -175,30 +169,21 @@ class BloomMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
|
||||
self.pretraining_tp = config.pretraining_tp
|
||||
self.slow_but_exact = config.slow_but_exact
|
||||
self.dense_h_to_4h = nn.Linear(hidden_size, 4 * hidden_size)
|
||||
self.dense_4h_to_h = nn.Linear(4 * hidden_size, hidden_size)
|
||||
if config.compression == 'qint8':
|
||||
self.dense_h_to_4h = nn.quantized.dynamic.modules.Linear(
|
||||
self.hidden_size, 4 * self.hidden_size, bias_=True, dtype=torch.qint8)
|
||||
self.dense_4h_to_h = nn.quantized.dynamic.modules.Linear(
|
||||
4 * self.hidden_size, self.hidden_size, bias_=True, dtype=torch.qint8)
|
||||
else:
|
||||
self.dense_h_to_4h = nn.Linear(hidden_size, 4 * hidden_size)
|
||||
self.dense_4h_to_h = nn.Linear(4 * hidden_size, hidden_size)
|
||||
self.hidden_dropout = config.hidden_dropout
|
||||
self.gelu_impl = BloomGelu()
|
||||
|
||||
def forward(self, hidden_states, residual):
|
||||
hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))
|
||||
|
||||
if self.pretraining_tp > 1 and self.slow_but_exact:
|
||||
intermediate_output = torch.zeros_like(residual)
|
||||
slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp
|
||||
for i in range(self.pretraining_tp):
|
||||
intermediate_output = intermediate_output + nn.functional.linear(
|
||||
hidden_states[:, :, int(i * slices) : int((i + 1) * slices)],
|
||||
self.dense_4h_to_h.weight[:, int(i * slices) : int((i + 1) * slices)],
|
||||
)
|
||||
else:
|
||||
intermediate_output = self.dense_4h_to_h(hidden_states)
|
||||
|
||||
intermediate_output = self.dense_4h_to_h(hidden_states)
|
||||
output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@ -207,10 +192,10 @@ class BloomBlock(nn.Module):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
|
||||
self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.input_layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.n_head = config.n_head
|
||||
self.self_attention = BloomAttention(config, layer_number=layer_number)
|
||||
self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.post_attention_layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
self.mlp = BloomMLP(config)
|
||||
|
||||
|
15
src/model.py
15
src/model.py
@ -11,7 +11,7 @@ from transformers.file_utils import add_code_sample_docstrings, add_start_docstr
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.utils import logging
|
||||
from transformers.models.bloom.configuration_bloom import BloomConfig
|
||||
from transformers.models.bloom.configuration_bloom import BloomConfig as _VanillaBloomConfig
|
||||
|
||||
from src.layer import BloomBlock
|
||||
from src.ops import build_alibi_tensor
|
||||
@ -19,8 +19,14 @@ from src.ops import build_alibi_tensor
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "bigscience/Bloom"
|
||||
_CONFIG_FOR_DOC = "BloomConfig"
|
||||
_CONFIG_FOR_DOC = "MemoryEfficientBloomConfig"
|
||||
_TOKENIZER_FOR_DOC = "BloomTokenizer"
|
||||
_NOT_IMPLEMENTED = 'NOT_IMPLEMENTED'
|
||||
|
||||
|
||||
class MemoryEfficientBloomConfig(_VanillaBloomConfig):
|
||||
compression: str = 'none'
|
||||
slow_but_exact = _NOT_IMPLEMENTED
|
||||
|
||||
|
||||
class BloomPreTrainedModel(PreTrainedModel):
|
||||
@ -30,7 +36,7 @@ class BloomPreTrainedModel(PreTrainedModel):
|
||||
models.
|
||||
"""
|
||||
|
||||
config_class = BloomConfig
|
||||
config_class = MemoryEfficientBloomConfig
|
||||
base_model_prefix = "transformer"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["BloomBlock"]
|
||||
@ -69,7 +75,7 @@ BLOOM_START_DOCSTRING = r"""
|
||||
and behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`BloomConfig`]): Model configuration class with all the parameters of the model.
|
||||
config ([`MemoryEfficientBloomConfig`]): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
@ -138,6 +144,7 @@ BLOOM_INPUTS_DOCSTRING = r"""
|
||||
class BloomModel(BloomPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
assert config.slow_but_exact == _NOT_IMPLEMENTED, "slow_but_exact mode was removed for code simplicity"
|
||||
|
||||
self.embed_dim = config.hidden_size
|
||||
self.n_head = config.n_head
|
||||
|
Loading…
Reference in New Issue
Block a user