You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
179 lines
5.2 KiB
Python
179 lines
5.2 KiB
Python
import math
|
|
|
|
import torch
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
|
|
@triton.jit
|
|
def _fwd_kernel(
|
|
Q,
|
|
K,
|
|
V,
|
|
sm_scale,
|
|
Out,
|
|
stride_qz,
|
|
stride_qh,
|
|
stride_qm,
|
|
stride_qk,
|
|
stride_kz,
|
|
stride_kh,
|
|
stride_kn,
|
|
stride_kk,
|
|
stride_vz,
|
|
stride_vh,
|
|
stride_vk,
|
|
stride_vn,
|
|
stride_oz,
|
|
stride_oh,
|
|
stride_om,
|
|
stride_on,
|
|
N_HEAD,
|
|
H,
|
|
N_CTX,
|
|
start_position, # <- ADDED
|
|
IS_CAUSAL: tl.constexpr, # <- ADDED
|
|
BLOCK_M: tl.constexpr,
|
|
BLOCK_N: tl.constexpr,
|
|
BLOCK_DMODEL: tl.constexpr,
|
|
USE_FP8: tl.constexpr,
|
|
):
|
|
start_m = tl.program_id(0)
|
|
|
|
head_idx = tl.program_id(1)
|
|
batch_id = head_idx // N_HEAD
|
|
off_hz = head_idx % N_HEAD
|
|
|
|
# initialize offsets
|
|
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
offs_n = tl.arange(0, BLOCK_N)
|
|
offs_d = tl.arange(0, BLOCK_DMODEL)
|
|
off_q = (
|
|
batch_id * stride_qz + off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
|
|
) # <- stride fixed
|
|
off_k = (
|
|
batch_id * stride_kz + off_hz * stride_kh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
|
|
) # <- stride fixed
|
|
off_v = (
|
|
batch_id * stride_vz + off_hz * stride_vh + offs_n[:, None] * stride_vk + offs_d[None, :] * stride_vn
|
|
) # <- stride fixed
|
|
# Initialize pointers to Q, K, V
|
|
q_ptrs = Q + off_q
|
|
k_ptrs = K + off_k
|
|
v_ptrs = V + off_v
|
|
# initialize pointer to m and l
|
|
m_prev = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
|
l_prev = tl.zeros([BLOCK_M], dtype=tl.float32)
|
|
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
|
# load q: it will stay in SRAM throughout
|
|
q = tl.load(q_ptrs, offs_m[:, None] < H, other=0.0)
|
|
# loop over k, v and update accumulator
|
|
block_n_end = N_CTX # <- ADDED (including the IF)
|
|
if IS_CAUSAL:
|
|
# in causal mode, we expect that BLOCK_M_SIZE == BLOCK_N_SIZE
|
|
# autotune will prune shapes not matching this rule
|
|
block_n_end = (start_m + 1) * BLOCK_N + start_position
|
|
for start_n in range(0, block_n_end, BLOCK_N):
|
|
block_n_offs = start_n + offs_n # <- ADDED
|
|
# -- compute qk ----
|
|
k = tl.load(k_ptrs, block_n_offs[:, None] < N_CTX, 0.0)
|
|
if USE_FP8:
|
|
k = k.to(tl.float8e5, bitcast=True)
|
|
k = k.to(tl.float16)
|
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
|
qk += tl.dot(q, tl.trans(k))
|
|
qk = tl.where(offs_n[None, :] < N_CTX, qk, float("-inf")) # <- ADDED
|
|
qk *= sm_scale
|
|
if IS_CAUSAL: # <- ADDED
|
|
qk = tl.where(offs_m[:, None] >= (block_n_offs[None, :] + start_position), qk, float("-inf"))
|
|
|
|
# compute new m
|
|
m_curr = tl.maximum(tl.max(qk, 1), m_prev)
|
|
# correct old l
|
|
l_prev *= tl.exp(m_prev - m_curr)
|
|
# attention weights
|
|
p = tl.exp(qk - m_curr[:, None])
|
|
l_curr = tl.sum(p, 1) + l_prev
|
|
# rescale operands of matmuls
|
|
l_rcp = 1.0 / l_curr
|
|
p *= l_rcp[:, None]
|
|
acc *= (l_prev * l_rcp)[:, None]
|
|
# update acc
|
|
p = p.to(Q.dtype.element_ty)
|
|
v = tl.load(v_ptrs, block_n_offs[:, None] < N_CTX, 0.0)
|
|
if USE_FP8:
|
|
v = v.to(tl.float8e5, bitcast=True)
|
|
v = v.to(tl.float16)
|
|
acc += tl.dot(p, v)
|
|
# update m_i and l_i
|
|
l_prev = l_curr
|
|
m_prev = m_curr
|
|
# update pointers
|
|
k_ptrs += BLOCK_N * stride_kn
|
|
v_ptrs += BLOCK_N * stride_vk
|
|
# rematerialize offsets to save registers
|
|
start_m = tl.program_id(0)
|
|
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
|
|
# initialize pointers to output
|
|
offs_d = tl.arange(0, BLOCK_DMODEL)
|
|
off_o = batch_id * stride_oz + off_hz * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on
|
|
out_ptrs = Out + off_o
|
|
tl.store(out_ptrs, acc, offs_m[:, None] < H)
|
|
|
|
|
|
def triton_fa(q, k, v, sm_scale, is_causal, start_position):
|
|
assert q.dtype == torch.float16
|
|
assert k.dtype == v.dtype and k.dtype in [torch.float16, torch.int8]
|
|
|
|
BLOCK = 64
|
|
# shape constraints
|
|
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
|
assert Lq == Lk and Lk == Lv
|
|
assert Lk in {16, 32, 64, 128}
|
|
o = torch.empty_like(q)
|
|
num_warps = 4 if Lk <= 64 else 8
|
|
batch, head_size, m_size, dhead = q.size()
|
|
grid = (triton.cdiv(m_size, BLOCK), head_size * batch)
|
|
n_size = k.size(2)
|
|
_fwd_kernel[grid](
|
|
q,
|
|
k,
|
|
v,
|
|
sm_scale,
|
|
o,
|
|
q.stride(0),
|
|
q.stride(1),
|
|
q.stride(2),
|
|
q.stride(3),
|
|
k.stride(0),
|
|
k.stride(1),
|
|
k.stride(2),
|
|
k.stride(3),
|
|
v.stride(0),
|
|
v.stride(1),
|
|
v.stride(2),
|
|
v.stride(3),
|
|
o.stride(0),
|
|
o.stride(1),
|
|
o.stride(2),
|
|
o.stride(3),
|
|
head_size,
|
|
m_size,
|
|
n_size,
|
|
start_position=start_position,
|
|
IS_CAUSAL=is_causal,
|
|
BLOCK_M=BLOCK,
|
|
BLOCK_N=BLOCK,
|
|
BLOCK_DMODEL=Lk,
|
|
USE_FP8=k.dtype == torch.int8, # USE_FP8
|
|
num_warps=num_warps,
|
|
num_stages=2,
|
|
)
|
|
|
|
return o
|
|
|
|
|
|
def attention_triton_wrapper(q, k, v, head_dim):
|
|
return triton_fa(q, k, v, sm_scale=1 / math.sqrt(head_dim), is_causal=True, start_position=0)
|