causal mask by default

This commit is contained in:
justheuristic 2022-06-20 16:24:29 +03:00
parent 1ab5fb1630
commit 5d8f7be546
3 changed files with 53 additions and 79 deletions

View File

@ -74,18 +74,19 @@ class BloomAttention(nn.Module):
use_cache=False,
output_attentions=False,
):
if alibi is None: # TODO OPTIMIZE ALIBI CREATION
current_sequence_length = hidden_states.shape[1]
if layer_past is not None:
current_sequence_length += layer_past[0].shape[1]
alibi = build_alibi_tensor(hidden_states.shape[1], n_head=self.num_heads, dtype=hidden_states.dtype)
# hidden_states: [batch_size, seq_length, hidden_size]
# repeat alibi tensor with the batch size
alibi = alibi.repeat(hidden_states.shape[0], 1, 1).to(hidden_states.device) # TODO eliminate cpu-gpu transfer!
if alibi is None:
current_sequence_length = hidden_states.shape[1] + (0 if layer_past is None else layer_past[0].shape[1])
alibi = build_alibi_tensor(
current_sequence_length, n_head=self.num_heads, dtype=hidden_states.dtype, device=hidden_states.device
)
# hidden_states: [batch_size, seq_length, hidden_size]
# apply preprocessing if the input is padded
if attention_mask is not None and 0 in attention_mask: # TODO REMOVE CUDA SYNC
alibi = pre_process_alibi_for_pad(alibi, attention_mask, self.num_heads)
if attention_mask is not None:
alibi = pre_process_alibi_for_pad(alibi, attention_mask)
# otherwise repeat alibi tensor with the batch size
else:
alibi = alibi.repeat(hidden_states.shape[0], 1, 1)
mixed_x_layer = self.query_key_value(hidden_states)
@ -115,19 +116,16 @@ class BloomAttention(nn.Module):
# [batch_size, k_length, num_heads, head_dim] -> [k_length, batch_size * num_heads, head_dim]
key_layer = key_layer.transpose(1, 0).reshape(output_size[3], output_size[0] * output_size[1], -1)
# slice alibi tensor until the query length
sliced_alibi = alibi[: output_size[0] * output_size[1], :, : output_size[3]]
# Raw attention scores. [batch_size * num_heads, q_length, k_length]
beta = 1.0 / self.layer_number
matmul_result = torch.baddbmm(
sliced_alibi,
alibi,
query_layer.transpose(1, 0),
key_layer.transpose(1, 0).transpose(1, 2),
beta=beta,
alpha=(1.0 / self.norm_factor),
) # TODO if end up creating alibi inside forward, consider setting out=sliced_alibi for memory efficiency
)
# change view to [batch_size, num_heads, q_length, k_length]
attention_scores = matmul_result.view(*output_size)

View File

