You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
petals/src/petals/triton/attention.py

179 lines
5.2 KiB
Python

import math
import torch
import triton
import triton.language as tl
@triton.jit
def _fwd_kernel(
Q,
K,
V,
sm_scale,
Out,
stride_qz,
stride_qh,
stride_qm,
stride_qk,
stride_kz,
stride_kh,
stride_kn,
stride_kk,
stride_vz,
stride_vh,
stride_vk,
stride_vn,
stride_oz,
stride_oh,
stride_om,
stride_on,
N_HEAD,
H,
N_CTX,
start_position, # <- ADDED
IS_CAUSAL: tl.constexpr, # <- ADDED
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
USE_FP8: tl.constexpr,
):
start_m = tl.program_id(0)
head_idx = tl.program_id(1)
batch_id = head_idx // N_HEAD
off_hz = head_idx % N_HEAD
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
off_q = (
batch_id * stride_qz + off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
) # <- stride fixed
off_k = (
batch_id * stride_kz + off_hz * stride_kh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
) # <- stride fixed
off_v = (
batch_id * stride_vz + off_hz * stride_vh + offs_n[:, None] * stride_vk + offs_d[None, :] * stride_vn
) # <- stride fixed
# Initialize pointers to Q, K, V
q_ptrs = Q + off_q
k_ptrs = K + off_k
v_ptrs = V + off_v
# initialize pointer to m and l
m_prev = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_prev = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# load q: it will stay in SRAM throughout
q = tl.load(q_ptrs, offs_m[:, None] < H, other=0.0)
# loop over k, v and update accumulator
block_n_end = N_CTX # <- ADDED (including the IF)
if IS_CAUSAL:
# in causal mode, we expect that BLOCK_M_SIZE == BLOCK_N_SIZE
# autotune will prune shapes not matching this rule
block_n_end = (start_m + 1) * BLOCK_N + start_position
for start_n in range(0, block_n_end, BLOCK_N):
block_n_offs = start_n + offs_n # <- ADDED
# -- compute qk ----
k = tl.load(k_ptrs, block_n_offs[:, None] < N_CTX, 0.0)
if USE_FP8:
k = k.to(tl.float8e5, bitcast=True)
k = k.to(tl.float16)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, tl.trans(k))
qk = tl.where(offs_n[None, :] < N_CTX, qk, float("-inf")) # <- ADDED
qk *= sm_scale
if IS_CAUSAL: # <- ADDED
qk = tl.where(offs_m[:, None] >= (block_n_offs[None, :] + start_position), qk, float("-inf"))
# compute new m
m_curr = tl.maximum(tl.max(qk, 1), m_prev)
# correct old l
l_prev *= tl.exp(m_prev - m_curr)
# attention weights
p = tl.exp(qk - m_curr[:, None])
l_curr = tl.sum(p, 1) + l_prev
# rescale operands of matmuls
l_rcp = 1.0 / l_curr
p *= l_rcp[:, None]
acc *= (l_prev * l_rcp)[:, None]
# update acc
p = p.to(Q.dtype.element_ty)
v = tl.load(v_ptrs, block_n_offs[:, None] < N_CTX, 0.0)
if USE_FP8:
v = v.to(tl.float8e5, bitcast=True)
v = v.to(tl.float16)
acc += tl.dot(p, v)
# update m_i and l_i
l_prev = l_curr
m_prev = m_curr
# update pointers
k_ptrs += BLOCK_N * stride_kn
v_ptrs += BLOCK_N * stride_vk
# rematerialize offsets to save registers
start_m = tl.program_id(0)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# initialize pointers to output
offs_d = tl.arange(0, BLOCK_DMODEL)
off_o = batch_id * stride_oz + off_hz * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on
out_ptrs = Out + off_o
tl.store(out_ptrs, acc, offs_m[:, None] < H)
def triton_fa(q, k, v, sm_scale, is_causal, start_position):
assert q.dtype == torch.float16
assert k.dtype == v.dtype and k.dtype in [torch.float16, torch.int8]
BLOCK = 64
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}
o = torch.empty_like(q)
num_warps = 4 if Lk <= 64 else 8
batch, head_size, m_size, dhead = q.size()
grid = (triton.cdiv(m_size, BLOCK), head_size * batch)
n_size = k.size(2)
_fwd_kernel[grid](
q,
k,
v,
sm_scale,
o,
q.stride(0),
q.stride(1),
q.stride(2),
q.stride(3),
k.stride(0),
k.stride(1),
k.stride(2),
k.stride(3),
v.stride(0),
v.stride(1),
v.stride(2),
v.stride(3),
o.stride(0),
o.stride(1),
o.stride(2),
o.stride(3),
head_size,
m_size,
n_size,
start_position=start_position,
IS_CAUSAL=is_causal,
BLOCK_M=BLOCK,
BLOCK_N=BLOCK,
BLOCK_DMODEL=Lk,
USE_FP8=k.dtype == torch.int8, # USE_FP8
num_warps=num_warps,
num_stages=2,
)
return o
def attention_triton_wrapper(q, k, v, head_dim):
return triton_fa(q, k, v, sm_scale=1 / math.sqrt(head_dim), is_causal=True, start_position=0)