|
|
|
@ -8,6 +8,8 @@ from typing import Optional, Tuple
|
|
|
|
|
import torch
|
|
|
|
|
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor
|
|
|
|
|
|
|
|
|
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class WrappedFalconBlock(FalconDecoderLayer):
|
|
|
|
|
def forward(
|
|
|
|
@ -16,7 +18,7 @@ class WrappedFalconBlock(FalconDecoderLayer):
|
|
|
|
|
*args,
|
|
|
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
|
alibi: Optional[torch.Tensor] = None,
|
|
|
|
|
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
|
|
|
layer_past: Optional[KVCache] = None,
|
|
|
|
|
use_cache: bool = False,
|
|
|
|
|
**kwargs
|
|
|
|
|
):
|
|
|
|
@ -44,15 +46,50 @@ class WrappedFalconBlock(FalconDecoderLayer):
|
|
|
|
|
|
|
|
|
|
if use_cache:
|
|
|
|
|
present_key_value = outputs[-1]
|
|
|
|
|
present_key_value = self._reorder_cache_from_bloom_to_falcon(present_key_value)
|
|
|
|
|
present_key_value = self._reorder_cache_from_falcon_to_bloom(present_key_value)
|
|
|
|
|
outputs = outputs[:-1] + (present_key_value,)
|
|
|
|
|
|
|
|
|
|
return outputs
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _reorder_cache_from_bloom_to_falcon(
|
|
|
|
|
key_value: Tuple[torch.Tensor, torch.Tensor]
|
|
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
def _reorder_cache_from_bloom_to_falcon(self, key_value: KVCache) -> KVCache:
|
|
|
|
|
key_states, value_states = key_value
|
|
|
|
|
|
|
|
|
|
key_states = key_states.permute(0, 2, 1)
|
|
|
|
|
assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim]
|
|
|
|
|
|
|
|
|
|
if self.config.new_decoder_architecture:
|
|
|
|
|
key_states = self._expand_states(key_states)
|
|
|
|
|
value_states = self._expand_states(value_states)
|
|
|
|
|
|
|
|
|
|
return (key_states, value_states)
|
|
|
|
|
|
|
|
|
|
def _reorder_cache_from_falcon_to_bloom(self, key_value: KVCache) -> KVCache:
|
|
|
|
|
key_states, value_states = key_value
|
|
|
|
|
|
|
|
|
|
if self.config.new_decoder_architecture:
|
|
|
|
|
key_states = self._collapse_states(key_states)
|
|
|
|
|
value_states = self._collapse_states(value_states)
|
|
|
|
|
|
|
|
|
|
assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim]
|
|
|
|
|
key_states = key_states.permute(0, 2, 1)
|
|
|
|
|
|
|
|
|
|
return (key_states, value_states)
|
|
|
|
|
|
|
|
|
|
def _expand_states(self, state: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
# Shape: [batch_size * num_kv_heads, seq_len, head_dim] -> [batch_size * num_attn_heads, seq_len, head_dim]
|
|
|
|
|
|
|
|
|
|
_, seq_len, head_dim = state.shape
|
|
|
|
|
state = state.view(-1, 1, self.config.num_kv_heads, seq_len, head_dim)
|
|
|
|
|
# Here, .expand() doesn't allocate new memory, instead uses stride=0 along dim=1
|
|
|
|
|
state = state.expand(-1, self.config.num_key_value_groups, self.config.num_kv_heads, seq_len, head_dim)
|
|
|
|
|
state = state.reshape(-1, seq_len, head_dim)
|
|
|
|
|
return state
|
|
|
|
|
|
|
|
|
|
def _collapse_states(self, state: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
# Shape: [batch_size * num_attn_heads, seq_len, head_dim] -> [batch_size * num_kv_heads, seq_len, head_dim]
|
|
|
|
|
|
|
|
|
|
_, seq_len, head_dim = state.shape
|
|
|
|
|
state = state.view(-1, self.config.num_key_value_groups, self.config.num_kv_heads, seq_len, head_dim)
|
|
|
|
|
state = state[:, 0]
|
|
|
|
|
state = state.view(-1, seq_len, head_dim)
|
|
|
|
|
return state
|
|
|
|
|