|
|
|
@ -15,7 +15,7 @@ from src.ops import (
|
|
|
|
|
attention_mask_func,
|
|
|
|
|
dropout_add,
|
|
|
|
|
pre_process_alibi_for_pad,
|
|
|
|
|
split_tensor_along_last_dim,
|
|
|
|
|
split_tensor_along_last_dim, build_alibi_tensor,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -73,6 +73,8 @@ class BloomAttention(nn.Module):
|
|
|
|
|
use_cache=False,
|
|
|
|
|
output_attentions=False,
|
|
|
|
|
):
|
|
|
|
|
if alibi is None:
|
|
|
|
|
alibi = build_alibi_tensor(hidden_states.shape[1], n_head=self.num_heads, dtype=hidden_states.dtype)
|
|
|
|
|
# hidden_states: [batch_size, seq_length, hidden_size]
|
|
|
|
|
# repeat alibi tensor with the batch size
|
|
|
|
|
alibi = alibi.repeat(hidden_states.shape[0], 1, 1).to(hidden_states.device)
|
|
|
|
|