|
|
|
@ -74,6 +74,9 @@ class BloomAttention(nn.Module):
|
|
|
|
|
output_attentions=False,
|
|
|
|
|
):
|
|
|
|
|
if alibi is None: # TODO OPTIMIZE ALIBI CREATION
|
|
|
|
|
current_sequence_length = hidden_states.shape[1]
|
|
|
|
|
if layer_past is not None:
|
|
|
|
|
current_sequence_length += layer_past[0].shape[1]
|
|
|
|
|
alibi = build_alibi_tensor(hidden_states.shape[1], n_head=self.num_heads, dtype=hidden_states.dtype)
|
|
|
|
|
# hidden_states: [batch_size, seq_length, hidden_size]
|
|
|
|
|
# repeat alibi tensor with the batch size
|
|
|
|
|