You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
269 lines
9.8 KiB
Python
269 lines
9.8 KiB
Python
"""
|
|
Utility operations used in the the BLOOM model
|
|
Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
|
|
"""
|
|
import math
|
|
|
|
import torch
|
|
import torch.autograd
|
|
from torch import nn
|
|
|
|
|
|
def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False):
|
|
"""Split a tensor along its last dimension.
|
|
|
|
Args:
|
|
tensor: ([`torch.tensor`], *required*):
|
|
input tensor to split
|
|
num_partitions ([`int`], *required*):
|
|
number of partitions to split the tensor
|
|
contiguous_split_chunks ([`bool`], *optional*, default=`False`)::
|
|
If True, make each chunk contiguous in memory.
|
|
"""
|
|
# Get the size and dimension.
|
|
last_dim = tensor.dim() - 1
|
|
numerator, denominator = tensor.size()[last_dim], num_partitions
|
|
if not (numerator % denominator == 0):
|
|
raise ValueError(f"{numerator} is not divisible by {denominator}")
|
|
last_dim_size = numerator // denominator
|
|
# Split.
|
|
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
|
|
# Note: torch.split does not create contiguous tensors by default.
|
|
if contiguous_split_chunks:
|
|
return tuple(chunk.contiguous() for chunk in tensor_list)
|
|
|
|
return tensor_list
|
|
|
|
|
|
def attention_mask_func(attention_scores, attention_mask, causal_mask):
|
|
if attention_mask.dtype == torch.bool:
|
|
attention_mask_bool = ~attention_mask
|
|
else:
|
|
attention_mask_bool = (1 - attention_mask).bool()
|
|
|
|
query_length, key_length, n_heads = attention_scores.size(2), attention_scores.size(3), attention_scores.size(1)
|
|
padded_causal_mask = (
|
|
attention_mask_bool[:, None, key_length - query_length : key_length, None]
|
|
+ ~causal_mask[:, :, key_length - query_length : key_length, :key_length]
|
|
).bool()
|
|
padded_causal_mask = padded_causal_mask + attention_mask_bool[:, None, None, :key_length].bool()
|
|
# Make use of floats
|
|
return (
|
|
attention_scores.masked_fill_(padded_causal_mask.expand(-1, n_heads, -1, -1), -10000.0),
|
|
padded_causal_mask,
|
|
)
|
|
|
|
|
|
def build_alibi_tensor(max_seq_len, n_head, dtype=torch.bfloat16):
|
|
"""
|
|
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
|
|
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
|
|
`softmax(l+a) = softmax(l)`. Based on
|
|
https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
|
|
|
|
Args:
|
|
Returns tensor shaped (n_head, 1, max_seq_len)
|
|
max_seq_len: (`int`, *required*):
|
|
max sequence length
|
|
n_head: (`int`, *required*):
|
|
number of heads
|
|
dtype: (`torch.dtype`, *optional*, default=`torch.bfloat16`):
|
|
dtype of the output tensor
|
|
"""
|
|
|
|
def get_slopes(n):
|
|
def get_slopes_power_of_2(n):
|
|
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
|
|
ratio = start
|
|
return [start * ratio**i for i in range(n)]
|
|
|
|
if math.log2(n).is_integer():
|
|
return get_slopes_power_of_2(n)
|
|
else:
|
|
closest_power_of_2 = 2 ** math.floor(math.log2(n))
|
|
return (
|
|
get_slopes_power_of_2(closest_power_of_2)
|
|
+ get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
|
|
)
|
|
|
|
slopes = torch.Tensor(get_slopes(n_head)).unsqueeze(1).unsqueeze(1)
|
|
arange_tensor = torch.arange(max_seq_len).unsqueeze(0).unsqueeze(0)
|
|
alibi = slopes * arange_tensor.expand(n_head, -1, -1)
|
|
|
|
alibi = alibi.to(dtype)
|
|
|
|
return alibi
|
|
|
|
|
|
def pre_process_alibi_for_pad(alibi, attention_mask, num_heads):
|
|
"""
|
|
Args:
|
|
Pre-process the alibi tensor for padding.
|
|
alibi: ([`torch.tensor`], *required*):
|
|
alibi tensor to pre-process
|
|
attention_mask: ([`torch.tensor`], *required*):
|
|
attention mask to pre-process"""
|
|
|
|
# Sanity check if we are not inferring less tokens than the total sequence length
|
|
# This usually happens when the inference is done with past_key_values
|
|
# In this case we re-create the alibi tensor with the correct sequence length
|
|
if attention_mask.shape[-1] != alibi.shape[-1]:
|
|
alibi = build_alibi_tensor(attention_mask.shape[-1], num_heads, alibi.dtype).repeat(
|
|
attention_mask.shape[0], 1, 1
|
|
)
|
|
# Get the indexes of the padding tokens
|
|
index_x0, index_y0 = torch.where(attention_mask == 0.0)
|
|
index_x1, index_y1 = torch.where(attention_mask == 1.0)
|
|
|
|
# Clone the embeddings - we can detach because the embeddings are not learned
|
|
# Get a refence tensor
|
|
slice_reference_alibi = build_alibi_tensor(alibi.shape[-1], num_heads, alibi.dtype)
|
|
|
|
# Loop over the batch where the padding is and replace the alibi tensor by the reference tensor
|
|
# Only where you do not have padding. Replace padding tokens by zeros
|
|
# This operation can be seen as a shifting operation.
|
|
for i, index in enumerate(torch.unique(index_x0)):
|
|
slice_to_modify = torch.zeros_like(slice_reference_alibi)
|
|
index_shift = index_y1[index_x1 == index]
|
|
shift_value = len(index_shift)
|
|
slice_to_modify[:, :, index_shift] = slice_reference_alibi[:, :, :shift_value]
|
|
alibi[index * num_heads : (index + 1) * num_heads] = slice_to_modify
|
|
return alibi
|
|
|
|
|
|
def dropout_add(x, residual, prob, training):
|
|
"""
|
|
Dropout add function
|
|
|
|
Args:
|
|
x (`torch.tensor`, *required*):
|
|
input tensor
|
|
residual (`torch.tensor`, *rquired*):
|
|
esidual tensor
|
|
prob (`float`, *required*):
|
|
dropout probability
|
|
training (`bool`, *required*):
|
|
training mode
|
|
"""
|
|
out = nn.functional.dropout(x, p=prob, training=training)
|
|
out = residual + out
|
|
return out
|
|
|
|
|
|
def bloom_gelu_forward(x):
|
|
"""
|
|
Custom bias GELU function. Adapted from Megatron-DeepSpeed code. Here we use a simple implementation (inference) to
|
|
make the model jitable.
|
|
|
|
Args:
|
|
x (`torch.tensor`, *required*):
|
|
input hidden states
|
|
"""
|
|
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
|
|
|
|
|
|
def bloom_gelu_back(g, x):
|
|
"""
|
|
gradient of tanh approximation of gelu gradient of actual gelu is: 0.5 * (1. + torch.erf(x * 0.70710678)) +
|
|
0.3989423 * x * torch.exp(-0.5 * x * x)
|
|
|
|
Args:
|
|
g (`torch.tensor`, *required*):
|
|
gradient output tensor
|
|
x (`torch.tensor`, *required*):
|
|
input tensor
|
|
"""
|
|
x = x[0] # x is a tuple of 1 element, needs to unpack it first
|
|
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
|
|
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
|
|
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
|
|
return ff * g
|
|
|
|
|
|
class GeLUFunction(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, input):
|
|
ctx.save_for_backward(input)
|
|
return bloom_gelu_forward(input)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
input = ctx.saved_tensors
|
|
tmp = bloom_gelu_back(grad_output, input)
|
|
return tmp
|
|
|
|
|
|
class BloomGelu(nn.Module):
|
|
"""
|
|
BloomBiasGelu wrapper function that make use of the simple function on inference mode to make the model
|
|
torchscriptable and use the autograd function in training mode to get the accurate results of the gradients Partly
|
|
copied from Megatron-DeepSpeed code and adapted for our needs
|
|
|
|
See here why autograd functions are not torchscriptable: https://github.com/pytorch/pytorch/issues/22329
|
|
|
|
"""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
if self.training:
|
|
return GeLUFunction.apply(x)
|
|
else:
|
|
return bloom_gelu_forward(x)
|
|
|
|
|
|
class BloomScaledSoftmax(nn.Module):
|
|
"""
|
|
fused operation: scaling + mask + softmax
|
|
|
|
Args:
|
|
input_in_fp16 (`bool`, *required*):
|
|
flag to indicate if input in fp16 data format.
|
|
input_in_bf16 (`bool`, *required*):
|
|
flag to indicate if input in bf16 data format.
|
|
scaled_masked_softmax_fusion (`bool`, *required*):
|
|
flag to indicate user want to use softmax fusion
|
|
mask_func (`function`, *required*):
|
|
mask function to be applied.
|
|
softmax_in_fp32 (`bool`, *required*):
|
|
if true, softmax in performed at fp32 precision.
|
|
scale (`float`, *required*):
|
|
scaling factor used in input tensor scaling.
|
|
"""
|
|
|
|
def __init__(self, scaled_masked_softmax_fusion, mask_func, softmax_in_fp32, scale):
|
|
super().__init__()
|
|
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
|
|
self.mask_func = mask_func
|
|
self.softmax_in_fp32 = softmax_in_fp32
|
|
self.scale = scale
|
|
|
|
if not (self.scale is None or softmax_in_fp32):
|
|
raise ValueError("softmax should be in fp32 when scaled")
|
|
|
|
def forward(self, input, mask, max_positions):
|
|
input_dtype = input.dtype
|
|
input_in_16bit = input_dtype in [torch.float16, torch.bfloat16]
|
|
softmax_dtype = torch.float32 if self.softmax_in_fp32 else input_dtype
|
|
|
|
if self.scale is not None:
|
|
input = input * self.scale
|
|
|
|
if mask is not None:
|
|
mask = mask.to(input.device)
|
|
causal_mask = (
|
|
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool))
|
|
.view(1, 1, max_positions, max_positions)
|
|
.to(input.device)
|
|
)
|
|
mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask)
|
|
probs = nn.functional.softmax(mask_output, dim=-1, dtype=softmax_dtype) * (~padded_causal_mask)
|
|
else:
|
|
probs = nn.functional.softmax(input, dim=-1, dtype=softmax_dtype)
|
|
|
|
if input_in_16bit and self.softmax_in_fp32:
|
|
probs = probs.to(dtype=input_dtype)
|
|
|
|
return probs
|