From 5d8f7be5466b7a40fb777bde973ca773b378e83a Mon Sep 17 00:00:00 2001 From: justheuristic Date: Mon, 20 Jun 2022 16:24:29 +0300 Subject: [PATCH] causal mask by default --- src/bloom/block.py | 28 ++++++------ src/bloom/ops.py | 102 ++++++++++++++++-------------------------- src/server/backend.py | 2 +- 3 files changed, 53 insertions(+), 79 deletions(-) diff --git a/src/bloom/block.py b/src/bloom/block.py index a6a5ddd..a87d2c3 100644 --- a/src/bloom/block.py +++ b/src/bloom/block.py @@ -74,18 +74,19 @@ class BloomAttention(nn.Module): use_cache=False, output_attentions=False, ): - if alibi is None: # TODO OPTIMIZE ALIBI CREATION - current_sequence_length = hidden_states.shape[1] - if layer_past is not None: - current_sequence_length += layer_past[0].shape[1] - alibi = build_alibi_tensor(hidden_states.shape[1], n_head=self.num_heads, dtype=hidden_states.dtype) - # hidden_states: [batch_size, seq_length, hidden_size] - # repeat alibi tensor with the batch size - alibi = alibi.repeat(hidden_states.shape[0], 1, 1).to(hidden_states.device) # TODO eliminate cpu-gpu transfer! + if alibi is None: + current_sequence_length = hidden_states.shape[1] + (0 if layer_past is None else layer_past[0].shape[1]) + alibi = build_alibi_tensor( + current_sequence_length, n_head=self.num_heads, dtype=hidden_states.dtype, device=hidden_states.device + ) + # hidden_states: [batch_size, seq_length, hidden_size] # apply preprocessing if the input is padded - if attention_mask is not None and 0 in attention_mask: # TODO REMOVE CUDA SYNC - alibi = pre_process_alibi_for_pad(alibi, attention_mask, self.num_heads) + if attention_mask is not None: + alibi = pre_process_alibi_for_pad(alibi, attention_mask) + # otherwise repeat alibi tensor with the batch size + else: + alibi = alibi.repeat(hidden_states.shape[0], 1, 1) mixed_x_layer = self.query_key_value(hidden_states) @@ -115,19 +116,16 @@ class BloomAttention(nn.Module): # [batch_size, k_length, num_heads, head_dim] -> [k_length, batch_size * num_heads, head_dim] key_layer = key_layer.transpose(1, 0).reshape(output_size[3], output_size[0] * output_size[1], -1) - # slice alibi tensor until the query length - sliced_alibi = alibi[: output_size[0] * output_size[1], :, : output_size[3]] - # Raw attention scores. [batch_size * num_heads, q_length, k_length] beta = 1.0 / self.layer_number matmul_result = torch.baddbmm( - sliced_alibi, + alibi, query_layer.transpose(1, 0), key_layer.transpose(1, 0).transpose(1, 2), beta=beta, alpha=(1.0 / self.norm_factor), - ) # TODO if end up creating alibi inside forward, consider setting out=sliced_alibi for memory efficiency + ) # change view to [batch_size, num_heads, q_length, k_length] attention_scores = matmul_result.view(*output_size) diff --git a/src/bloom/ops.py b/src/bloom/ops.py index 882e960..db0f0ed 100644 --- a/src/bloom/ops.py +++ b/src/bloom/ops.py @@ -8,6 +8,7 @@ import math import torch import torch.autograd from torch import nn +import torch.nn.functional as F def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False): @@ -55,13 +56,14 @@ def attention_mask_func(attention_scores, attention_mask, causal_mask): ) -def build_alibi_tensor(max_seq_len, n_head, dtype=torch.bfloat16): +def build_alibi_tensor( + max_seq_len: int, n_head: int, dtype: torch.dtype = torch.bfloat16, device: torch.device = torch.device("cpu") +) -> torch.Tensor: """ Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value `softmax(l+a) = softmax(l)`. Based on https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742 - Args: Returns tensor shaped (n_head, 1, max_seq_len) max_seq_len: (`int`, *required*): @@ -70,67 +72,41 @@ def build_alibi_tensor(max_seq_len, n_head, dtype=torch.bfloat16): number of heads dtype: (`torch.dtype`, *optional*, default=`torch.bfloat16`): dtype of the output tensor + device: (`torch.device`, *optional*, default=`torch.device('cpu')`): + device of the output alibi tensor """ + closest_power_of_2 = 2 ** math.floor(math.log2(n_head)) + base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=device, dtype=torch.float32) + powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32) + slopes = torch.pow(base, powers) - def get_slopes(n): - def get_slopes_power_of_2(n): - start = 2 ** (-(2 ** -(math.log2(n) - 3))) - ratio = start - return [start * ratio**i for i in range(n)] + if closest_power_of_2 != n_head: + extra_base = torch.tensor( + 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=device, dtype=torch.float32 + ) + num_remaining_heads = min(closest_power_of_2, n_head - closest_power_of_2) + extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32) + slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) - if math.log2(n).is_integer(): - return get_slopes_power_of_2(n) - else: - closest_power_of_2 = 2 ** math.floor(math.log2(n)) - return ( - get_slopes_power_of_2(closest_power_of_2) - + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] - ) - - slopes = torch.Tensor(get_slopes(n_head)).unsqueeze(1).unsqueeze(1) - arange_tensor = torch.arange(max_seq_len).unsqueeze(0).unsqueeze(0) - alibi = slopes * arange_tensor.expand(n_head, -1, -1) - - alibi = alibi.to(dtype) - - return alibi + lengths = torch.arange(max_seq_len, device=device, dtype=torch.int32) + return (slopes.view(-1, 1, 1) * lengths.view(1, 1, -1)).to(dtype) -def pre_process_alibi_for_pad(alibi, attention_mask, num_heads): +def pre_process_alibi_for_pad(alibi: torch.Tensor, attention_mask: torch.Tensor): """ Args: Pre-process the alibi tensor for padding. alibi: ([`torch.tensor`], *required*): alibi tensor to pre-process attention_mask: ([`torch.tensor`], *required*): - attention mask to pre-process""" - - # Sanity check if we are not inferring less tokens than the total sequence length - # This usually happens when the inference is done with past_key_values - # In this case we re-create the alibi tensor with the correct sequence length - if attention_mask.shape[-1] != alibi.shape[-1]: - alibi = build_alibi_tensor(attention_mask.shape[-1], num_heads, alibi.dtype).repeat( - attention_mask.shape[0], 1, 1 - ) - # Get the indexes of the padding tokens - index_x0, index_y0 = torch.where(attention_mask == 0.0) - index_x1, index_y1 = torch.where(attention_mask == 1.0) - - # Clone the embeddings - we can detach because the embeddings are not learned - # Get a refence tensor - slice_reference_alibi = build_alibi_tensor(alibi.shape[-1], num_heads, alibi.dtype) - - # Loop over the batch where the padding is and replace the alibi tensor by the reference tensor - # Only where you do not have padding. Replace padding tokens by zeros - # This operation can be seen as a shifting operation. - for i, index in enumerate(torch.unique(index_x0)): - slice_to_modify = torch.zeros_like(slice_reference_alibi) - index_shift = index_y1[index_x1 == index] - shift_value = len(index_shift) - slice_to_modify[:, :, index_shift] = slice_reference_alibi[:, :, :shift_value] - alibi[index * num_heads : (index + 1) * num_heads] = slice_to_modify - return alibi - + attention mask to pre-process + """ + assert attention_mask.shape.ndim == 2, "mask should be [batch_size, seq_length]" + unpadded_indices = torch.relu(attention_mask.cumsum(dim=1) - 1) + # ^-- [batch, max_len], values correspond to element indices after removing padding + # We shift the alibi tensor + replace all the values where attention_mask==0.0 by 0 + alibi = alibi.take_along_dim(unpadded_indices.unsqueeze(0), -1) * attention_mask.unsqueeze(0) + return alibi.reshape(alibi.shape[0] * alibi.shape[1], 1, -1) def dropout_add(x, residual, prob, training): """ @@ -251,17 +227,17 @@ class BloomScaledSoftmax(nn.Module): if self.scale is not None: input = input * self.scale - if mask is not None: - mask = mask.to(input.device) - causal_mask = ( - torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)) - .view(1, 1, max_positions, max_positions) - .to(input.device) - ) - mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask) - probs = nn.functional.softmax(mask_output, dim=-1, dtype=softmax_dtype) * (~padded_causal_mask) - else: - probs = nn.functional.softmax(input, dim=-1, dtype=softmax_dtype) + if mask is None: + mask = torch.ones(input.shape[:2], dtype=torch.bool, device=input.device) + + mask = mask.to(input.device) + causal_mask = ( + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)) + .view(1, 1, max_positions, max_positions) + .to(input.device) + ) + mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask) + probs = F.softmax(mask_output, dim=-1, dtype=softmax_dtype) * (~padded_causal_mask) if input_in_16bit and self.softmax_in_fp32: probs = probs.to(dtype=input_dtype) diff --git a/src/server/backend.py b/src/server/backend.py index 55b6b82..bdcbf1c 100644 --- a/src/server/backend.py +++ b/src/server/backend.py @@ -35,7 +35,7 @@ class TransformerBackend(ModuleBackend): print("METADATA:", cache_metadata) assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5 layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length] - print(past_k.shape, past_v.shape) + print('PAST', past_k.shape, past_v.shape) hidden_states, (new_k, new_v) = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True) # todo remove these asserts once we pass all tests