Fix formatting

pull/500/head
Max Ryabinin 9 months ago
parent 1f2ef79da3
commit b941df5d2f

@ -17,10 +17,10 @@ from transformers.models.falcon.modeling_falcon import (
FalconLinear,
FalconMLP,
FalconModel,
FalconRotaryEmbedding,
LayerNorm,
build_alibi_tensor,
dropout_add,
FalconRotaryEmbedding,
)
KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -109,9 +109,10 @@ class OptimizedFalconRotaryEmbedding(nn.Module):
# else:
return apply_rotary(query, key, cos, sin)
# @torch.jit.script
def split_heads(
fused_qkv: torch.Tensor, num_heads:int, num_kv_heads:int, head_dim:int
fused_qkv: torch.Tensor, num_heads: int, num_kv_heads: int, head_dim: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
batch, seq_len, _ = fused_qkv.shape
qkv = fused_qkv.view(batch, seq_len, -1, num_heads // num_kv_heads + 2, head_dim)

Loading…
Cancel
Save