|
|
|
@ -110,8 +110,6 @@ def split_heads(
|
|
|
|
|
class OptimizedFalconAttention(FalconAttention):
|
|
|
|
|
def __init__(self, config: FalconConfig):
|
|
|
|
|
nn.Module.__init__(self)
|
|
|
|
|
assert config.new_decoder_architecture
|
|
|
|
|
assert config.attention_dropout == 0.0
|
|
|
|
|
|
|
|
|
|
self.hidden_size = config.hidden_size
|
|
|
|
|
self.num_heads = config.num_attention_heads
|
|
|
|
@ -130,21 +128,26 @@ class OptimizedFalconAttention(FalconAttention):
|
|
|
|
|
# Layer-wise attention scaling
|
|
|
|
|
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
|
|
|
|
|
self.beta = self.inv_norm_factor
|
|
|
|
|
|
|
|
|
|
qkv_out_dim = (config.num_kv_heads * 2 + config.num_attention_heads) * self.head_dim
|
|
|
|
|
|
|
|
|
|
if config.new_decoder_architecture:
|
|
|
|
|
qkv_out_dim = (config.num_kv_heads * 2 + config.num_attention_heads) * self.head_dim
|
|
|
|
|
elif config.multi_query:
|
|
|
|
|
qkv_out_dim = self.hidden_size + 2 * self.head_dim
|
|
|
|
|
else:
|
|
|
|
|
qkv_out_dim = 3 * self.hidden_size
|
|
|
|
|
self.query_key_value = FalconLinear(self.hidden_size, qkv_out_dim, bias=config.bias)
|
|
|
|
|
self.new_decoder_architecture = config.new_decoder_architecture
|
|
|
|
|
self.multi_query = config.multi_query
|
|
|
|
|
self.dense = FalconLinear(self.hidden_size, self.hidden_size, bias=config.bias)
|
|
|
|
|
self.num_kv_heads = config.num_kv_heads
|
|
|
|
|
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
|
|
|
|
self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1
|
|
|
|
|
|
|
|
|
|
self._split_heads = partial(
|
|
|
|
|
split_heads, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, head_dim=self.head_dim
|
|
|
|
|
)
|
|
|
|
|
self.qkv_graph = None
|
|
|
|
|
self.input_surface = None
|
|
|
|
|
self.static_outputs = None
|
|
|
|
|
if self.new_decoder_architecture:
|
|
|
|
|
self._split_heads = partial(
|
|
|
|
|
split_heads, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, head_dim=self.head_dim
|
|
|
|
|
)
|
|
|
|
|
self.qkv_graph = None
|
|
|
|
|
self.input_surface = None
|
|
|
|
|
self.static_outputs = None
|
|
|
|
|
|
|
|
|
|
def _optimized_apply_qkv(self, hidden_states):
|
|
|
|
|
if self.qkv_graph is None:
|
|
|
|
@ -180,7 +183,7 @@ class OptimizedFalconAttention(FalconAttention):
|
|
|
|
|
assert alibi is None
|
|
|
|
|
assert not output_attentions
|
|
|
|
|
|
|
|
|
|
if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled():
|
|
|
|
|
if self.new_decoder_architecture and hidden_states.size(1) == 1 and torch.is_inference_mode_enabled():
|
|
|
|
|
query_layer, key_layer, value_layer = self._optimized_apply_qkv(hidden_states)
|
|
|
|
|
else:
|
|
|
|
|
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
|
|
|
@ -236,18 +239,30 @@ class OptimizedFalconDecoderLayer(FalconDecoderLayer):
|
|
|
|
|
nn.Module.__init__(self)
|
|
|
|
|
hidden_size = config.hidden_size
|
|
|
|
|
self.num_heads = config.num_attention_heads
|
|
|
|
|
self.self_attention = OptimizedFalconAttention(config)
|
|
|
|
|
|
|
|
|
|
self.mlp = FalconMLP(config)
|
|
|
|
|
self.hidden_dropout = config.hidden_dropout
|
|
|
|
|
self.config = config
|
|
|
|
|
|
|
|
|
|
assert not self.config.alibi
|
|
|
|
|
assert config.new_decoder_architecture
|
|
|
|
|
self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
|
|
|
|
self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
|
|
|
|
self.self_attention = OptimizedFalconAttention(config)
|
|
|
|
|
|
|
|
|
|
if self.config.alibi or not config.new_decoder_architecture:
|
|
|
|
|
if config.new_decoder_architecture:
|
|
|
|
|
# The layer norm before self-attention
|
|
|
|
|
self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
|
|
|
|
# The layer norm before the MLP
|
|
|
|
|
self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
|
|
|
|
else:
|
|
|
|
|
self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
|
|
|
|
if not config.parallel_attn:
|
|
|
|
|
self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
|
|
|
|
self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
|
|
|
|
|
|
|
|
|
self.ln_graph = None
|
|
|
|
|
self.static_input = None
|
|
|
|
|
self.ln_graph = None
|
|
|
|
|
self.static_input = None
|
|
|
|
|
|
|
|
|
|
def _optimized_apply_ln(self, hidden_states):
|
|
|
|
|
if self.ln_graph is None:
|
|
|
|
@ -283,11 +298,14 @@ class OptimizedFalconDecoderLayer(FalconDecoderLayer):
|
|
|
|
|
):
|
|
|
|
|
residual = hidden_states
|
|
|
|
|
|
|
|
|
|
if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled():
|
|
|
|
|
attention_layernorm_out, mlp_layernorm_out = self._optimized_apply_ln(hidden_states)
|
|
|
|
|
if self.config.new_decoder_architecture:
|
|
|
|
|
if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled():
|
|
|
|
|
attention_layernorm_out, mlp_layernorm_out = self._optimized_apply_ln(hidden_states)
|
|
|
|
|
else:
|
|
|
|
|
attention_layernorm_out = self.ln_attn(hidden_states)
|
|
|
|
|
mlp_layernorm_out = self.ln_mlp(hidden_states)
|
|
|
|
|
else:
|
|
|
|
|
attention_layernorm_out = self.ln_attn(hidden_states)
|
|
|
|
|
mlp_layernorm_out = self.ln_mlp(hidden_states)
|
|
|
|
|
attention_layernorm_out = self.input_layernorm(hidden_states)
|
|
|
|
|
|
|
|
|
|
attn_outputs = self.self_attention(
|
|
|
|
|
attention_layernorm_out,
|
|
|
|
@ -300,10 +318,22 @@ class OptimizedFalconDecoderLayer(FalconDecoderLayer):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
attention_output = attn_outputs[0]
|
|
|
|
|
|
|
|
|
|
if not self.config.new_decoder_architecture:
|
|
|
|
|
if self.config.parallel_attn:
|
|
|
|
|
mlp_layernorm_out = attention_layernorm_out
|
|
|
|
|
else:
|
|
|
|
|
residual = dropout_add(
|
|
|
|
|
attention_output, residual, self.config.attention_dropout, training=self.training
|
|
|
|
|
)
|
|
|
|
|
mlp_layernorm_out = self.post_attention_layernorm(residual)
|
|
|
|
|
|
|
|
|
|
outputs = attn_outputs[1:]
|
|
|
|
|
|
|
|
|
|
mlp_output = self.mlp(mlp_layernorm_out)
|
|
|
|
|
mlp_output += attention_output
|
|
|
|
|
|
|
|
|
|
if self.config.new_decoder_architecture or self.config.parallel_attn:
|
|
|
|
|
mlp_output += attention_output
|
|
|
|
|
|
|
|
|
|
output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training)
|
|
|
|
|
|
|
|
|
|