account for layer_past in alibi

pull/9/head
justheuristic 2 years ago
parent fb3bfbb78f
commit 3ccd0b5e2d

@ -74,6 +74,9 @@ class BloomAttention(nn.Module):
output_attentions=False,
):
if alibi is None: # TODO OPTIMIZE ALIBI CREATION
current_sequence_length = hidden_states.shape[1]
if layer_past is not None:
current_sequence_length += layer_past[0].shape[1]
alibi = build_alibi_tensor(hidden_states.shape[1], n_head=self.num_heads, dtype=hidden_states.dtype)
# hidden_states: [batch_size, seq_length, hidden_size]
# repeat alibi tensor with the batch size

Loading…
Cancel
Save