WIP Triton+QKV merge
parent
b4d822afb2
commit
fa464dfc99
@ -0,0 +1,3 @@
|
||||
from petals.triton.rmsnorm import rmsnorm_triton_wrapper
|
||||
from petals.triton.attention import attention_triton_wrapper
|
||||
from petals.triton.rotary import rbe_triton_wrapper
|
@ -0,0 +1,178 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
sm_scale,
|
||||
Out,
|
||||
stride_qz,
|
||||
stride_qh,
|
||||
stride_qm,
|
||||
stride_qk,
|
||||
stride_kz,
|
||||
stride_kh,
|
||||
stride_kn,
|
||||
stride_kk,
|
||||
stride_vz,
|
||||
stride_vh,
|
||||
stride_vk,
|
||||
stride_vn,
|
||||
stride_oz,
|
||||
stride_oh,
|
||||
stride_om,
|
||||
stride_on,
|
||||
N_HEAD,
|
||||
H,
|
||||
N_CTX,
|
||||
start_position, # <- ADDED
|
||||
IS_CAUSAL: tl.constexpr, # <- ADDED
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
USE_FP8: tl.constexpr,
|
||||
):
|
||||
start_m = tl.program_id(0)
|
||||
|
||||
head_idx = tl.program_id(1)
|
||||
batch_id = head_idx // N_HEAD
|
||||
off_hz = head_idx % N_HEAD
|
||||
|
||||
# initialize offsets
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
off_q = (
|
||||
batch_id * stride_qz + off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
|
||||
) # <- stride fixed
|
||||
off_k = (
|
||||
batch_id * stride_kz + off_hz * stride_kh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
|
||||
) # <- stride fixed
|
||||
off_v = (
|
||||
batch_id * stride_vz + off_hz * stride_vh + offs_n[:, None] * stride_vk + offs_d[None, :] * stride_vn
|
||||
) # <- stride fixed
|
||||
# Initialize pointers to Q, K, V
|
||||
q_ptrs = Q + off_q
|
||||
k_ptrs = K + off_k
|
||||
v_ptrs = V + off_v
|
||||
# initialize pointer to m and l
|
||||
m_prev = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_prev = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
# load q: it will stay in SRAM throughout
|
||||
q = tl.load(q_ptrs, offs_m[:, None] < H, other=0.0)
|
||||
# loop over k, v and update accumulator
|
||||
block_n_end = N_CTX # <- ADDED (including the IF)
|
||||
if IS_CAUSAL:
|
||||
# in causal mode, we expect that BLOCK_M_SIZE == BLOCK_N_SIZE
|
||||
# autotune will prune shapes not matching this rule
|
||||
block_n_end = (start_m + 1) * BLOCK_N + start_position
|
||||
for start_n in range(0, block_n_end, BLOCK_N):
|
||||
block_n_offs = start_n + offs_n # <- ADDED
|
||||
# -- compute qk ----
|
||||
k = tl.load(k_ptrs, block_n_offs[:, None] < N_CTX, 0.0)
|
||||
if USE_FP8:
|
||||
k = k.to(tl.float8e5, bitcast=True)
|
||||
k = k.to(tl.float16)
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, tl.trans(k))
|
||||
qk = tl.where(offs_n[None, :] < N_CTX, qk, float("-inf")) # <- ADDED
|
||||
qk *= sm_scale
|
||||
if IS_CAUSAL: # <- ADDED
|
||||
qk = tl.where(offs_m[:, None] >= (block_n_offs[None, :] + start_position), qk, float("-inf"))
|
||||
|
||||
# compute new m
|
||||
m_curr = tl.maximum(tl.max(qk, 1), m_prev)
|
||||
# correct old l
|
||||
l_prev *= tl.exp(m_prev - m_curr)
|
||||
# attention weights
|
||||
p = tl.exp(qk - m_curr[:, None])
|
||||
l_curr = tl.sum(p, 1) + l_prev
|
||||
# rescale operands of matmuls
|
||||
l_rcp = 1.0 / l_curr
|
||||
p *= l_rcp[:, None]
|
||||
acc *= (l_prev * l_rcp)[:, None]
|
||||
# update acc
|
||||
p = p.to(Q.dtype.element_ty)
|
||||
v = tl.load(v_ptrs, block_n_offs[:, None] < N_CTX, 0.0)
|
||||
if USE_FP8:
|
||||
v = v.to(tl.float8e5, bitcast=True)
|
||||
v = v.to(tl.float16)
|
||||
acc += tl.dot(p, v)
|
||||
# update m_i and l_i
|
||||
l_prev = l_curr
|
||||
m_prev = m_curr
|
||||
# update pointers
|
||||
k_ptrs += BLOCK_N * stride_kn
|
||||
v_ptrs += BLOCK_N * stride_vk
|
||||
# rematerialize offsets to save registers
|
||||
start_m = tl.program_id(0)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
|
||||
# initialize pointers to output
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
off_o = batch_id * stride_oz + off_hz * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(out_ptrs, acc, offs_m[:, None] < H)
|
||||
|
||||
|
||||
def triton_fa(q, k, v, sm_scale, is_causal, start_position):
|
||||
assert q.dtype == torch.float16
|
||||
assert k.dtype == v.dtype and k.dtype in [torch.float16, torch.int8]
|
||||
|
||||
BLOCK = 64
|
||||
# shape constraints
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
assert Lq == Lk and Lk == Lv
|
||||
assert Lk in {16, 32, 64, 128}
|
||||
o = torch.empty_like(q)
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
batch, head_size, m_size, dhead = q.size()
|
||||
grid = (triton.cdiv(m_size, BLOCK), head_size * batch)
|
||||
n_size = k.size(2)
|
||||
_fwd_kernel[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
sm_scale,
|
||||
o,
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
q.stride(3),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(2),
|
||||
k.stride(3),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
v.stride(2),
|
||||
v.stride(3),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
o.stride(2),
|
||||
o.stride(3),
|
||||
head_size,
|
||||
m_size,
|
||||
n_size,
|
||||
start_position=start_position,
|
||||
IS_CAUSAL=is_causal,
|
||||
BLOCK_M=BLOCK,
|
||||
BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=Lk,
|
||||
USE_FP8=k.dtype == torch.int8, # USE_FP8
|
||||
num_warps=num_warps,
|
||||
num_stages=2,
|
||||
)
|
||||
|
||||
return o
|
||||
|
||||
|
||||
def attention_triton_wrapper(q, k, v, head_dim):
|
||||
return triton_fa(q, k, v, sm_scale=1 / math.sqrt(head_dim), is_causal=True, start_position=0)
|
@ -0,0 +1,49 @@
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import torch
|
||||
|
||||
@triton.jit
|
||||
def rmsnorm_triton(x_ptr, rms_w_ptr, output_ptr,
|
||||
stride_x_batch, stride_x_m, stride_x_k,
|
||||
stride_rms_w,
|
||||
stride_out_batch, stride_out_m, stride_out_k,
|
||||
N_SIZE: tl.constexpr, eps: tl.constexpr, BLOCK_N_SIZE: tl.constexpr):
|
||||
pid_batch = tl.program_id(0)
|
||||
pid_m = tl.program_id(1)
|
||||
|
||||
offs_m = pid_batch * stride_x_batch + pid_m * stride_x_m
|
||||
block_N = tl.arange(0, BLOCK_N_SIZE)
|
||||
var = tl.zeros((BLOCK_N_SIZE,), tl.float32)
|
||||
for block_n_start_idx in range(0, N_SIZE, BLOCK_N_SIZE):
|
||||
offs_n = block_n_start_idx + block_N
|
||||
x_ptr_mask = offs_n < N_SIZE
|
||||
x = tl.load(x_ptr + offs_m + offs_n * stride_x_k, mask=x_ptr_mask, other=0.0)
|
||||
var += tl.math.pow(x.to(tl.float32), 2)
|
||||
|
||||
var = tl.sum(var, axis=0) / N_SIZE
|
||||
rstd = tl.math.rsqrt(var + eps)
|
||||
|
||||
# multiply by weight and add bias
|
||||
for block_n_start_idx in range(0, N_SIZE, BLOCK_N_SIZE):
|
||||
offs_n = block_n_start_idx + block_N
|
||||
x_ptr_mask = offs_n < N_SIZE
|
||||
rms_w = tl.load(rms_w_ptr + offs_n * stride_rms_w, mask=x_ptr_mask)
|
||||
|
||||
x = tl.load(x_ptr + offs_m + offs_n * stride_x_k, mask=x_ptr_mask, other=0.0).to(tl.float32)
|
||||
x_hat = x * rstd
|
||||
out = x_hat * rms_w
|
||||
out_off = pid_batch * stride_out_batch + pid_m * stride_out_m + offs_n * stride_out_k
|
||||
tl.store(output_ptr + out_off, out, mask=x_ptr_mask)
|
||||
|
||||
|
||||
def rmsnorm_triton_wrapper(x, rms_w, eps=1e-6):
|
||||
batch_size, seq_length, hid_dim = x.shape
|
||||
assert rms_w.shape[-1] == hid_dim
|
||||
out = torch.empty_like(x)
|
||||
rmsnorm_triton[(batch_size, seq_length,)](x, rms_w, out,
|
||||
*x.stride(),
|
||||
*rms_w.stride(),
|
||||
*out.stride(),
|
||||
N_SIZE=hid_dim, eps=eps, BLOCK_N_SIZE=1024,
|
||||
)
|
||||
return out
|
@ -0,0 +1,81 @@
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import torch
|
||||
|
||||
|
||||
@triton.jit
|
||||
def get_freq_multi_tokens(offs_cn, starting_idx, theta: tl.constexpr, NB_TOKENS: tl.constexpr):
|
||||
DIM: tl.constexpr = 128 # in model, dim = self.params.dim // self.params.n_heads
|
||||
freqs = offs_cn % DIM
|
||||
freqs = freqs.to(tl.float32) / DIM
|
||||
freqs = tl.math.pow(theta, freqs)
|
||||
freqs = (tl.arange(0, NB_TOKENS) + starting_idx)[:, None] / freqs[None, :]
|
||||
return tl.cos(freqs), tl.sin(freqs)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def rbe_triton(
|
||||
x_ptr,
|
||||
out_ptr,
|
||||
M,
|
||||
K,
|
||||
stride_x_batch,
|
||||
stride_x_m,
|
||||
stride_x_n,
|
||||
stride_out_batch,
|
||||
stride_out_m,
|
||||
stride_out_n,
|
||||
start_token_position,
|
||||
THETA: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
):
|
||||
pid_batch = tl.program_id(axis=0)
|
||||
pid = tl.program_id(axis=1)
|
||||
pid_m = pid // tl.cdiv(K, BLOCK_SIZE_K)
|
||||
pid_n = pid % tl.cdiv(K, BLOCK_SIZE_K)
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K // 2) * 2 # take only even numbers
|
||||
x_ptrs = x_ptr + (pid_batch * stride_x_batch + stride_x_m * offs_m[:, None] + stride_x_n * offs_n[None, :])
|
||||
x_real_mask = (offs_m[:, None] < M) & (offs_n[None, :] < K)
|
||||
real = tl.load(x_ptrs, mask=x_real_mask, other=0.0)
|
||||
x_imag_mask = (offs_m[:, None] < M) & (1 + offs_n[None, :] < K)
|
||||
imag = tl.load(x_ptrs + 1, mask=x_imag_mask, other=0.0)
|
||||
tl.debug_barrier()
|
||||
start_block = start_token_position + pid_m * BLOCK_SIZE_M
|
||||
cos, sin = get_freq_multi_tokens(offs_cn=offs_n, starting_idx=start_block, theta=THETA, NB_TOKENS=BLOCK_SIZE_M)
|
||||
|
||||
out_real = real * cos - imag * sin
|
||||
out_imag = real * sin + imag * cos
|
||||
tl.debug_barrier()
|
||||
out_ptrs = out_ptr + (
|
||||
pid_batch * stride_out_batch + stride_out_m * offs_m[:, None] + stride_out_n * offs_n[None, :]
|
||||
)
|
||||
out_real_mask = (offs_m[:, None] < M) & (offs_n[None, :] < K)
|
||||
tl.store(out_ptrs, out_real, mask=out_real_mask)
|
||||
out_imag_mask = (offs_m[:, None] < M) & (1 + offs_n[None, :] < K)
|
||||
tl.store(out_ptrs + 1, out_imag, mask=out_imag_mask)
|
||||
|
||||
|
||||
def rbe_triton_wrapper(x: torch.Tensor, pos: int) -> torch.Tensor:
|
||||
batch, M, K = x.shape
|
||||
out = torch.empty_like(x)
|
||||
grid = lambda META: (
|
||||
batch,
|
||||
triton.cdiv(META["M"], META["BLOCK_SIZE_M"]) * triton.cdiv(META["K"], META["BLOCK_SIZE_K"]),
|
||||
)
|
||||
|
||||
rbe_triton[grid](
|
||||
x,
|
||||
out,
|
||||
M,
|
||||
K,
|
||||
*x.stride(),
|
||||
*out.stride(),
|
||||
start_token_position=pos,
|
||||
THETA=10000.0,
|
||||
BLOCK_SIZE_M=2,
|
||||
BLOCK_SIZE_K=1024
|
||||
)
|
||||
return out
|
Loading…
Reference in New Issue