WIP Triton+QKV merge

wip_triton
Max Ryabinin 8 months ago
parent b4d822afb2
commit fa464dfc99

@ -6,10 +6,124 @@ See commit history for authorship.
from typing import Optional, Tuple
import torch
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
import torch.nn as nn
import torch.nn.functional as F
import math
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaConfig,
LlamaDecoderLayer,
LlamaMLP,
LlamaModel,
LlamaRMSNorm,
repeat_kv,
apply_rotary_pos_emb,
)
from petals.triton import attention_triton_wrapper, rbe_triton_wrapper, rmsnorm_triton_wrapper
class WrappedLlamaBlock(LlamaDecoderLayer):
class OptimizedLlamaRMSNorm(LlamaRMSNorm):
def forward(self, hidden_states):
if torch.is_inference_mode_enabled():
return rmsnorm_triton_wrapper(hidden_states, self.weight)
return super().forward(hidden_states)
class OptimizedLlamaAttention(LlamaAttention):
def __init__(self, config: LlamaConfig):
super().__init__(config)
self.qkv_proj = nn.Linear(
self.hidden_size, (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, bias=False
)
self.qkv_sizes = [
self.num_heads * self.head_dim,
self.num_key_value_heads * self.head_dim,
self.num_key_value_heads * self.head_dim,
]
self.attn_norm_constant = math.sqrt(self.head_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
assert (
self.config.pretraining_tp == 1
), "OptimizedLlamaAttention assumes that config.pretraining_tp is equal to 1"
query_states, key_states, value_states = torch.split(self.qkv_proj(hidden_states), self.qkv_sizes, dim=2)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / self.attn_norm_constant
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
class OptimizedLlamaDecoderLayer(LlamaDecoderLayer):
def __init__(self, config: LlamaConfig):
nn.Module.__init__(self)
self.hidden_size = config.hidden_size
self.self_attn = OptimizedLlamaAttention(config=config)
self.mlp = LlamaMLP(config)
self.input_layernorm = OptimizedLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = OptimizedLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
class WrappedLlamaBlock(OptimizedLlamaDecoderLayer):
def forward(
self,
hidden_states: torch.Tensor,

@ -1,6 +1,8 @@
from __future__ import annotations
import argparse
import fcntl
import json
import math
import multiprocessing as mp
import os
import time
@ -8,14 +10,19 @@ from collections import Counter
from pathlib import Path
from typing import Dict, Optional, Sequence, Union
import configargparse
import torch
import torch.mps
from hivemind.utils.logging import get_logger
from transformers import PretrainedConfig
from petals.constants import DTYPE_MAP
from petals.server.block_utils import resolve_block_dtype
from petals.utils.auto_config import AutoDistributedConfig
from petals.utils.convert_block import QuantType, convert_block
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
from petals.utils.version import get_compatible_model_repo
logger = get_logger(__name__)
@ -114,6 +121,7 @@ def measure_throughput_info(
*,
quant_type: QuantType,
tensor_parallel_devices: Sequence[torch.device],
measure_network: bool = True,
) -> Dict[str, float]:
logger.info(
"Measuring network and compute throughput. This takes about a minute and will be cached for future runs"
@ -139,14 +147,16 @@ def measure_throughput_info(
n_steps=10,
inference=False,
),
"network_rps": measure_network_rps(config),
"network_rps": measure_network_rps(config, use_default=not measure_network),
}
def measure_network_rps(
config: PretrainedConfig, *, timeout: float = 60, default_speed: float = 100e6 # 100 Mbit/s
config: PretrainedConfig, *, use_default=False, timeout: float = 60, default_speed: float = 100e6 # 100 Mbit/s
) -> Optional[float]:
bits_per_request = config.hidden_size * 16 # Clients usually send 16-bit tensors for forward/backward
if use_default:
return default_speed / bits_per_request
try:
pipe_recv, pipe_send = mp.Pipe(duplex=False)
process = mp.Process(target=_measure_bits_per_second, args=(pipe_send,))
@ -207,13 +217,23 @@ def measure_compute_rps(
cache = None
elapsed = 0
dummy_input = torch.randn(1, n_tokens, config.hidden_size, device=device, dtype=dtype)
# with torch.profiler.profile(
# schedule=torch.profiler.schedule(wait=1, warmup=4, active=n_steps, repeat=1),
# on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/profbf16_70b_qkv'),
# record_shapes=True,
# profile_memory=True,
# with_stack=True
# ) as prof:
_, cache = block.forward(dummy_input, use_cache=True) # Skip the 1st step to exclude the initialization time
synchronize(device)
# prof.step()
start_time = time.perf_counter()
for _ in range(n_steps):
_, cache = block.forward(dummy_input, use_cache=True, layer_past=cache if inference else None)
synchronize(device)
synchronize(device)
# prof.step()
elapsed = time.perf_counter() - start_time
device_rps = n_steps * n_tokens / elapsed
@ -245,3 +265,93 @@ def get_dtype_name(dtype: torch.dtype, quant_type: QuantType) -> str:
if quant_type != QuantType.NONE:
name += f", quantized to {quant_type.name.lower()}"
return name
def parse_args():
# fmt:off
parser = configargparse.ArgParser(default_config_files=["config.yml"],
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add('-c', '--config', required=False, is_config_file=True, help='config file path')
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument('--converted_model_name_or_path', type=str, default=None,
help="path or name of a pretrained model, converted with cli/convert_model.py")
group.add_argument('model', nargs='?', type=str, help="same as --converted_model_name_or_path")
group = parser.add_mutually_exclusive_group(required=False)
group.add_argument("--token", type=str, default=None, help="Hugging Face hub auth token for .from_pretrained()")
group.add_argument("--use_auth_token", action="store_true", dest="token",
help="Read token saved by `huggingface-cli login")
parser.add_argument('--device', type=str, default=None, required=False,
help='all blocks will use this device in torch notation; default: cuda if available else cpu')
parser.add_argument("--torch_dtype", type=str, choices=DTYPE_MAP.keys(), default="auto",
help="Use this dtype to store block weights and do computations. "
"By default, respect the dtypes in the pre-trained state dict.")
parser.add_argument('--revision', type=str, default=None,
help="The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models"
"and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.")
parser.add_argument('--quant_type', type=str, default=None, choices=[choice.name.lower() for choice in QuantType],
help="Quantize blocks to 8-bit (int8 from the LLM.int8() paper) or "
"4-bit (nf4 from the QLoRA paper) formats to save GPU memory. "
"Default: 'int8' if GPU is available, 'none' otherwise")
parser.add_argument("--tensor_parallel_devices", nargs='+', default=None,
help=
"Split each block between the specified GPUs such that each device holds a portion of every "
"weight matrix. See https://huggingface.co/transformers/v4.9.0/parallelism.html#tensor-parallelism")
# fmt:on
args = parser.parse_args()
args.converted_model_name_or_path = args.model
return args
if __name__ == "__main__":
args = parse_args()
converted_model_name_or_path = get_compatible_model_repo(args.converted_model_name_or_path)
config = AutoDistributedConfig.from_pretrained(
converted_model_name_or_path,
use_auth_token=args.token,
revision=args.revision,
)
device = args.device
if device is None:
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
device = torch.device(device)
if device.type == "cuda" and device.index is None:
device = torch.device(device.type, index=0)
torch_dtype = resolve_block_dtype(config, DTYPE_MAP[args.torch_dtype])
if device.type == "cpu" and torch_dtype == torch.float16:
raise ValueError(
f"Type float16 is not supported on CPU. Please use --torch_dtype float32 or --torch_dtype bfloat16"
)
if device.type == "mps" and torch_dtype == torch.bfloat16:
logger.warning(f"Type bfloat16 is not supported on MPS, using float16 instead")
torch_dtype = torch.float16
quant_type = args.quant_type
if quant_type is None:
if device.type == "cuda":
quant_type = QuantType.NF4 if config.model_type == "llama" else QuantType.INT8
else:
quant_type = QuantType.NONE
if args.tensor_parallel_devices is None:
args.tensor_parallel_devices = (device,)
measure_throughput_info(
config,
device,
torch_dtype,
quant_type=quant_type,
tensor_parallel_devices=args.tensor_parallel_devices,
measure_network=False,
)

@ -0,0 +1,3 @@
from petals.triton.rmsnorm import rmsnorm_triton_wrapper
from petals.triton.attention import attention_triton_wrapper
from petals.triton.rotary import rbe_triton_wrapper

@ -0,0 +1,178 @@
import math
import torch
import triton
import triton.language as tl
@triton.jit
def _fwd_kernel(
Q,
K,
V,
sm_scale,
Out,
stride_qz,
stride_qh,
stride_qm,
stride_qk,
stride_kz,
stride_kh,
stride_kn,
stride_kk,
stride_vz,
stride_vh,
stride_vk,
stride_vn,
stride_oz,
stride_oh,
stride_om,
stride_on,
N_HEAD,
H,
N_CTX,
start_position, # <- ADDED
IS_CAUSAL: tl.constexpr, # <- ADDED
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
USE_FP8: tl.constexpr,
):
start_m = tl.program_id(0)
head_idx = tl.program_id(1)
batch_id = head_idx // N_HEAD
off_hz = head_idx % N_HEAD
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
off_q = (
batch_id * stride_qz + off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
) # <- stride fixed
off_k = (
batch_id * stride_kz + off_hz * stride_kh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
) # <- stride fixed
off_v = (
batch_id * stride_vz + off_hz * stride_vh + offs_n[:, None] * stride_vk + offs_d[None, :] * stride_vn
) # <- stride fixed
# Initialize pointers to Q, K, V
q_ptrs = Q + off_q
k_ptrs = K + off_k
v_ptrs = V + off_v
# initialize pointer to m and l
m_prev = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_prev = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# load q: it will stay in SRAM throughout
q = tl.load(q_ptrs, offs_m[:, None] < H, other=0.0)
# loop over k, v and update accumulator
block_n_end = N_CTX # <- ADDED (including the IF)
if IS_CAUSAL:
# in causal mode, we expect that BLOCK_M_SIZE == BLOCK_N_SIZE
# autotune will prune shapes not matching this rule
block_n_end = (start_m + 1) * BLOCK_N + start_position
for start_n in range(0, block_n_end, BLOCK_N):
block_n_offs = start_n + offs_n # <- ADDED
# -- compute qk ----
k = tl.load(k_ptrs, block_n_offs[:, None] < N_CTX, 0.0)
if USE_FP8:
k = k.to(tl.float8e5, bitcast=True)
k = k.to(tl.float16)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, tl.trans(k))
qk = tl.where(offs_n[None, :] < N_CTX, qk, float("-inf")) # <- ADDED
qk *= sm_scale
if IS_CAUSAL: # <- ADDED
qk = tl.where(offs_m[:, None] >= (block_n_offs[None, :] + start_position), qk, float("-inf"))
# compute new m
m_curr = tl.maximum(tl.max(qk, 1), m_prev)
# correct old l
l_prev *= tl.exp(m_prev - m_curr)
# attention weights
p = tl.exp(qk - m_curr[:, None])
l_curr = tl.sum(p, 1) + l_prev
# rescale operands of matmuls
l_rcp = 1.0 / l_curr
p *= l_rcp[:, None]
acc *= (l_prev * l_rcp)[:, None]
# update acc
p = p.to(Q.dtype.element_ty)
v = tl.load(v_ptrs, block_n_offs[:, None] < N_CTX, 0.0)
if USE_FP8:
v = v.to(tl.float8e5, bitcast=True)
v = v.to(tl.float16)
acc += tl.dot(p, v)
# update m_i and l_i
l_prev = l_curr
m_prev = m_curr
# update pointers
k_ptrs += BLOCK_N * stride_kn
v_ptrs += BLOCK_N * stride_vk
# rematerialize offsets to save registers
start_m = tl.program_id(0)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# initialize pointers to output
offs_d = tl.arange(0, BLOCK_DMODEL)
off_o = batch_id * stride_oz + off_hz * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on
out_ptrs = Out + off_o
tl.store(out_ptrs, acc, offs_m[:, None] < H)
def triton_fa(q, k, v, sm_scale, is_causal, start_position):
assert q.dtype == torch.float16
assert k.dtype == v.dtype and k.dtype in [torch.float16, torch.int8]
BLOCK = 64
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}
o = torch.empty_like(q)
num_warps = 4 if Lk <= 64 else 8
batch, head_size, m_size, dhead = q.size()
grid = (triton.cdiv(m_size, BLOCK), head_size * batch)
n_size = k.size(2)
_fwd_kernel[grid](
q,
k,
v,
sm_scale,
o,
q.stride(0),
q.stride(1),
q.stride(2),
q.stride(3),
k.stride(0),
k.stride(1),
k.stride(2),
k.stride(3),
v.stride(0),
v.stride(1),
v.stride(2),
v.stride(3),
o.stride(0),
o.stride(1),
o.stride(2),
o.stride(3),
head_size,
m_size,
n_size,
start_position=start_position,
IS_CAUSAL=is_causal,
BLOCK_M=BLOCK,
BLOCK_N=BLOCK,
BLOCK_DMODEL=Lk,
USE_FP8=k.dtype == torch.int8, # USE_FP8
num_warps=num_warps,
num_stages=2,
)
return o
def attention_triton_wrapper(q, k, v, head_dim):
return triton_fa(q, k, v, sm_scale=1 / math.sqrt(head_dim), is_causal=True, start_position=0)

@ -0,0 +1,49 @@
import triton
import triton.language as tl
import torch
@triton.jit
def rmsnorm_triton(x_ptr, rms_w_ptr, output_ptr,
stride_x_batch, stride_x_m, stride_x_k,
stride_rms_w,
stride_out_batch, stride_out_m, stride_out_k,
N_SIZE: tl.constexpr, eps: tl.constexpr, BLOCK_N_SIZE: tl.constexpr):
pid_batch = tl.program_id(0)
pid_m = tl.program_id(1)
offs_m = pid_batch * stride_x_batch + pid_m * stride_x_m
block_N = tl.arange(0, BLOCK_N_SIZE)
var = tl.zeros((BLOCK_N_SIZE,), tl.float32)
for block_n_start_idx in range(0, N_SIZE, BLOCK_N_SIZE):
offs_n = block_n_start_idx + block_N
x_ptr_mask = offs_n < N_SIZE
x = tl.load(x_ptr + offs_m + offs_n * stride_x_k, mask=x_ptr_mask, other=0.0)
var += tl.math.pow(x.to(tl.float32), 2)
var = tl.sum(var, axis=0) / N_SIZE
rstd = tl.math.rsqrt(var + eps)
# multiply by weight and add bias
for block_n_start_idx in range(0, N_SIZE, BLOCK_N_SIZE):
offs_n = block_n_start_idx + block_N
x_ptr_mask = offs_n < N_SIZE
rms_w = tl.load(rms_w_ptr + offs_n * stride_rms_w, mask=x_ptr_mask)
x = tl.load(x_ptr + offs_m + offs_n * stride_x_k, mask=x_ptr_mask, other=0.0).to(tl.float32)
x_hat = x * rstd
out = x_hat * rms_w
out_off = pid_batch * stride_out_batch + pid_m * stride_out_m + offs_n * stride_out_k
tl.store(output_ptr + out_off, out, mask=x_ptr_mask)
def rmsnorm_triton_wrapper(x, rms_w, eps=1e-6):
batch_size, seq_length, hid_dim = x.shape
assert rms_w.shape[-1] == hid_dim
out = torch.empty_like(x)
rmsnorm_triton[(batch_size, seq_length,)](x, rms_w, out,
*x.stride(),
*rms_w.stride(),
*out.stride(),
N_SIZE=hid_dim, eps=eps, BLOCK_N_SIZE=1024,
)
return out

@ -0,0 +1,81 @@
import triton
import triton.language as tl
import torch
@triton.jit
def get_freq_multi_tokens(offs_cn, starting_idx, theta: tl.constexpr, NB_TOKENS: tl.constexpr):
DIM: tl.constexpr = 128 # in model, dim = self.params.dim // self.params.n_heads
freqs = offs_cn % DIM
freqs = freqs.to(tl.float32) / DIM
freqs = tl.math.pow(theta, freqs)
freqs = (tl.arange(0, NB_TOKENS) + starting_idx)[:, None] / freqs[None, :]
return tl.cos(freqs), tl.sin(freqs)
@triton.jit
def rbe_triton(
x_ptr,
out_ptr,
M,
K,
stride_x_batch,
stride_x_m,
stride_x_n,
stride_out_batch,
stride_out_m,
stride_out_n,
start_token_position,
THETA: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
):
pid_batch = tl.program_id(axis=0)
pid = tl.program_id(axis=1)
pid_m = pid // tl.cdiv(K, BLOCK_SIZE_K)
pid_n = pid % tl.cdiv(K, BLOCK_SIZE_K)
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K // 2) * 2 # take only even numbers
x_ptrs = x_ptr + (pid_batch * stride_x_batch + stride_x_m * offs_m[:, None] + stride_x_n * offs_n[None, :])
x_real_mask = (offs_m[:, None] < M) & (offs_n[None, :] < K)
real = tl.load(x_ptrs, mask=x_real_mask, other=0.0)
x_imag_mask = (offs_m[:, None] < M) & (1 + offs_n[None, :] < K)
imag = tl.load(x_ptrs + 1, mask=x_imag_mask, other=0.0)
tl.debug_barrier()
start_block = start_token_position + pid_m * BLOCK_SIZE_M
cos, sin = get_freq_multi_tokens(offs_cn=offs_n, starting_idx=start_block, theta=THETA, NB_TOKENS=BLOCK_SIZE_M)
out_real = real * cos - imag * sin
out_imag = real * sin + imag * cos
tl.debug_barrier()
out_ptrs = out_ptr + (
pid_batch * stride_out_batch + stride_out_m * offs_m[:, None] + stride_out_n * offs_n[None, :]
)
out_real_mask = (offs_m[:, None] < M) & (offs_n[None, :] < K)
tl.store(out_ptrs, out_real, mask=out_real_mask)
out_imag_mask = (offs_m[:, None] < M) & (1 + offs_n[None, :] < K)
tl.store(out_ptrs + 1, out_imag, mask=out_imag_mask)
def rbe_triton_wrapper(x: torch.Tensor, pos: int) -> torch.Tensor:
batch, M, K = x.shape
out = torch.empty_like(x)
grid = lambda META: (
batch,
triton.cdiv(META["M"], META["BLOCK_SIZE_M"]) * triton.cdiv(META["K"], META["BLOCK_SIZE_K"]),
)
rbe_triton[grid](
x,
out,
M,
K,
*x.stride(),
*out.stride(),
start_token_position=pos,
THETA=10000.0,
BLOCK_SIZE_M=2,
BLOCK_SIZE_K=1024
)
return out

@ -50,6 +50,19 @@ def convert_block(
if freeze:
block.requires_grad_(False)
if hasattr(block, "self_attn") and hasattr(block.self_attn, "qkv_proj"):
offset = 0
for data in [
block.self_attn.q_proj.weight.data,
block.self_attn.k_proj.weight.data,
block.self_attn.v_proj.weight.data,
]:
block.self_attn.qkv_proj.weight.data[offset : offset + data.size(0)].copy_(data)
offset += data.size(0)
del block.self_attn.q_proj
del block.self_attn.k_proj
del block.self_attn.v_proj
block = make_tensor_parallel(block, config, tensor_parallel_devices, output_device=output_device)
if quant_type != QuantType.NONE:

Loading…
Cancel
Save