You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
49 lines
1.9 KiB
Python
49 lines
1.9 KiB
Python
import triton
|
|
import triton.language as tl
|
|
import torch
|
|
|
|
@triton.jit
|
|
def rmsnorm_triton(x_ptr, rms_w_ptr, output_ptr,
|
|
stride_x_batch, stride_x_m, stride_x_k,
|
|
stride_rms_w,
|
|
stride_out_batch, stride_out_m, stride_out_k,
|
|
N_SIZE: tl.constexpr, eps: tl.constexpr, BLOCK_N_SIZE: tl.constexpr):
|
|
pid_batch = tl.program_id(0)
|
|
pid_m = tl.program_id(1)
|
|
|
|
offs_m = pid_batch * stride_x_batch + pid_m * stride_x_m
|
|
block_N = tl.arange(0, BLOCK_N_SIZE)
|
|
var = tl.zeros((BLOCK_N_SIZE,), tl.float32)
|
|
for block_n_start_idx in range(0, N_SIZE, BLOCK_N_SIZE):
|
|
offs_n = block_n_start_idx + block_N
|
|
x_ptr_mask = offs_n < N_SIZE
|
|
x = tl.load(x_ptr + offs_m + offs_n * stride_x_k, mask=x_ptr_mask, other=0.0)
|
|
var += tl.math.pow(x.to(tl.float32), 2)
|
|
|
|
var = tl.sum(var, axis=0) / N_SIZE
|
|
rstd = tl.math.rsqrt(var + eps)
|
|
|
|
# multiply by weight and add bias
|
|
for block_n_start_idx in range(0, N_SIZE, BLOCK_N_SIZE):
|
|
offs_n = block_n_start_idx + block_N
|
|
x_ptr_mask = offs_n < N_SIZE
|
|
rms_w = tl.load(rms_w_ptr + offs_n * stride_rms_w, mask=x_ptr_mask)
|
|
|
|
x = tl.load(x_ptr + offs_m + offs_n * stride_x_k, mask=x_ptr_mask, other=0.0).to(tl.float32)
|
|
x_hat = x * rstd
|
|
out = x_hat * rms_w
|
|
out_off = pid_batch * stride_out_batch + pid_m * stride_out_m + offs_n * stride_out_k
|
|
tl.store(output_ptr + out_off, out, mask=x_ptr_mask)
|
|
|
|
|
|
def rmsnorm_triton_wrapper(x, rms_w, eps=1e-6):
|
|
batch_size, seq_length, hid_dim = x.shape
|
|
assert rms_w.shape[-1] == hid_dim
|
|
out = torch.empty_like(x)
|
|
rmsnorm_triton[(batch_size, seq_length,)](x, rms_w, out,
|
|
*x.stride(),
|
|
*rms_w.stride(),
|
|
*out.stride(),
|
|
N_SIZE=hid_dim, eps=eps, BLOCK_N_SIZE=1024,
|
|
)
|
|
return out |