|
|
|
@ -9,6 +9,7 @@ from typing import Optional, Tuple
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
|
|
|
|
from transformers.models.llama.modeling_llama import (
|
|
|
|
|
LlamaAttention,
|
|
|
|
|
LlamaConfig,
|
|
|
|
@ -244,8 +245,8 @@ class WrappedLlamaBlock(OptimizedLlamaDecoderLayer):
|
|
|
|
|
attention_mask = torch.ones(
|
|
|
|
|
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
|
|
|
|
|
)
|
|
|
|
|
attention_mask = LlamaModel._prepare_decoder_attention_mask(
|
|
|
|
|
None, attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
|
|
|
|
|
attention_mask = _prepare_4d_causal_attention_mask(
|
|
|
|
|
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
outputs = super().forward(
|
|
|
|
|