replaced call to _prepare_4d_causal_attention_mask

pull/545/head
younesbelkada 7 months ago
parent d59c15c578
commit 76479fdc43

@ -17,13 +17,7 @@ from petals.models import *
from petals.utils import *
from petals.utils.logging import initialize_logs as _initialize_logs
__version__ = "2.3.0.dev1"
if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
assert (
version.parse("4.32.0") <= version.parse(transformers.__version__) < version.parse("4.35.0")
), "Please install a proper transformers version: pip install transformers>=4.32.0,<4.35.0"
__version__ = "2.3.0.dev2"
def _override_bfloat16_mode_default():

@ -9,6 +9,7 @@ from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaConfig,
@ -244,8 +245,8 @@ class WrappedLlamaBlock(OptimizedLlamaDecoderLayer):
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
)
attention_mask = LlamaModel._prepare_decoder_attention_mask(
None, attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
)
outputs = super().forward(

Loading…
Cancel
Save