|
|
|
@ -5,20 +5,17 @@ Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e
|
|
|
|
|
import math
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from torch import nn
|
|
|
|
|
from torch.nn import LayerNorm
|
|
|
|
|
import torch.nn.quantized.dynamic.modules.linear
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
|
|
|
|
from src.ops import BloomScaledSoftmax, attention_mask_func, pre_process_alibi_for_pad, split_tensor_along_last_dim, \
|
|
|
|
|
dropout_add, BloomGelu
|
|
|
|
|
from src.ops import BloomScaledSoftmax, BloomGelu
|
|
|
|
|
from src.ops import attention_mask_func, pre_process_alibi_for_pad, split_tensor_along_last_dim, dropout_add
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BloomAttention(nn.Module):
|
|
|
|
|
def __init__(self, config, layer_number=None):
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
self.pretraining_tp = config.pretraining_tp
|
|
|
|
|
self.slow_but_exact = config.slow_but_exact
|
|
|
|
|
|
|
|
|
|
self.hidden_size = config.hidden_size
|
|
|
|
|
self.num_heads = config.n_head
|
|
|
|
|
self.head_dim = self.hidden_size // self.num_heads
|
|
|
|
@ -45,8 +42,15 @@ class BloomAttention(nn.Module):
|
|
|
|
|
self.layer_number,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
|
|
|
|
|
self.dense = nn.Linear(self.hidden_size, self.hidden_size)
|
|
|
|
|
if config.compression == 'qint8':
|
|
|
|
|
self.query_key_value = nn.quantized.dynamic.modules.Linear(
|
|
|
|
|
self.hidden_size, 3 * self.hidden_size, bias_=True, dtype=torch.qint8)
|
|
|
|
|
self.dense = nn.quantized.dynamic.modules.Linear(
|
|
|
|
|
self.hidden_size, self.hidden_size, bias_=True, dtype=torch.qint8)
|
|
|
|
|
else:
|
|
|
|
|
self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
|
|
|
|
|
self.dense = nn.Linear(self.hidden_size, self.hidden_size)
|
|
|
|
|
|
|
|
|
|
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
@ -149,17 +153,7 @@ class BloomAttention(nn.Module):
|
|
|
|
|
# Output. [q_length, batch_size, hidden_size]
|
|
|
|
|
|
|
|
|
|
# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
|
|
|
|
|
if self.pretraining_tp > 1 and self.slow_but_exact:
|
|
|
|
|
slices = context_layer.shape[-1] / self.pretraining_tp
|
|
|
|
|
output_tensor = torch.zeros_like(context_layer)
|
|
|
|
|
for i in range(self.pretraining_tp):
|
|
|
|
|
output_tensor = output_tensor + nn.functional.linear(
|
|
|
|
|
context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
|
|
|
|
|
self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
output_tensor = self.dense(context_layer)
|
|
|
|
|
|
|
|
|
|
output_tensor = self.dense(context_layer)
|
|
|
|
|
output = output_tensor.transpose(1, 0)
|
|
|
|
|
|
|
|
|
|
output = dropout_add(output, residual, self.hidden_dropout, self.training)
|
|
|
|
@ -175,30 +169,21 @@ class BloomMLP(nn.Module):
|
|
|
|
|
def __init__(self, config):
|
|
|
|
|
super().__init__()
|
|
|
|
|
hidden_size = config.hidden_size
|
|
|
|
|
|
|
|
|
|
self.pretraining_tp = config.pretraining_tp
|
|
|
|
|
self.slow_but_exact = config.slow_but_exact
|
|
|
|
|
self.dense_h_to_4h = nn.Linear(hidden_size, 4 * hidden_size)
|
|
|
|
|
self.dense_4h_to_h = nn.Linear(4 * hidden_size, hidden_size)
|
|
|
|
|
if config.compression == 'qint8':
|
|
|
|
|
self.dense_h_to_4h = nn.quantized.dynamic.modules.Linear(
|
|
|
|
|
self.hidden_size, 4 * self.hidden_size, bias_=True, dtype=torch.qint8)
|
|
|
|
|
self.dense_4h_to_h = nn.quantized.dynamic.modules.Linear(
|
|
|
|
|
4 * self.hidden_size, self.hidden_size, bias_=True, dtype=torch.qint8)
|
|
|
|
|
else:
|
|
|
|
|
self.dense_h_to_4h = nn.Linear(hidden_size, 4 * hidden_size)
|
|
|
|
|
self.dense_4h_to_h = nn.Linear(4 * hidden_size, hidden_size)
|
|
|
|
|
self.hidden_dropout = config.hidden_dropout
|
|
|
|
|
self.gelu_impl = BloomGelu()
|
|
|
|
|
|
|
|
|
|
def forward(self, hidden_states, residual):
|
|
|
|
|
hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))
|
|
|
|
|
|
|
|
|
|
if self.pretraining_tp > 1 and self.slow_but_exact:
|
|
|
|
|
intermediate_output = torch.zeros_like(residual)
|
|
|
|
|
slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp
|
|
|
|
|
for i in range(self.pretraining_tp):
|
|
|
|
|
intermediate_output = intermediate_output + nn.functional.linear(
|
|
|
|
|
hidden_states[:, :, int(i * slices) : int((i + 1) * slices)],
|
|
|
|
|
self.dense_4h_to_h.weight[:, int(i * slices) : int((i + 1) * slices)],
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
intermediate_output = self.dense_4h_to_h(hidden_states)
|
|
|
|
|
|
|
|
|
|
intermediate_output = self.dense_4h_to_h(hidden_states)
|
|
|
|
|
output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
|
|
|
|
|
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -207,10 +192,10 @@ class BloomBlock(nn.Module):
|
|
|
|
|
super().__init__()
|
|
|
|
|
hidden_size = config.hidden_size
|
|
|
|
|
|
|
|
|
|
self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
|
|
|
|
self.input_layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
|
|
|
|
self.n_head = config.n_head
|
|
|
|
|
self.self_attention = BloomAttention(config, layer_number=layer_number)
|
|
|
|
|
self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
|
|
|
|
self.post_attention_layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
|
|
|
|
|
|
|
|
|
self.mlp = BloomMLP(config)
|
|
|
|
|
|
|
|
|
|