Make the block compatible with other architectures

pull/500/head
Max Ryabinin 9 months ago
parent 2c1452de5c
commit ae30427276

@ -110,8 +110,6 @@ def split_heads(
class OptimizedFalconAttention(FalconAttention):
def __init__(self, config: FalconConfig):
nn.Module.__init__(self)
assert config.new_decoder_architecture
assert config.attention_dropout == 0.0
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
@ -130,21 +128,26 @@ class OptimizedFalconAttention(FalconAttention):
# Layer-wise attention scaling
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
self.beta = self.inv_norm_factor
qkv_out_dim = (config.num_kv_heads * 2 + config.num_attention_heads) * self.head_dim
if config.new_decoder_architecture:
qkv_out_dim = (config.num_kv_heads * 2 + config.num_attention_heads) * self.head_dim
elif config.multi_query:
qkv_out_dim = self.hidden_size + 2 * self.head_dim
else:
qkv_out_dim = 3 * self.hidden_size
self.query_key_value = FalconLinear(self.hidden_size, qkv_out_dim, bias=config.bias)
self.new_decoder_architecture = config.new_decoder_architecture
self.multi_query = config.multi_query
self.dense = FalconLinear(self.hidden_size, self.hidden_size, bias=config.bias)
self.num_kv_heads = config.num_kv_heads
self.attention_dropout = nn.Dropout(config.attention_dropout)
self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1
self._split_heads = partial(
split_heads, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, head_dim=self.head_dim
)
self.qkv_graph = None
self.input_surface = None
self.static_outputs = None
if self.new_decoder_architecture:
self._split_heads = partial(
split_heads, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, head_dim=self.head_dim
)
self.qkv_graph = None
self.input_surface = None
self.static_outputs = None
def _optimized_apply_qkv(self, hidden_states):
if self.qkv_graph is None:
@ -180,7 +183,7 @@ class OptimizedFalconAttention(FalconAttention):
assert alibi is None
assert not output_attentions
if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled():
if self.new_decoder_architecture and hidden_states.size(1) == 1 and torch.is_inference_mode_enabled():
query_layer, key_layer, value_layer = self._optimized_apply_qkv(hidden_states)
else:
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
@ -236,18 +239,30 @@ class OptimizedFalconDecoderLayer(FalconDecoderLayer):
nn.Module.__init__(self)
hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.self_attention = OptimizedFalconAttention(config)
self.mlp = FalconMLP(config)
self.hidden_dropout = config.hidden_dropout
self.config = config
assert not self.config.alibi
assert config.new_decoder_architecture
self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.self_attention = OptimizedFalconAttention(config)
if self.config.alibi or not config.new_decoder_architecture:
if config.new_decoder_architecture:
# The layer norm before self-attention
self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
# The layer norm before the MLP
self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
else:
self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
if not config.parallel_attn:
self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
else:
self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.ln_graph = None
self.static_input = None
self.ln_graph = None
self.static_input = None
def _optimized_apply_ln(self, hidden_states):
if self.ln_graph is None:
@ -283,11 +298,14 @@ class OptimizedFalconDecoderLayer(FalconDecoderLayer):
):
residual = hidden_states
if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled():
attention_layernorm_out, mlp_layernorm_out = self._optimized_apply_ln(hidden_states)
if self.config.new_decoder_architecture:
if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled():
attention_layernorm_out, mlp_layernorm_out = self._optimized_apply_ln(hidden_states)
else:
attention_layernorm_out = self.ln_attn(hidden_states)
mlp_layernorm_out = self.ln_mlp(hidden_states)
else:
attention_layernorm_out = self.ln_attn(hidden_states)
mlp_layernorm_out = self.ln_mlp(hidden_states)
attention_layernorm_out = self.input_layernorm(hidden_states)
attn_outputs = self.self_attention(
attention_layernorm_out,
@ -300,10 +318,22 @@ class OptimizedFalconDecoderLayer(FalconDecoderLayer):
)
attention_output = attn_outputs[0]
if not self.config.new_decoder_architecture:
if self.config.parallel_attn:
mlp_layernorm_out = attention_layernorm_out
else:
residual = dropout_add(
attention_output, residual, self.config.attention_dropout, training=self.training
)
mlp_layernorm_out = self.post_attention_layernorm(residual)
outputs = attn_outputs[1:]
mlp_output = self.mlp(mlp_layernorm_out)
mlp_output += attention_output
if self.config.new_decoder_architecture or self.config.parallel_attn:
mlp_output += attention_output
output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training)

@ -6,6 +6,7 @@ from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, Falco
from petals.utils.auto_config import AutoDistributedConfig
from petals.utils.convert_block import QuantType, convert_block
from test_utils import MODEL_NAME
KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -93,11 +94,10 @@ class UnoptimizedWrappedFalconBlock(FalconDecoderLayer):
return state
@pytest.mark.skipif("falcon" not in MODEL_NAME, reason="This test is applicable only to Falcon models")
@pytest.mark.forked
def test_falcon():
config = AutoDistributedConfig.from_pretrained("tiiuae/falcon-rw-1b")
config.alibi = False
config.new_decoder_architecture = True
config = AutoDistributedConfig.from_pretrained(MODEL_NAME)
device = "cpu"
tensor_parallel_devices = (device,)

Loading…
Cancel
Save