|
|
|
@ -17,10 +17,10 @@ from transformers.models.falcon.modeling_falcon import (
|
|
|
|
|
FalconLinear,
|
|
|
|
|
FalconMLP,
|
|
|
|
|
FalconModel,
|
|
|
|
|
FalconRotaryEmbedding,
|
|
|
|
|
LayerNorm,
|
|
|
|
|
build_alibi_tensor,
|
|
|
|
|
dropout_add,
|
|
|
|
|
FalconRotaryEmbedding,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
|
|
|
@ -109,9 +109,10 @@ class OptimizedFalconRotaryEmbedding(nn.Module):
|
|
|
|
|
# else:
|
|
|
|
|
return apply_rotary(query, key, cos, sin)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# @torch.jit.script
|
|
|
|
|
def split_heads(
|
|
|
|
|
fused_qkv: torch.Tensor, num_heads:int, num_kv_heads:int, head_dim:int
|
|
|
|
|
fused_qkv: torch.Tensor, num_heads: int, num_kv_heads: int, head_dim: int
|
|
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
|
|
|
batch, seq_len, _ = fused_qkv.shape
|
|
|
|
|
qkv = fused_qkv.view(batch, seq_len, -1, num_heads // num_kv_heads + 2, head_dim)
|
|
|
|
|