|
|
|
@ -363,18 +363,17 @@ class WrappedFalconBlock(OptimizedFalconDecoderLayer):
|
|
|
|
|
batch_size_x_num_kv_heads, seq_len, head_dim = state.shape
|
|
|
|
|
batch_size = batch_size_x_num_kv_heads // self.config.num_kv_heads
|
|
|
|
|
|
|
|
|
|
state = state.view(batch_size, 1, self.config.num_kv_heads, seq_len, head_dim)
|
|
|
|
|
# Here, .expand() doesn't allocate new memory, instead uses stride=0 along dim=1
|
|
|
|
|
state = state.expand(-1, self.config.num_key_value_groups, -1, -1, -1)
|
|
|
|
|
state = state.reshape(batch_size * self.config.num_attention_heads, seq_len, head_dim)
|
|
|
|
|
state = state.view(batch_size, self.config.num_kv_heads, 1, seq_len, head_dim)
|
|
|
|
|
state = state.expand(-1, -1, self.config.num_key_value_groups, -1, -1) # No copy
|
|
|
|
|
state = state.reshape(batch_size * self.config.num_attention_heads, seq_len, head_dim) # Involves a copy
|
|
|
|
|
return state
|
|
|
|
|
|
|
|
|
|
def _collapse_states(self, state: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
batch_size_x_num_attn_heads, seq_len, head_dim = state.shape
|
|
|
|
|
batch_size = batch_size_x_num_attn_heads // self.config.num_attention_heads
|
|
|
|
|
|
|
|
|
|
state = state.view(batch_size, self.config.num_key_value_groups, self.config.num_kv_heads, seq_len, head_dim)
|
|
|
|
|
state = state[:, 0]
|
|
|
|
|
state = state.view(batch_size, self.config.num_kv_heads, self.config.num_key_value_groups, seq_len, head_dim)
|
|
|
|
|
state = state[:, :, 0]
|
|
|
|
|
state = state.view(batch_size * self.config.num_kv_heads, seq_len, head_dim)
|
|
|
|
|
return state
|
|
|
|
|
|
|
|
|
@ -387,7 +386,6 @@ class UnoptimizedWrappedFalconBlock(FalconDecoderLayer):
|
|
|
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
|
alibi: Optional[torch.Tensor] = None,
|
|
|
|
|
layer_past: Optional[KVCache] = None,
|
|
|
|
|
layer_past: Optional[KVCache] = None,
|
|
|
|
|
use_cache: bool = False,
|
|
|
|
|
**kwargs,
|
|
|
|
|
):
|
|
|
|
|