|
|
|
@ -16,7 +16,9 @@ from transformers.models.falcon.modeling_falcon import (
|
|
|
|
|
FalconDecoderLayer,
|
|
|
|
|
FalconLinear,
|
|
|
|
|
FalconMLP,
|
|
|
|
|
FalconModel,
|
|
|
|
|
LayerNorm,
|
|
|
|
|
build_alibi_tensor,
|
|
|
|
|
dropout_add,
|
|
|
|
|
rotate_half,
|
|
|
|
|
)
|
|
|
|
@ -180,7 +182,6 @@ class OptimizedFalconAttention(FalconAttention):
|
|
|
|
|
use_cache: bool = False,
|
|
|
|
|
output_attentions: bool = False,
|
|
|
|
|
):
|
|
|
|
|
assert alibi is None
|
|
|
|
|
assert not output_attentions
|
|
|
|
|
|
|
|
|
|
if self.new_decoder_architecture and hidden_states.size(1) == 1 and torch.is_inference_mode_enabled():
|
|
|
|
@ -212,6 +213,7 @@ class OptimizedFalconAttention(FalconAttention):
|
|
|
|
|
key_layer = torch.cat((past_key, key_layer), dim=1)
|
|
|
|
|
value_layer = torch.cat((past_value, value_layer), dim=1)
|
|
|
|
|
|
|
|
|
|
_, kv_length, _ = key_layer.shape
|
|
|
|
|
if use_cache:
|
|
|
|
|
present = (key_layer, value_layer)
|
|
|
|
|
else:
|
|
|
|
@ -221,17 +223,59 @@ class OptimizedFalconAttention(FalconAttention):
|
|
|
|
|
key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
|
|
|
|
|
value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
|
|
|
|
|
|
|
|
|
|
attn_output = F.scaled_dot_product_attention(
|
|
|
|
|
query_layer_, key_layer_, value_layer_, attn_mask=None, dropout_p=0.0, is_causal=True
|
|
|
|
|
)
|
|
|
|
|
attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype)
|
|
|
|
|
|
|
|
|
|
if alibi is None:
|
|
|
|
|
attn_output = F.scaled_dot_product_attention(
|
|
|
|
|
query_layer_, key_layer_, value_layer_, attn_mask=attention_mask_float, dropout_p=0.0, is_causal=False
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim)
|
|
|
|
|
attn_output = attn_output.permute(0, 2, 1, 3)
|
|
|
|
|
attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
|
|
|
|
|
attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim)
|
|
|
|
|
attn_output = attn_output.permute(0, 2, 1, 3)
|
|
|
|
|
attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
|
|
|
|
|
|
|
|
|
|
output_tensor = self.dense(attn_output)
|
|
|
|
|
output_tensor = self.dense(attn_output)
|
|
|
|
|
|
|
|
|
|
return output_tensor, present
|
|
|
|
|
return output_tensor, present
|
|
|
|
|
else:
|
|
|
|
|
matmul_result = query_layer_ @ key_layer_.transpose(-1, -2)
|
|
|
|
|
|
|
|
|
|
# change view to [batch_size, num_heads, q_length, kv_length]
|
|
|
|
|
attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length)
|
|
|
|
|
|
|
|
|
|
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
|
|
|
|
|
input_dtype = attention_scores.dtype
|
|
|
|
|
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
|
|
|
|
|
if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
|
|
|
|
|
attention_scores = attention_scores.to(torch.float32)
|
|
|
|
|
# Matt (HF) note: We could possibly use F.scaled_dot_product_attention here too, by
|
|
|
|
|
# adding (alibi * self.inv_norm_factor) to attention_mask_float. I think this would be mathematically
|
|
|
|
|
# equivalent and more performant, but there might be a numerical difference. If you're reading this
|
|
|
|
|
# and you'd like to experiment and maybe file a PR, feel free!
|
|
|
|
|
attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)
|
|
|
|
|
attention_logits *= self.inv_norm_factor
|
|
|
|
|
attention_probs = F.softmax(attention_logits + attention_mask_float, dim=-1, dtype=hidden_states.dtype)
|
|
|
|
|
# [batch_size, num_heads, q_length, kv_length]
|
|
|
|
|
attention_probs = self.attention_dropout(attention_probs)
|
|
|
|
|
|
|
|
|
|
if head_mask is not None:
|
|
|
|
|
attention_probs = attention_probs * head_mask
|
|
|
|
|
|
|
|
|
|
# change view [batch_size, num_heads, q_length, kv_length]
|
|
|
|
|
attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length)
|
|
|
|
|
|
|
|
|
|
# matmul: [batch_size * num_heads, q_length, head_dim]
|
|
|
|
|
context_layer = (attention_probs_reshaped @ value_layer_).flatten(0, 1)
|
|
|
|
|
|
|
|
|
|
# change view [batch_size, q_length, num_heads * head_dim]
|
|
|
|
|
context_layer = self._merge_heads(context_layer)
|
|
|
|
|
|
|
|
|
|
output_tensor = self.dense(context_layer)
|
|
|
|
|
|
|
|
|
|
if output_attentions:
|
|
|
|
|
return output_tensor, present, attention_probs
|
|
|
|
|
else:
|
|
|
|
|
return output_tensor, present
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OptimizedFalconDecoderLayer(FalconDecoderLayer):
|
|
|
|
@ -352,20 +396,29 @@ class WrappedFalconBlock(OptimizedFalconDecoderLayer):
|
|
|
|
|
*args,
|
|
|
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
|
alibi: Optional[torch.Tensor] = None,
|
|
|
|
|
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
|
|
|
layer_past: Optional[KVCache] = None,
|
|
|
|
|
use_cache: bool = False,
|
|
|
|
|
**kwargs,
|
|
|
|
|
):
|
|
|
|
|
assert attention_mask is None
|
|
|
|
|
|
|
|
|
|
batch_size, seq_length = hidden_states.shape[:2]
|
|
|
|
|
|
|
|
|
|
if layer_past is not None:
|
|
|
|
|
layer_past = self._reorder_cache_from_bloom_to_falcon(layer_past)
|
|
|
|
|
past_length = 0 if layer_past is None else layer_past[0].shape[1]
|
|
|
|
|
seq_length_with_past = seq_length + past_length
|
|
|
|
|
|
|
|
|
|
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
|
|
|
|
|
if alibi is None and self.config.alibi:
|
|
|
|
|
alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
|
|
|
|
|
attention_mask = FalconModel._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length)
|
|
|
|
|
|
|
|
|
|
outputs = super().forward(
|
|
|
|
|
hidden_states,
|
|
|
|
|
*args,
|
|
|
|
|
attention_mask=None,
|
|
|
|
|
alibi=None,
|
|
|
|
|
attention_mask=attention_mask,
|
|
|
|
|
alibi=alibi,
|
|
|
|
|
layer_past=layer_past,
|
|
|
|
|
use_cache=use_cache,
|
|
|
|
|
**kwargs,
|
|
|
|
|