Post-rebase changes

pull/500/head
Max Ryabinin 9 months ago
parent 1f006c59a1
commit 1fc22bd69f

@ -363,18 +363,17 @@ class WrappedFalconBlock(OptimizedFalconDecoderLayer):
batch_size_x_num_kv_heads, seq_len, head_dim = state.shape
batch_size = batch_size_x_num_kv_heads // self.config.num_kv_heads
state = state.view(batch_size, 1, self.config.num_kv_heads, seq_len, head_dim)
# Here, .expand() doesn't allocate new memory, instead uses stride=0 along dim=1
state = state.expand(-1, self.config.num_key_value_groups, -1, -1, -1)
state = state.reshape(batch_size * self.config.num_attention_heads, seq_len, head_dim)
state = state.view(batch_size, self.config.num_kv_heads, 1, seq_len, head_dim)
state = state.expand(-1, -1, self.config.num_key_value_groups, -1, -1) # No copy
state = state.reshape(batch_size * self.config.num_attention_heads, seq_len, head_dim) # Involves a copy
return state
def _collapse_states(self, state: torch.Tensor) -> torch.Tensor:
batch_size_x_num_attn_heads, seq_len, head_dim = state.shape
batch_size = batch_size_x_num_attn_heads // self.config.num_attention_heads
state = state.view(batch_size, self.config.num_key_value_groups, self.config.num_kv_heads, seq_len, head_dim)
state = state[:, 0]
state = state.view(batch_size, self.config.num_kv_heads, self.config.num_key_value_groups, seq_len, head_dim)
state = state[:, :, 0]
state = state.view(batch_size * self.config.num_kv_heads, seq_len, head_dim)
return state
@ -387,7 +386,6 @@ class UnoptimizedWrappedFalconBlock(FalconDecoderLayer):
attention_mask: Optional[torch.Tensor] = None,
alibi: Optional[torch.Tensor] = None,
layer_past: Optional[KVCache] = None,
layer_past: Optional[KVCache] = None,
use_cache: bool = False,
**kwargs,
):

Loading…
Cancel
Save