|
|
@ -1,3 +1,4 @@
|
|
|
|
|
|
|
|
import json
|
|
|
|
from typing import Optional, Tuple
|
|
|
|
from typing import Optional, Tuple
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch
|
|
|
@ -33,16 +34,15 @@ class WrappedMixtralBlock(MixtralDecoderLayer):
|
|
|
|
past_key_values_length = 0
|
|
|
|
past_key_values_length = 0
|
|
|
|
|
|
|
|
|
|
|
|
past_key_value = layer_past
|
|
|
|
past_key_value = layer_past
|
|
|
|
|
|
|
|
|
|
|
|
if past_key_value is not None:
|
|
|
|
if past_key_value is not None:
|
|
|
|
past_key_values_length = past_key_value[0].shape[2]
|
|
|
|
past_key_values_length = past_key_value[0].shape[2]
|
|
|
|
seq_length_with_past = seq_length_with_past + past_key_values_length
|
|
|
|
seq_length_with_past = seq_length_with_past + past_key_values_length
|
|
|
|
_past_key_value = self._reorder_cache_from_bloom(past_key_value, batch_size, past_key_values_length)
|
|
|
|
_past_key_value = self._reorder_cache_from_bloom(past_key_value, batch_size, past_key_values_length)
|
|
|
|
past_key_value = DynamicCache()
|
|
|
|
past_key_value = DynamicCache()
|
|
|
|
for idx in range(self.layer_idx):
|
|
|
|
past_key_value.key_cache = [torch.empty(0) for _ in range(self.layer_idx)] + [_past_key_value[0]]
|
|
|
|
past_key_value.update(
|
|
|
|
past_key_value.value_cache = [torch.empty(0) for _ in range(self.layer_idx)] + [_past_key_value[1]]
|
|
|
|
torch.empty(_past_key_value[0].size()), torch.empty(_past_key_value[1].size()), idx
|
|
|
|
past_key_value._seen_tokens = past_key_values_length
|
|
|
|
)
|
|
|
|
|
|
|
|
past_key_value.update(_past_key_value[0], _past_key_value[1], self.layer_idx)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self._attn_implementation == "flash_attention_2":
|
|
|
|
if self._attn_implementation == "flash_attention_2":
|
|
|
|
# 2d mask is passed through the layers
|
|
|
|
# 2d mask is passed through the layers
|
|
|
@ -83,7 +83,7 @@ class WrappedMixtralBlock(MixtralDecoderLayer):
|
|
|
|
|
|
|
|
|
|
|
|
if use_cache:
|
|
|
|
if use_cache:
|
|
|
|
present_key_value = outputs[-1]
|
|
|
|
present_key_value = outputs[-1]
|
|
|
|
present_key_value = present_key_value.to_legacy_cache()[self.layer_idx]
|
|
|
|
present_key_value = present_key_value[self.layer_idx]
|
|
|
|
present_key_value = self._reorder_cache_to_bloom(present_key_value, batch_size, seq_length_with_past)
|
|
|
|
present_key_value = self._reorder_cache_to_bloom(present_key_value, batch_size, seq_length_with_past)
|
|
|
|
outputs = outputs[:-1] + (present_key_value,)
|
|
|
|
outputs = outputs[:-1] + (present_key_value,)
|
|
|
|
|
|
|
|
|
|
|
|