Improve test and compatibility

pull/500/head
Max Ryabinin 9 months ago
parent ae30427276
commit 841a0d5262

@ -16,7 +16,9 @@ from transformers.models.falcon.modeling_falcon import (
FalconDecoderLayer,
FalconLinear,
FalconMLP,
FalconModel,
LayerNorm,
build_alibi_tensor,
dropout_add,
rotate_half,
)
@ -180,7 +182,6 @@ class OptimizedFalconAttention(FalconAttention):
use_cache: bool = False,
output_attentions: bool = False,
):
assert alibi is None
assert not output_attentions
if self.new_decoder_architecture and hidden_states.size(1) == 1 and torch.is_inference_mode_enabled():
@ -212,6 +213,7 @@ class OptimizedFalconAttention(FalconAttention):
key_layer = torch.cat((past_key, key_layer), dim=1)
value_layer = torch.cat((past_value, value_layer), dim=1)
_, kv_length, _ = key_layer.shape
if use_cache:
present = (key_layer, value_layer)
else:
@ -221,17 +223,59 @@ class OptimizedFalconAttention(FalconAttention):
key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
attn_output = F.scaled_dot_product_attention(
query_layer_, key_layer_, value_layer_, attn_mask=None, dropout_p=0.0, is_causal=True
)
attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype)
if alibi is None:
attn_output = F.scaled_dot_product_attention(
query_layer_, key_layer_, value_layer_, attn_mask=attention_mask_float, dropout_p=0.0, is_causal=False
)
attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim)
attn_output = attn_output.permute(0, 2, 1, 3)
attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim)
attn_output = attn_output.permute(0, 2, 1, 3)
attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
output_tensor = self.dense(attn_output)
output_tensor = self.dense(attn_output)
return output_tensor, present
return output_tensor, present
else:
matmul_result = query_layer_ @ key_layer_.transpose(-1, -2)
# change view to [batch_size, num_heads, q_length, kv_length]
attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length)
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
input_dtype = attention_scores.dtype
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
attention_scores = attention_scores.to(torch.float32)
# Matt (HF) note: We could possibly use F.scaled_dot_product_attention here too, by
# adding (alibi * self.inv_norm_factor) to attention_mask_float. I think this would be mathematically
# equivalent and more performant, but there might be a numerical difference. If you're reading this
# and you'd like to experiment and maybe file a PR, feel free!
attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)
attention_logits *= self.inv_norm_factor
attention_probs = F.softmax(attention_logits + attention_mask_float, dim=-1, dtype=hidden_states.dtype)
# [batch_size, num_heads, q_length, kv_length]
attention_probs = self.attention_dropout(attention_probs)
if head_mask is not None:
attention_probs = attention_probs * head_mask
# change view [batch_size, num_heads, q_length, kv_length]
attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length)
# matmul: [batch_size * num_heads, q_length, head_dim]
context_layer = (attention_probs_reshaped @ value_layer_).flatten(0, 1)
# change view [batch_size, q_length, num_heads * head_dim]
context_layer = self._merge_heads(context_layer)
output_tensor = self.dense(context_layer)
if output_attentions:
return output_tensor, present, attention_probs
else:
return output_tensor, present
class OptimizedFalconDecoderLayer(FalconDecoderLayer):
@ -352,20 +396,29 @@ class WrappedFalconBlock(OptimizedFalconDecoderLayer):
*args,
attention_mask: Optional[torch.Tensor] = None,
alibi: Optional[torch.Tensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
layer_past: Optional[KVCache] = None,
use_cache: bool = False,
**kwargs,
):
assert attention_mask is None
batch_size, seq_length = hidden_states.shape[:2]
if layer_past is not None:
layer_past = self._reorder_cache_from_bloom_to_falcon(layer_past)
past_length = 0 if layer_past is None else layer_past[0].shape[1]
seq_length_with_past = seq_length + past_length
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
if alibi is None and self.config.alibi:
alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
attention_mask = FalconModel._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length)
outputs = super().forward(
hidden_states,
*args,
attention_mask=None,
alibi=None,
attention_mask=attention_mask,
alibi=alibi,
layer_past=layer_past,
use_cache=use_cache,
**kwargs,

@ -113,9 +113,12 @@ def test_falcon():
)
unopt_block.load_state_dict(block.state_dict())
cache = unopt_cache = None
for _ in range(3):
for l in range(3):
dummy_input = torch.randn(1, 1, config.hidden_size, device=device, dtype=dtype)
block_output = block(dummy_input)
unopt_block_output = unopt_block(dummy_input)
assert torch.allclose(block_output[0], unopt_block_output[0], atol=1e-6, rtol=0)
block_output, cache = block(dummy_input, layer_past=cache, use_cache=True)
unopt_block_output, unopt_cache = unopt_block(dummy_input, layer_past=unopt_cache, use_cache=True)
assert torch.allclose(block_output, unopt_block_output, atol=1e-6, rtol=0), l
assert torch.allclose(cache[0], unopt_cache[0], atol=1e-6, rtol=0), l
assert torch.allclose(cache[1], unopt_cache[1], atol=1e-6, rtol=0), l

Loading…
Cancel
Save