Fix generation

pull/553/head
Artem Chumachenko 4 months ago
parent d275d79b72
commit 4cdd57cf49

@ -26,7 +26,6 @@ class WrappedMixtralBlock(MixtralDecoderLayer):
use_cache: bool = False,
**kwargs
):
print(self.layer_idx)
batch_size, seq_length, _ = hidden_states.shape
seq_length_with_past = seq_length
@ -37,7 +36,6 @@ class WrappedMixtralBlock(MixtralDecoderLayer):
past_key_values_length = past_key_value[0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
_past_key_value = self._reorder_cache_from_bloom(past_key_value, batch_size, past_key_values_length)
print(_past_key_value)
# TODO: remove DynamicCache
past_key_value = DynamicCache()
for idx in range(self.layer_idx):
@ -73,8 +71,19 @@ class WrappedMixtralBlock(MixtralDecoderLayer):
sliding_window=self.sliding_window,
)
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=hidden_states.device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
outputs = super().forward(
hidden_states, *args, attention_mask=attention_mask, past_key_value=past_key_value, use_cache=use_cache, **kwargs
hidden_states,
*args,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
use_cache=use_cache,
**kwargs
)
if use_cache:

@ -94,6 +94,10 @@ class DistributedMixtralModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMi
hidden_states = inputs_embeds
output_shape = input_shape + (hidden_states.size(-1),)
if past_key_values is None:
past_key_values = RemotePastKeyValues()
past_key_values.update_seen(hidden_states.size(1))
hidden_states = self.layers(
hidden_states,
prompts=intermediate_prompts,
@ -109,7 +113,7 @@ class DistributedMixtralModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMi
hidden_states = hidden_states.view(output_shape)
return MoeModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=RemotePastKeyValues(),
past_key_values=past_key_values,
hidden_states=None,
attentions=None,
)

Loading…
Cancel
Save