You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
petals/src/petals/triton/rmsnorm.py

49 lines
1.9 KiB
Python

import triton
import triton.language as tl
import torch
@triton.jit
def rmsnorm_triton(x_ptr, rms_w_ptr, output_ptr,
stride_x_batch, stride_x_m, stride_x_k,
stride_rms_w,
stride_out_batch, stride_out_m, stride_out_k,
N_SIZE: tl.constexpr, eps: tl.constexpr, BLOCK_N_SIZE: tl.constexpr):
pid_batch = tl.program_id(0)
pid_m = tl.program_id(1)
offs_m = pid_batch * stride_x_batch + pid_m * stride_x_m
block_N = tl.arange(0, BLOCK_N_SIZE)
var = tl.zeros((BLOCK_N_SIZE,), tl.float32)
for block_n_start_idx in range(0, N_SIZE, BLOCK_N_SIZE):
offs_n = block_n_start_idx + block_N
x_ptr_mask = offs_n < N_SIZE
x = tl.load(x_ptr + offs_m + offs_n * stride_x_k, mask=x_ptr_mask, other=0.0)
var += tl.math.pow(x.to(tl.float32), 2)
var = tl.sum(var, axis=0) / N_SIZE
rstd = tl.math.rsqrt(var + eps)
# multiply by weight and add bias
for block_n_start_idx in range(0, N_SIZE, BLOCK_N_SIZE):
offs_n = block_n_start_idx + block_N
x_ptr_mask = offs_n < N_SIZE
rms_w = tl.load(rms_w_ptr + offs_n * stride_rms_w, mask=x_ptr_mask)
x = tl.load(x_ptr + offs_m + offs_n * stride_x_k, mask=x_ptr_mask, other=0.0).to(tl.float32)
x_hat = x * rstd
out = x_hat * rms_w
out_off = pid_batch * stride_out_batch + pid_m * stride_out_m + offs_n * stride_out_k
tl.store(output_ptr + out_off, out, mask=x_ptr_mask)
def rmsnorm_triton_wrapper(x, rms_w, eps=1e-6):
batch_size, seq_length, hid_dim = x.shape
assert rms_w.shape[-1] == hid_dim
out = torch.empty_like(x)
rmsnorm_triton[(batch_size, seq_length,)](x, rms_w, out,
*x.stride(),
*rms_w.stride(),
*out.stride(),
N_SIZE=hid_dim, eps=eps, BLOCK_N_SIZE=1024,
)
return out