|
|
|
@ -26,7 +26,6 @@ class WrappedMixtralBlock(MixtralDecoderLayer):
|
|
|
|
|
use_cache: bool = False,
|
|
|
|
|
**kwargs
|
|
|
|
|
):
|
|
|
|
|
print(self.layer_idx)
|
|
|
|
|
batch_size, seq_length, _ = hidden_states.shape
|
|
|
|
|
|
|
|
|
|
seq_length_with_past = seq_length
|
|
|
|
@ -37,7 +36,6 @@ class WrappedMixtralBlock(MixtralDecoderLayer):
|
|
|
|
|
past_key_values_length = past_key_value[0].shape[2]
|
|
|
|
|
seq_length_with_past = seq_length_with_past + past_key_values_length
|
|
|
|
|
_past_key_value = self._reorder_cache_from_bloom(past_key_value, batch_size, past_key_values_length)
|
|
|
|
|
print(_past_key_value)
|
|
|
|
|
# TODO: remove DynamicCache
|
|
|
|
|
past_key_value = DynamicCache()
|
|
|
|
|
for idx in range(self.layer_idx):
|
|
|
|
@ -73,8 +71,19 @@ class WrappedMixtralBlock(MixtralDecoderLayer):
|
|
|
|
|
sliding_window=self.sliding_window,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
position_ids = torch.arange(
|
|
|
|
|
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=hidden_states.device
|
|
|
|
|
)
|
|
|
|
|
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
|
|
|
|
|
|
|
|
|
outputs = super().forward(
|
|
|
|
|
hidden_states, *args, attention_mask=attention_mask, past_key_value=past_key_value, use_cache=use_cache, **kwargs
|
|
|
|
|
hidden_states,
|
|
|
|
|
*args,
|
|
|
|
|
attention_mask=attention_mask,
|
|
|
|
|
position_ids=position_ids,
|
|
|
|
|
past_key_value=past_key_value,
|
|
|
|
|
use_cache=use_cache,
|
|
|
|
|
**kwargs
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if use_cache:
|
|
|
|
|