You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
petals/src/petals/triton/rotary.py

82 lines
2.6 KiB
Python

import triton
import triton.language as tl
import torch
@triton.jit
def get_freq_multi_tokens(offs_cn, starting_idx, theta: tl.constexpr, NB_TOKENS: tl.constexpr):
DIM: tl.constexpr = 128 # in model, dim = self.params.dim // self.params.n_heads
freqs = offs_cn % DIM
freqs = freqs.to(tl.float32) / DIM
freqs = tl.math.pow(theta, freqs)
freqs = (tl.arange(0, NB_TOKENS) + starting_idx)[:, None] / freqs[None, :]
return tl.cos(freqs), tl.sin(freqs)
@triton.jit
def rbe_triton(
x_ptr,
out_ptr,
M,
K,
stride_x_batch,
stride_x_m,
stride_x_n,
stride_out_batch,
stride_out_m,
stride_out_n,
start_token_position,
THETA: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
):
pid_batch = tl.program_id(axis=0)
pid = tl.program_id(axis=1)
pid_m = pid // tl.cdiv(K, BLOCK_SIZE_K)
pid_n = pid % tl.cdiv(K, BLOCK_SIZE_K)
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K // 2) * 2 # take only even numbers
x_ptrs = x_ptr + (pid_batch * stride_x_batch + stride_x_m * offs_m[:, None] + stride_x_n * offs_n[None, :])
x_real_mask = (offs_m[:, None] < M) & (offs_n[None, :] < K)
real = tl.load(x_ptrs, mask=x_real_mask, other=0.0)
x_imag_mask = (offs_m[:, None] < M) & (1 + offs_n[None, :] < K)
imag = tl.load(x_ptrs + 1, mask=x_imag_mask, other=0.0)
tl.debug_barrier()
start_block = start_token_position + pid_m * BLOCK_SIZE_M
cos, sin = get_freq_multi_tokens(offs_cn=offs_n, starting_idx=start_block, theta=THETA, NB_TOKENS=BLOCK_SIZE_M)
out_real = real * cos - imag * sin
out_imag = real * sin + imag * cos
tl.debug_barrier()
out_ptrs = out_ptr + (
pid_batch * stride_out_batch + stride_out_m * offs_m[:, None] + stride_out_n * offs_n[None, :]
)
out_real_mask = (offs_m[:, None] < M) & (offs_n[None, :] < K)
tl.store(out_ptrs, out_real, mask=out_real_mask)
out_imag_mask = (offs_m[:, None] < M) & (1 + offs_n[None, :] < K)
tl.store(out_ptrs + 1, out_imag, mask=out_imag_mask)
def rbe_triton_wrapper(x: torch.Tensor, pos: int) -> torch.Tensor:
batch, M, K = x.shape
out = torch.empty_like(x)
grid = lambda META: (
batch,
triton.cdiv(META["M"], META["BLOCK_SIZE_M"]) * triton.cdiv(META["K"], META["BLOCK_SIZE_K"]),
)
rbe_triton[grid](
x,
out,
M,
K,
*x.stride(),
*out.stride(),
start_token_position=pos,
THETA=10000.0,
BLOCK_SIZE_M=2,
BLOCK_SIZE_K=1024
)
return out