Fix cache reordering with seq_len = 0

pull/499/head
Aleksandr Borzunov 9 months ago
parent 4537c77004
commit f6553ad4cb

@ -76,20 +76,20 @@ class WrappedFalconBlock(FalconDecoderLayer):
return (key_states, value_states)
def _expand_states(self, state: torch.Tensor) -> torch.Tensor:
# Shape: [batch_size * num_kv_heads, seq_len, head_dim] -> [batch_size * num_attn_heads, seq_len, head_dim]
batch_size_x_num_kv_heads, seq_len, head_dim = state.shape
batch_size = batch_size_x_num_kv_heads // self.config.num_kv_heads
_, seq_len, head_dim = state.shape
state = state.view(-1, 1, self.config.num_kv_heads, seq_len, head_dim)
state = state.view(batch_size, 1, self.config.num_kv_heads, seq_len, head_dim)
# Here, .expand() doesn't allocate new memory, instead uses stride=0 along dim=1
state = state.expand(-1, self.config.num_key_value_groups, self.config.num_kv_heads, seq_len, head_dim)
state = state.reshape(-1, seq_len, head_dim)
state = state.expand(-1, self.config.num_key_value_groups, -1, -1, -1)
state = state.reshape(batch_size * self.config.num_attention_heads, seq_len, head_dim)
return state
def _collapse_states(self, state: torch.Tensor) -> torch.Tensor:
# Shape: [batch_size * num_attn_heads, seq_len, head_dim] -> [batch_size * num_kv_heads, seq_len, head_dim]
batch_size_x_num_attn_heads, seq_len, head_dim = state.shape
batch_size = batch_size_x_num_attn_heads // self.config.num_attention_heads
_, seq_len, head_dim = state.shape
state = state.view(-1, self.config.num_key_value_groups, self.config.num_kv_heads, seq_len, head_dim)
state = state.view(batch_size, self.config.num_key_value_groups, self.config.num_kv_heads, seq_len, head_dim)
state = state[:, 0]
state = state.view(-1, seq_len, head_dim)
state = state.view(batch_size * self.config.num_kv_heads, seq_len, head_dim)
return state

Loading…
Cancel
Save