|
|
|
@ -2,12 +2,12 @@ from typing import Optional, Tuple
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from transformers import MixtralConfig
|
|
|
|
|
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralModel
|
|
|
|
|
from transformers.cache_utils import DynamicCache
|
|
|
|
|
from transformers.modeling_attn_mask_utils import (
|
|
|
|
|
_prepare_4d_causal_attention_mask,
|
|
|
|
|
_prepare_4d_causal_attention_mask_for_sdpa,
|
|
|
|
|
)
|
|
|
|
|
from transformers.cache_utils import DynamicCache
|
|
|
|
|
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralModel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class WrappedMixtralBlock(MixtralDecoderLayer):
|
|
|
|
@ -38,7 +38,9 @@ class WrappedMixtralBlock(MixtralDecoderLayer):
|
|
|
|
|
_past_key_value = self._reorder_cache_from_bloom(past_key_value, batch_size, past_key_values_length)
|
|
|
|
|
past_key_value = DynamicCache()
|
|
|
|
|
for idx in range(self.layer_idx):
|
|
|
|
|
past_key_value.update(torch.empty(_past_key_value[0].size()), torch.empty(_past_key_value[1].size()), idx)
|
|
|
|
|
past_key_value.update(
|
|
|
|
|
torch.empty(_past_key_value[0].size()), torch.empty(_past_key_value[1].size()), idx
|
|
|
|
|
)
|
|
|
|
|
past_key_value.update(_past_key_value[0], _past_key_value[1], self.layer_idx)
|
|
|
|
|
|
|
|
|
|
if self._attn_implementation == "flash_attention_2":
|
|
|
|
@ -81,9 +83,7 @@ class WrappedMixtralBlock(MixtralDecoderLayer):
|
|
|
|
|
if use_cache:
|
|
|
|
|
present_key_value = outputs[-1]
|
|
|
|
|
present_key_value = present_key_value.to_legacy_cache()[self.layer_idx]
|
|
|
|
|
present_key_value = self._reorder_cache_to_bloom(
|
|
|
|
|
present_key_value, batch_size, seq_length_with_past
|
|
|
|
|
)
|
|
|
|
|
present_key_value = self._reorder_cache_to_bloom(present_key_value, batch_size, seq_length_with_past)
|
|
|
|
|
outputs = outputs[:-1] + (present_key_value,)
|
|
|
|
|
|
|
|
|
|
return outputs
|
|
|
|
|