pull/553/head
Artem Chumachenko 3 months ago
parent 7b6224d0cf
commit 08bbbd38f0

@ -2,12 +2,12 @@ from typing import Optional, Tuple
import torch
from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralModel
from transformers.cache_utils import DynamicCache
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.cache_utils import DynamicCache
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralModel
class WrappedMixtralBlock(MixtralDecoderLayer):
@ -38,7 +38,9 @@ class WrappedMixtralBlock(MixtralDecoderLayer):
_past_key_value = self._reorder_cache_from_bloom(past_key_value, batch_size, past_key_values_length)
past_key_value = DynamicCache()
for idx in range(self.layer_idx):
past_key_value.update(torch.empty(_past_key_value[0].size()), torch.empty(_past_key_value[1].size()), idx)
past_key_value.update(
torch.empty(_past_key_value[0].size()), torch.empty(_past_key_value[1].size()), idx
)
past_key_value.update(_past_key_value[0], _past_key_value[1], self.layer_idx)
if self._attn_implementation == "flash_attention_2":
@ -81,9 +83,7 @@ class WrappedMixtralBlock(MixtralDecoderLayer):
if use_cache:
present_key_value = outputs[-1]
present_key_value = present_key_value.to_legacy_cache()[self.layer_idx]
present_key_value = self._reorder_cache_to_bloom(
present_key_value, batch_size, seq_length_with_past
)
present_key_value = self._reorder_cache_to_bloom(present_key_value, batch_size, seq_length_with_past)
outputs = outputs[:-1] + (present_key_value,)
return outputs

@ -135,7 +135,9 @@ class DistributedMixtralModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMi
return self.norm
class DistributedMixtralForCausalLM(DefaultRevisionMixin, FromPretrainedMixin, RemoteGenerationMixin, MixtralForCausalLM):
class DistributedMixtralForCausalLM(
DefaultRevisionMixin, FromPretrainedMixin, RemoteGenerationMixin, MixtralForCausalLM
):
_keys_to_ignore_on_load_missing = DistributedMixtralModel._keys_to_ignore_on_load_missing
_keys_to_ignore_on_load_unexpected = DistributedMixtralModel._keys_to_ignore_on_load_unexpected
@ -160,7 +162,6 @@ class DistributedMixtralForCausalLM(DefaultRevisionMixin, FromPretrainedMixin, R
class DistributedMixtralForSequenceClassification(
DefaultRevisionMixin, FromPretrainedMixin, MixtralForSequenceClassification
):
def __init__(self, config: DistributedMixtralConfig):
MixtralPreTrainedModel.__init__(self, config)
self.num_labels = config.num_labels

Loading…
Cancel
Save