|
|
|
@ -6,6 +6,7 @@ See commit history for authorship.
|
|
|
|
|
from typing import Optional, Tuple
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
|
|
|
|
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel, build_alibi_tensor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -26,7 +27,13 @@ class WrappedBloomBlock(BloomBlock):
|
|
|
|
|
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
|
|
|
|
|
if alibi is None:
|
|
|
|
|
alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
|
|
|
|
|
attention_mask = BloomModel._prepare_attn_mask(None, attention_mask, (batch_size, seq_length), past_length)
|
|
|
|
|
fake_inputs_embeds = torch.tensor([42], dtype=torch.float32)
|
|
|
|
|
attention_mask = _prepare_4d_causal_attention_mask(
|
|
|
|
|
attention_mask=attention_mask,
|
|
|
|
|
input_shape=(batch_size, seq_length),
|
|
|
|
|
inputs_embeds=fake_inputs_embeds,
|
|
|
|
|
past_key_values_length=past_length,
|
|
|
|
|
)
|
|
|
|
|
return super().forward(
|
|
|
|
|
hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs
|
|
|
|
|
)
|
|
|
|
|