You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
82 lines
2.6 KiB
Python
82 lines
2.6 KiB
Python
import triton
|
|
import triton.language as tl
|
|
import torch
|
|
|
|
|
|
@triton.jit
|
|
def get_freq_multi_tokens(offs_cn, starting_idx, theta: tl.constexpr, NB_TOKENS: tl.constexpr):
|
|
DIM: tl.constexpr = 128 # in model, dim = self.params.dim // self.params.n_heads
|
|
freqs = offs_cn % DIM
|
|
freqs = freqs.to(tl.float32) / DIM
|
|
freqs = tl.math.pow(theta, freqs)
|
|
freqs = (tl.arange(0, NB_TOKENS) + starting_idx)[:, None] / freqs[None, :]
|
|
return tl.cos(freqs), tl.sin(freqs)
|
|
|
|
|
|
@triton.jit
|
|
def rbe_triton(
|
|
x_ptr,
|
|
out_ptr,
|
|
M,
|
|
K,
|
|
stride_x_batch,
|
|
stride_x_m,
|
|
stride_x_n,
|
|
stride_out_batch,
|
|
stride_out_m,
|
|
stride_out_n,
|
|
start_token_position,
|
|
THETA: tl.constexpr,
|
|
BLOCK_SIZE_M: tl.constexpr,
|
|
BLOCK_SIZE_K: tl.constexpr,
|
|
):
|
|
pid_batch = tl.program_id(axis=0)
|
|
pid = tl.program_id(axis=1)
|
|
pid_m = pid // tl.cdiv(K, BLOCK_SIZE_K)
|
|
pid_n = pid % tl.cdiv(K, BLOCK_SIZE_K)
|
|
|
|
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
offs_n = pid_n * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K // 2) * 2 # take only even numbers
|
|
x_ptrs = x_ptr + (pid_batch * stride_x_batch + stride_x_m * offs_m[:, None] + stride_x_n * offs_n[None, :])
|
|
x_real_mask = (offs_m[:, None] < M) & (offs_n[None, :] < K)
|
|
real = tl.load(x_ptrs, mask=x_real_mask, other=0.0)
|
|
x_imag_mask = (offs_m[:, None] < M) & (1 + offs_n[None, :] < K)
|
|
imag = tl.load(x_ptrs + 1, mask=x_imag_mask, other=0.0)
|
|
tl.debug_barrier()
|
|
start_block = start_token_position + pid_m * BLOCK_SIZE_M
|
|
cos, sin = get_freq_multi_tokens(offs_cn=offs_n, starting_idx=start_block, theta=THETA, NB_TOKENS=BLOCK_SIZE_M)
|
|
|
|
out_real = real * cos - imag * sin
|
|
out_imag = real * sin + imag * cos
|
|
tl.debug_barrier()
|
|
out_ptrs = out_ptr + (
|
|
pid_batch * stride_out_batch + stride_out_m * offs_m[:, None] + stride_out_n * offs_n[None, :]
|
|
)
|
|
out_real_mask = (offs_m[:, None] < M) & (offs_n[None, :] < K)
|
|
tl.store(out_ptrs, out_real, mask=out_real_mask)
|
|
out_imag_mask = (offs_m[:, None] < M) & (1 + offs_n[None, :] < K)
|
|
tl.store(out_ptrs + 1, out_imag, mask=out_imag_mask)
|
|
|
|
|
|
def rbe_triton_wrapper(x: torch.Tensor, pos: int) -> torch.Tensor:
|
|
batch, M, K = x.shape
|
|
out = torch.empty_like(x)
|
|
grid = lambda META: (
|
|
batch,
|
|
triton.cdiv(META["M"], META["BLOCK_SIZE_M"]) * triton.cdiv(META["K"], META["BLOCK_SIZE_K"]),
|
|
)
|
|
|
|
rbe_triton[grid](
|
|
x,
|
|
out,
|
|
M,
|
|
K,
|
|
*x.stride(),
|
|
*out.stride(),
|
|
start_token_position=pos,
|
|
THETA=10000.0,
|
|
BLOCK_SIZE_M=2,
|
|
BLOCK_SIZE_K=1024
|
|
)
|
|
return out
|