|
|
|
@ -76,20 +76,20 @@ class WrappedFalconBlock(FalconDecoderLayer):
|
|
|
|
|
return (key_states, value_states)
|
|
|
|
|
|
|
|
|
|
def _expand_states(self, state: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
# Shape: [batch_size * num_kv_heads, seq_len, head_dim] -> [batch_size * num_attn_heads, seq_len, head_dim]
|
|
|
|
|
batch_size_x_num_kv_heads, seq_len, head_dim = state.shape
|
|
|
|
|
batch_size = batch_size_x_num_kv_heads // self.config.num_kv_heads
|
|
|
|
|
|
|
|
|
|
_, seq_len, head_dim = state.shape
|
|
|
|
|
state = state.view(-1, 1, self.config.num_kv_heads, seq_len, head_dim)
|
|
|
|
|
state = state.view(batch_size, 1, self.config.num_kv_heads, seq_len, head_dim)
|
|
|
|
|
# Here, .expand() doesn't allocate new memory, instead uses stride=0 along dim=1
|
|
|
|
|
state = state.expand(-1, self.config.num_key_value_groups, self.config.num_kv_heads, seq_len, head_dim)
|
|
|
|
|
state = state.reshape(-1, seq_len, head_dim)
|
|
|
|
|
state = state.expand(-1, self.config.num_key_value_groups, -1, -1, -1)
|
|
|
|
|
state = state.reshape(batch_size * self.config.num_attention_heads, seq_len, head_dim)
|
|
|
|
|
return state
|
|
|
|
|
|
|
|
|
|
def _collapse_states(self, state: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
# Shape: [batch_size * num_attn_heads, seq_len, head_dim] -> [batch_size * num_kv_heads, seq_len, head_dim]
|
|
|
|
|
batch_size_x_num_attn_heads, seq_len, head_dim = state.shape
|
|
|
|
|
batch_size = batch_size_x_num_attn_heads // self.config.num_attention_heads
|
|
|
|
|
|
|
|
|
|
_, seq_len, head_dim = state.shape
|
|
|
|
|
state = state.view(-1, self.config.num_key_value_groups, self.config.num_kv_heads, seq_len, head_dim)
|
|
|
|
|
state = state.view(batch_size, self.config.num_key_value_groups, self.config.num_kv_heads, seq_len, head_dim)
|
|
|
|
|
state = state[:, 0]
|
|
|
|
|
state = state.view(-1, seq_len, head_dim)
|
|
|
|
|
state = state.view(batch_size * self.config.num_kv_heads, seq_len, head_dim)
|
|
|
|
|
return state
|
|
|
|
|