mirror of
https://github.com/bigscience-workshop/petals
synced 2024-11-19 21:25:38 +00:00
causal mask by default
This commit is contained in:
parent
1ab5fb1630
commit
5d8f7be546
@ -74,18 +74,19 @@ class BloomAttention(nn.Module):
|
||||
use_cache=False,
|
||||
output_attentions=False,
|
||||
):
|
||||
if alibi is None: # TODO OPTIMIZE ALIBI CREATION
|
||||
current_sequence_length = hidden_states.shape[1]
|
||||
if layer_past is not None:
|
||||
current_sequence_length += layer_past[0].shape[1]
|
||||
alibi = build_alibi_tensor(hidden_states.shape[1], n_head=self.num_heads, dtype=hidden_states.dtype)
|
||||
# hidden_states: [batch_size, seq_length, hidden_size]
|
||||
# repeat alibi tensor with the batch size
|
||||
alibi = alibi.repeat(hidden_states.shape[0], 1, 1).to(hidden_states.device) # TODO eliminate cpu-gpu transfer!
|
||||
if alibi is None:
|
||||
current_sequence_length = hidden_states.shape[1] + (0 if layer_past is None else layer_past[0].shape[1])
|
||||
alibi = build_alibi_tensor(
|
||||
current_sequence_length, n_head=self.num_heads, dtype=hidden_states.dtype, device=hidden_states.device
|
||||
)
|
||||
|
||||
# hidden_states: [batch_size, seq_length, hidden_size]
|
||||
# apply preprocessing if the input is padded
|
||||
if attention_mask is not None and 0 in attention_mask: # TODO REMOVE CUDA SYNC
|
||||
alibi = pre_process_alibi_for_pad(alibi, attention_mask, self.num_heads)
|
||||
if attention_mask is not None:
|
||||
alibi = pre_process_alibi_for_pad(alibi, attention_mask)
|
||||
# otherwise repeat alibi tensor with the batch size
|
||||
else:
|
||||
alibi = alibi.repeat(hidden_states.shape[0], 1, 1)
|
||||
|
||||
mixed_x_layer = self.query_key_value(hidden_states)
|
||||
|
||||
@ -115,19 +116,16 @@ class BloomAttention(nn.Module):
|
||||
# [batch_size, k_length, num_heads, head_dim] -> [k_length, batch_size * num_heads, head_dim]
|
||||
key_layer = key_layer.transpose(1, 0).reshape(output_size[3], output_size[0] * output_size[1], -1)
|
||||
|
||||
# slice alibi tensor until the query length
|
||||
sliced_alibi = alibi[: output_size[0] * output_size[1], :, : output_size[3]]
|
||||
|
||||
# Raw attention scores. [batch_size * num_heads, q_length, k_length]
|
||||
beta = 1.0 / self.layer_number
|
||||
|
||||
matmul_result = torch.baddbmm(
|
||||
sliced_alibi,
|
||||
alibi,
|
||||
query_layer.transpose(1, 0),
|
||||
key_layer.transpose(1, 0).transpose(1, 2),
|
||||
beta=beta,
|
||||
alpha=(1.0 / self.norm_factor),
|
||||
) # TODO if end up creating alibi inside forward, consider setting out=sliced_alibi for memory efficiency
|
||||
)
|
||||
|
||||
# change view to [batch_size, num_heads, q_length, k_length]
|
||||
attention_scores = matmul_result.view(*output_size)
|
||||
|
102
src/bloom/ops.py
102
src/bloom/ops.py
@ -8,6 +8,7 @@ import math
|
||||
import torch
|
||||
import torch.autograd
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False):
|
||||
@ -55,13 +56,14 @@ def attention_mask_func(attention_scores, attention_mask, causal_mask):
|
||||
)
|
||||
|
||||
|
||||
def build_alibi_tensor(max_seq_len, n_head, dtype=torch.bfloat16):
|
||||
def build_alibi_tensor(
|
||||
max_seq_len: int, n_head: int, dtype: torch.dtype = torch.bfloat16, device: torch.device = torch.device("cpu")
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
|
||||
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
|
||||
`softmax(l+a) = softmax(l)`. Based on
|
||||
https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
|
||||
|
||||
Args:
|
||||
Returns tensor shaped (n_head, 1, max_seq_len)
|
||||
max_seq_len: (`int`, *required*):
|
||||
@ -70,67 +72,41 @@ def build_alibi_tensor(max_seq_len, n_head, dtype=torch.bfloat16):
|
||||
number of heads
|
||||
dtype: (`torch.dtype`, *optional*, default=`torch.bfloat16`):
|
||||
dtype of the output tensor
|
||||
device: (`torch.device`, *optional*, default=`torch.device('cpu')`):
|
||||
device of the output alibi tensor
|
||||
"""
|
||||
closest_power_of_2 = 2 ** math.floor(math.log2(n_head))
|
||||
base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=device, dtype=torch.float32)
|
||||
powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32)
|
||||
slopes = torch.pow(base, powers)
|
||||
|
||||
def get_slopes(n):
|
||||
def get_slopes_power_of_2(n):
|
||||
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
|
||||
ratio = start
|
||||
return [start * ratio**i for i in range(n)]
|
||||
if closest_power_of_2 != n_head:
|
||||
extra_base = torch.tensor(
|
||||
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=device, dtype=torch.float32
|
||||
)
|
||||
num_remaining_heads = min(closest_power_of_2, n_head - closest_power_of_2)
|
||||
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32)
|
||||
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
||||
|
||||
if math.log2(n).is_integer():
|
||||
return get_slopes_power_of_2(n)
|
||||
else:
|
||||
closest_power_of_2 = 2 ** math.floor(math.log2(n))
|
||||
return (
|
||||
get_slopes_power_of_2(closest_power_of_2)
|
||||
+ get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
|
||||
)
|
||||
|
||||
slopes = torch.Tensor(get_slopes(n_head)).unsqueeze(1).unsqueeze(1)
|
||||
arange_tensor = torch.arange(max_seq_len).unsqueeze(0).unsqueeze(0)
|
||||
alibi = slopes * arange_tensor.expand(n_head, -1, -1)
|
||||
|
||||
alibi = alibi.to(dtype)
|
||||
|
||||
return alibi
|
||||
lengths = torch.arange(max_seq_len, device=device, dtype=torch.int32)
|
||||
return (slopes.view(-1, 1, 1) * lengths.view(1, 1, -1)).to(dtype)
|
||||
|
||||
|
||||
def pre_process_alibi_for_pad(alibi, attention_mask, num_heads):
|
||||
def pre_process_alibi_for_pad(alibi: torch.Tensor, attention_mask: torch.Tensor):
|
||||
"""
|
||||
Args:
|
||||
Pre-process the alibi tensor for padding.
|
||||
alibi: ([`torch.tensor`], *required*):
|
||||
alibi tensor to pre-process
|
||||
attention_mask: ([`torch.tensor`], *required*):
|
||||
attention mask to pre-process"""
|
||||
|
||||
# Sanity check if we are not inferring less tokens than the total sequence length
|
||||
# This usually happens when the inference is done with past_key_values
|
||||
# In this case we re-create the alibi tensor with the correct sequence length
|
||||
if attention_mask.shape[-1] != alibi.shape[-1]:
|
||||
alibi = build_alibi_tensor(attention_mask.shape[-1], num_heads, alibi.dtype).repeat(
|
||||
attention_mask.shape[0], 1, 1
|
||||
)
|
||||
# Get the indexes of the padding tokens
|
||||
index_x0, index_y0 = torch.where(attention_mask == 0.0)
|
||||
index_x1, index_y1 = torch.where(attention_mask == 1.0)
|
||||
|
||||
# Clone the embeddings - we can detach because the embeddings are not learned
|
||||
# Get a refence tensor
|
||||
slice_reference_alibi = build_alibi_tensor(alibi.shape[-1], num_heads, alibi.dtype)
|
||||
|
||||
# Loop over the batch where the padding is and replace the alibi tensor by the reference tensor
|
||||
# Only where you do not have padding. Replace padding tokens by zeros
|
||||
# This operation can be seen as a shifting operation.
|
||||
for i, index in enumerate(torch.unique(index_x0)):
|
||||
slice_to_modify = torch.zeros_like(slice_reference_alibi)
|
||||
index_shift = index_y1[index_x1 == index]
|
||||
shift_value = len(index_shift)
|
||||
slice_to_modify[:, :, index_shift] = slice_reference_alibi[:, :, :shift_value]
|
||||
alibi[index * num_heads : (index + 1) * num_heads] = slice_to_modify
|
||||
return alibi
|
||||
|
||||
attention mask to pre-process
|
||||
"""
|
||||
assert attention_mask.shape.ndim == 2, "mask should be [batch_size, seq_length]"
|
||||
unpadded_indices = torch.relu(attention_mask.cumsum(dim=1) - 1)
|
||||
# ^-- [batch, max_len], values correspond to element indices after removing padding
|
||||
# We shift the alibi tensor + replace all the values where attention_mask==0.0 by 0
|
||||
alibi = alibi.take_along_dim(unpadded_indices.unsqueeze(0), -1) * attention_mask.unsqueeze(0)
|
||||
return alibi.reshape(alibi.shape[0] * alibi.shape[1], 1, -1)
|
||||
|
||||
def dropout_add(x, residual, prob, training):
|
||||
"""
|
||||
@ -251,17 +227,17 @@ class BloomScaledSoftmax(nn.Module):
|
||||
if self.scale is not None:
|
||||
input = input * self.scale
|
||||
|
||||
if mask is not None:
|
||||
mask = mask.to(input.device)
|
||||
causal_mask = (
|
||||
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool))
|
||||
.view(1, 1, max_positions, max_positions)
|
||||
.to(input.device)
|
||||
)
|
||||
mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask)
|
||||
probs = nn.functional.softmax(mask_output, dim=-1, dtype=softmax_dtype) * (~padded_causal_mask)
|
||||
else:
|
||||
probs = nn.functional.softmax(input, dim=-1, dtype=softmax_dtype)
|
||||
if mask is None:
|
||||
mask = torch.ones(input.shape[:2], dtype=torch.bool, device=input.device)
|
||||
|
||||
mask = mask.to(input.device)
|
||||
causal_mask = (
|
||||
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool))
|
||||
.view(1, 1, max_positions, max_positions)
|
||||
.to(input.device)
|
||||
)
|
||||
mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask)
|
||||
probs = F.softmax(mask_output, dim=-1, dtype=softmax_dtype) * (~padded_causal_mask)
|
||||
|
||||
if input_in_16bit and self.softmax_in_fp32:
|
||||
probs = probs.to(dtype=input_dtype)
|
||||
|
@ -35,7 +35,7 @@ class TransformerBackend(ModuleBackend):
|
||||
print("METADATA:", cache_metadata)
|
||||
assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5
|
||||
layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length]
|
||||
print(past_k.shape, past_v.shape)
|
||||
print('PAST', past_k.shape, past_v.shape)
|
||||
hidden_states, (new_k, new_v) = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True)
|
||||
|
||||
# todo remove these asserts once we pass all tests
|
||||
|
Loading…
Reference in New Issue
Block a user