@ -8,6 +8,7 @@ import math
import torch
import torch.autograd
from torch import nn
import torch.nn.functional as F
def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False):
@ -55,13 +56,14 @@ def attention_mask_func(attention_scores, attention_mask, causal_mask):
)
def build_alibi_tensor(max_seq_len, n_head, dtype=torch.bfloat16):
def build_alibi_tensor(
max_seq_len: int, n_head: int, dtype: torch.dtype = torch.bfloat16, device: torch.device = torch.device("cpu")
) -> torch.Tensor:
"""
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
`softmax(l+a) = softmax(l)`. Based on
https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
Args:
Returns tensor shaped (n_head, 1, max_seq_len)
max_seq_len: (`int`, *required*):
@ -70,67 +72,41 @@ def build_alibi_tensor(max_seq_len, n_head, dtype=torch.bfloat16):
number of heads
dtype: (`torch.dtype`, *optional*, default=`torch.bfloat16`):
dtype of the output tensor
device: (`torch.device`, *optional*, default=`torch.device('cpu')`):
device of the output alibi tensor
"""
closest_power_of_2 = 2 ** math.floor(math.log2(n_head))
base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=device, dtype=torch.float32)
powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32)
slopes = torch.pow(base, powers)
def get_slopes(n):
def get_slopes_power_of_2(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * ratio**i for i in range(n)]
if closest_power_of_2 != n_head:
extra_base = torch.tensor(
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=device, dtype=torch.float32
)
num_remaining_heads = min(closest_power_of_2, n_head - closest_power_of_2)
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32)
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
else:
closest_power_of_2 = 2 ** math.floor(math.log2(n))
return (
get_slopes_power_of_2(closest_power_of_2)
+ get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
)
slopes = torch.Tensor(get_slopes(n_head)).unsqueeze(1).unsqueeze(1)
arange_tensor = torch.arange(max_seq_len).unsqueeze(0).unsqueeze(0)
alibi = slopes * arange_tensor.expand(n_head, -1, -1)
alibi = alibi.to(dtype)
return alibi
lengths = torch.arange(max_seq_len, device=device, dtype=torch.int32)
return (slopes.view(-1, 1, 1) * lengths.view(1, 1, -1)).to(dtype)
def pre_process_alibi_for_pad(alibi, attention_mask, num_heads):
def pre_process_alibi_for_pad(alibi: torch.Tensor, attention_mask: torch.Tensor):
"""
Args:
Pre-process the alibi tensor for padding.
alibi: ([`torch.tensor`], *required*):
alibi tensor to pre-process
attention_mask: ([`torch.tensor`], *required*):
attention mask to pre-process"""
# Sanity check if we are not inferring less tokens than the total sequence length
# This usually happens when the inference is done with past_key_values
# In this case we re-create the alibi tensor with the correct sequence length
if attention_mask.shape[-1] != alibi.shape[-1]:
alibi = build_alibi_tensor(attention_mask.shape[-1], num_heads, alibi.dtype).repeat(
attention_mask.shape[0], 1, 1
)
# Get the indexes of the padding tokens
index_x0, index_y0 = torch.where(attention_mask == 0.0)
index_x1, index_y1 = torch.where(attention_mask == 1.0)
# Clone the embeddings - we can detach because the embeddings are not learned
# Get a refence tensor
slice_reference_alibi = build_alibi_tensor(alibi.shape[-1], num_heads, alibi.dtype)
# Loop over the batch where the padding is and replace the alibi tensor by the reference tensor
# Only where you do not have padding. Replace padding tokens by zeros
# This operation can be seen as a shifting operation.
for i, index in enumerate(torch.unique(index_x0)):
slice_to_modify = torch.zeros_like(slice_reference_alibi)
index_shift = index_y1[index_x1 == index]
shift_value = len(index_shift)
slice_to_modify[:, :, index_shift] = slice_reference_alibi[:, :, :shift_value]
alibi[index * num_heads : (index + 1) * num_heads] = slice_to_modify
return alibi
attention mask to pre-process
"""
assert attention_mask.shape.ndim == 2, "mask should be [batch_size, seq_length]"
unpadded_indices = torch.relu(attention_mask.cumsum(dim=1) - 1)
# ^-- [batch, max_len], values correspond to element indices after removing padding
# We shift the alibi tensor + replace all the values where attention_mask==0.0 by 0
alibi = alibi.take_along_dim(unpadded_indices.unsqueeze(0), -1) * attention_mask.unsqueeze(0)
return alibi.reshape(alibi.shape[0] * alibi.shape[1], 1, -1)
def dropout_add(x, residual, prob, training):
"""
@ -251,17 +227,17 @@ class BloomScaledSoftmax(nn.Module):
if self.scale is not None:
input = input * self.scale
if mask is not None:
mask = mask.to(input.device)
causal_mask = (
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool))
.view(1, 1, max_positions, max_positions)
.to(input.device)
)
mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask)
probs = nn.functional.softmax(mask_output, dim=-1, dtype=softmax_dtype) * (~padded_causal_mask)
else:
probs = nn.functional.softmax(input, dim=-1, dtype=softmax_dtype)
if mask is None:
mask = torch.ones(input.shape[:2], dtype=torch.bool, device=input.device)
mask = mask.to(input.device)
causal_mask = (
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool))
.view(1, 1, max_positions, max_positions)
.to(input.device)
)
mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask)
probs = F.softmax(mask_output, dim=-1, dtype=softmax_dtype) * (~padded_causal_mask)
if input_in_16bit and self.softmax_in_fp32:
probs = probs.to(dtype=input_dtype)

View File

@ -35,7 +35,7 @@ class TransformerBackend(ModuleBackend):
print("METADATA:", cache_metadata)
assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5
layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length]
print(past_k.shape, past_v.shape)
print('PAST', past_k.shape, past_v.shape)
hidden_states, (new_k, new_v) = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True)
# todo remove these asserts once we pass all tests