WIP Triton+QKV merge

wip_triton
Max Ryabinin 9 months ago
parent b4d822afb2
commit fa464dfc99

@ -6,10 +6,124 @@ See commit history for authorship.
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel import torch.nn as nn
import torch.nn.functional as F
import math
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaConfig,
LlamaDecoderLayer,
LlamaMLP,
LlamaModel,
LlamaRMSNorm,
repeat_kv,
apply_rotary_pos_emb,
)
from petals.triton import attention_triton_wrapper, rbe_triton_wrapper, rmsnorm_triton_wrapper
class WrappedLlamaBlock(LlamaDecoderLayer):
class OptimizedLlamaRMSNorm(LlamaRMSNorm):
def forward(self, hidden_states):
if torch.is_inference_mode_enabled():
return rmsnorm_triton_wrapper(hidden_states, self.weight)
return super().forward(hidden_states)
class OptimizedLlamaAttention(LlamaAttention):
def __init__(self, config: LlamaConfig):
super().__init__(config)
self.qkv_proj = nn.Linear(
self.hidden_size, (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, bias=False
)
self.qkv_sizes = [
self.num_heads * self.head_dim,
self.num_key_value_heads * self.head_dim,
self.num_key_value_heads * self.head_dim,
]
self.attn_norm_constant = math.sqrt(self.head_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
assert (
self.config.pretraining_tp == 1
), "OptimizedLlamaAttention assumes that config.pretraining_tp is equal to 1"
query_states, key_states, value_states = torch.split(self.qkv_proj(hidden_states), self.qkv_sizes, dim=2)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / self.attn_norm_constant
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
class OptimizedLlamaDecoderLayer(LlamaDecoderLayer):
def __init__(self, config: LlamaConfig):
nn.Module.__init__(self)
self.hidden_size = config.hidden_size
self.self_attn = OptimizedLlamaAttention(config=config)
self.mlp = LlamaMLP(config)
self.input_layernorm = OptimizedLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = OptimizedLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
class WrappedLlamaBlock(OptimizedLlamaDecoderLayer):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,

@ -1,6 +1,8 @@
from __future__ import annotations
import argparse
import fcntl import fcntl
import json import json
import math
import multiprocessing as mp import multiprocessing as mp
import os import os
import time import time
@ -8,14 +10,19 @@ from collections import Counter
from pathlib import Path from pathlib import Path
from typing import Dict, Optional, Sequence, Union from typing import Dict, Optional, Sequence, Union
import configargparse
import torch import torch
import torch.mps import torch.mps
from hivemind.utils.logging import get_logger from hivemind.utils.logging import get_logger
from transformers import PretrainedConfig from transformers import PretrainedConfig
from petals.constants import DTYPE_MAP
from petals.server.block_utils import resolve_block_dtype from petals.server.block_utils import resolve_block_dtype
from petals.utils.auto_config import AutoDistributedConfig
from petals.utils.convert_block import QuantType, convert_block from petals.utils.convert_block import QuantType, convert_block
from petals.utils.disk_cache import DEFAULT_CACHE_DIR from petals.utils.disk_cache import DEFAULT_CACHE_DIR
from petals.utils.version import get_compatible_model_repo
logger = get_logger(__name__) logger = get_logger(__name__)
@ -114,6 +121,7 @@ def measure_throughput_info(
*, *,
quant_type: QuantType, quant_type: QuantType,
tensor_parallel_devices: Sequence[torch.device], tensor_parallel_devices: Sequence[torch.device],
measure_network: bool = True,
) -> Dict[str, float]: ) -> Dict[str, float]:
logger.info( logger.info(
"Measuring network and compute throughput. This takes about a minute and will be cached for future runs" "Measuring network and compute throughput. This takes about a minute and will be cached for future runs"
@ -139,14 +147,16 @@ def measure_throughput_info(
n_steps=10, n_steps=10,
inference=False, inference=False,
), ),
"network_rps": measure_network_rps(config), "network_rps": measure_network_rps(config, use_default=not measure_network),
} }
def measure_network_rps( def measure_network_rps(
config: PretrainedConfig, *, timeout: float = 60, default_speed: float = 100e6 # 100 Mbit/s config: PretrainedConfig, *, use_default=False, timeout: float = 60, default_speed: float = 100e6 # 100 Mbit/s
) -> Optional[float]: ) -> Optional[float]:
bits_per_request = config.hidden_size * 16 # Clients usually send 16-bit tensors for forward/backward bits_per_request = config.hidden_size * 16 # Clients usually send 16-bit tensors for forward/backward
if use_default:
return default_speed / bits_per_request
try: try:
pipe_recv, pipe_send = mp.Pipe(duplex=False) pipe_recv, pipe_send = mp.Pipe(duplex=False)
process = mp.Process(target=_measure_bits_per_second, args=(pipe_send,)) process = mp.Process(target=_measure_bits_per_second, args=(pipe_send,))
@ -207,13 +217,23 @@ def measure_compute_rps(
cache = None cache = None
elapsed = 0 elapsed = 0
dummy_input = torch.randn(1, n_tokens, config.hidden_size, device=device, dtype=dtype) dummy_input = torch.randn(1, n_tokens, config.hidden_size, device=device, dtype=dtype)
# with torch.profiler.profile(
# schedule=torch.profiler.schedule(wait=1, warmup=4, active=n_steps, repeat=1),
# on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/profbf16_70b_qkv'),
# record_shapes=True,
# profile_memory=True,
# with_stack=True
# ) as prof:
_, cache = block.forward(dummy_input, use_cache=True) # Skip the 1st step to exclude the initialization time _, cache = block.forward(dummy_input, use_cache=True) # Skip the 1st step to exclude the initialization time
synchronize(device) synchronize(device)
# prof.step()
start_time = time.perf_counter() start_time = time.perf_counter()
for _ in range(n_steps): for _ in range(n_steps):
_, cache = block.forward(dummy_input, use_cache=True, layer_past=cache if inference else None) _, cache = block.forward(dummy_input, use_cache=True, layer_past=cache if inference else None)
synchronize(device) synchronize(device)
# prof.step()
elapsed = time.perf_counter() - start_time elapsed = time.perf_counter() - start_time
device_rps = n_steps * n_tokens / elapsed device_rps = n_steps * n_tokens / elapsed
@ -245,3 +265,93 @@ def get_dtype_name(dtype: torch.dtype, quant_type: QuantType) -> str:
if quant_type != QuantType.NONE: if quant_type != QuantType.NONE:
name += f", quantized to {quant_type.name.lower()}" name += f", quantized to {quant_type.name.lower()}"
return name return name
def parse_args():
# fmt:off
parser = configargparse.ArgParser(default_config_files=["config.yml"],
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add('-c', '--config', required=False, is_config_file=True, help='config file path')
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument('--converted_model_name_or_path', type=str, default=None,
help="path or name of a pretrained model, converted with cli/convert_model.py")
group.add_argument('model', nargs='?', type=str, help="same as --converted_model_name_or_path")
group = parser.add_mutually_exclusive_group(required=False)
group.add_argument("--token", type=str, default=None, help="Hugging Face hub auth token for .from_pretrained()")
group.add_argument("--use_auth_token", action="store_true", dest="token",
help="Read token saved by `huggingface-cli login")
parser.add_argument('--device', type=str, default=None, required=False,
help='all blocks will use this device in torch notation; default: cuda if available else cpu')
parser.add_argument("--torch_dtype", type=str, choices=DTYPE_MAP.keys(), default="auto",
help="Use this dtype to store block weights and do computations. "
"By default, respect the dtypes in the pre-trained state dict.")
parser.add_argument('--revision', type=str, default=None,
help="The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models"
"and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.")
parser.add_argument('--quant_type', type=str, default=None, choices=[choice.name.lower() for choice in QuantType],
help="Quantize blocks to 8-bit (int8 from the LLM.int8() paper) or "
"4-bit (nf4 from the QLoRA paper) formats to save GPU memory. "
"Default: 'int8' if GPU is available, 'none' otherwise")
parser.add_argument("--tensor_parallel_devices", nargs='+', default=None,
help=
"Split each block between the specified GPUs such that each device holds a portion of every "
"weight matrix. See https://huggingface.co/transformers/v4.9.0/parallelism.html#tensor-parallelism")
# fmt:on
args = parser.parse_args()
args.converted_model_name_or_path = args.model
return args
if __name__ == "__main__":
args = parse_args()
converted_model_name_or_path = get_compatible_model_repo(args.converted_model_name_or_path)
config = AutoDistributedConfig.from_pretrained(
converted_model_name_or_path,
use_auth_token=args.token,
revision=args.revision,
)
device = args.device
if device is None:
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
device = torch.device(device)
if device.type == "cuda" and device.index is None:
device = torch.device(device.type, index=0)
torch_dtype = resolve_block_dtype(config, DTYPE_MAP[args.torch_dtype])
if device.type == "cpu" and torch_dtype == torch.float16:
raise ValueError(
f"Type float16 is not supported on CPU. Please use --torch_dtype float32 or --torch_dtype bfloat16"
)
if device.type == "mps" and torch_dtype == torch.bfloat16:
logger.warning(f"Type bfloat16 is not supported on MPS, using float16 instead")
torch_dtype = torch.float16
quant_type = args.quant_type
if quant_type is None:
if device.type == "cuda":
quant_type = QuantType.NF4 if config.model_type == "llama" else QuantType.INT8
else:
quant_type = QuantType.NONE
if args.tensor_parallel_devices is None:
args.tensor_parallel_devices = (device,)
measure_throughput_info(
config,
device,
torch_dtype,
quant_type=quant_type,
tensor_parallel_devices=args.tensor_parallel_devices,
measure_network=False,
)

@ -0,0 +1,3 @@
from petals.triton.rmsnorm import rmsnorm_triton_wrapper
from petals.triton.attention import attention_triton_wrapper
from petals.triton.rotary import rbe_triton_wrapper

@ -0,0 +1,178 @@
import math
import torch
import triton
import triton.language as tl
@triton.jit
def _fwd_kernel(
Q,
K,
V,
sm_scale,
Out,
stride_qz,
stride_qh,
stride_qm,
stride_qk,
stride_kz,
stride_kh,
stride_kn,
stride_kk,
stride_vz,
stride_vh,
stride_vk,
stride_vn,
stride_oz,
stride_oh,
stride_om,
stride_on,
N_HEAD,
H,
N_CTX,
start_position, # <- ADDED
IS_CAUSAL: tl.constexpr, # <- ADDED
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
USE_FP8: tl.constexpr,
):
start_m = tl.program_id(0)
head_idx = tl.program_id(1)
batch_id = head_idx // N_HEAD
off_hz = head_idx % N_HEAD
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
off_q = (
batch_id * stride_qz + off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
) # <- stride fixed
off_k = (
batch_id * stride_kz + off_hz * stride_kh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
) # <- stride fixed
off_v = (
batch_id * stride_vz + off_hz * stride_vh + offs_n[:, None] * stride_vk + offs_d[None, :] * stride_vn
) # <- stride fixed
# Initialize pointers to Q, K, V
q_ptrs = Q + off_q
k_ptrs = K + off_k
v_ptrs = V + off_v
# initialize pointer to m and l
m_prev = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_prev = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# load q: it will stay in SRAM throughout
q = tl.load(q_ptrs, offs_m[:, None] < H, other=0.0)
# loop over k, v and update accumulator
block_n_end = N_CTX # <- ADDED (including the IF)
if IS_CAUSAL:
# in causal mode, we expect that BLOCK_M_SIZE == BLOCK_N_SIZE
# autotune will prune shapes not matching this rule
block_n_end = (start_m + 1) * BLOCK_N + start_position
for start_n in range(0, block_n_end, BLOCK_N):
block_n_offs = start_n + offs_n # <- ADDED
# -- compute qk ----
k = tl.load(k_ptrs, block_n_offs[:, None] < N_CTX, 0.0)
if USE_FP8:
k = k.to(tl.float8e5, bitcast=True)
k = k.to(tl.float16)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, tl.trans(k))
qk = tl.where(offs_n[None, :] < N_CTX, qk, float("-inf")) # <- ADDED
qk *= sm_scale
if IS_CAUSAL: # <- ADDED
qk = tl.where(offs_m[:, None] >= (block_n_offs[None, :] + start_position), qk, float("-inf"))
# compute new m
m_curr = tl.maximum(tl.max(qk, 1), m_prev)
# correct old l
l_prev *= tl.exp(m_prev - m_curr)
# attention weights
p = tl.exp(qk - m_curr[:, None])
l_curr = tl.sum(p, 1) + l_prev
# rescale operands of matmuls
l_rcp = 1.0 / l_curr
p *= l_rcp[:, None]
acc *= (l_prev * l_rcp)[:, None]
# update acc
p = p.to(Q.dtype.element_ty)
v = tl.load(v_ptrs, block_n_offs[:, None] < N_CTX, 0.0)
if USE_FP8:
v = v.to(tl.float8e5, bitcast=True)
v = v.to(tl.float16)
acc += tl.dot(p, v)
# update m_i and l_i
l_prev = l_curr
m_prev = m_curr
# update pointers
k_ptrs += BLOCK_N * stride_kn
v_ptrs += BLOCK_N * stride_vk
# rematerialize offsets to save registers
start_m = tl.program_id(0)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# initialize pointers to output
offs_d = tl.arange(0, BLOCK_DMODEL)
off_o = batch_id * stride_oz + off_hz * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on
out_ptrs = Out + off_o
tl.store(out_ptrs, acc, offs_m[:, None] < H)
def triton_fa(q, k, v, sm_scale, is_causal, start_position):
assert q.dtype == torch.float16
assert k.dtype == v.dtype and k.dtype in [torch.float16, torch.int8]
BLOCK = 64
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}
o = torch.empty_like(q)
num_warps = 4 if Lk <= 64 else 8
batch, head_size, m_size, dhead = q.size()
grid = (triton.cdiv(m_size, BLOCK), head_size * batch)
n_size = k.size(2)
_fwd_kernel[grid](
q,
k,
v,
sm_scale,
o,
q.stride(0),
q.stride(1),
q.stride(2),
q.stride(3),
k.stride(0),
k.stride(1),
k.stride(2),
k.stride(3),
v.stride(0),
v.stride(1),
v.stride(2),
v.stride(3),
o.stride(0),
o.stride(1),
o.stride(2),
o.stride(3),
head_size,
m_size,
n_size,
start_position=start_position,
IS_CAUSAL=is_causal,
BLOCK_M=BLOCK,
BLOCK_N=BLOCK,
BLOCK_DMODEL=Lk,
USE_FP8=k.dtype == torch.int8, # USE_FP8
num_warps=num_warps,
num_stages=2,
)
return o
def attention_triton_wrapper(q, k, v, head_dim):
return triton_fa(q, k, v, sm_scale=1 / math.sqrt(head_dim), is_causal=True, start_position=0)

@ -0,0 +1,49 @@
import triton
import triton.language as tl
import torch
@triton.jit
def rmsnorm_triton(x_ptr, rms_w_ptr, output_ptr,
stride_x_batch, stride_x_m, stride_x_k,
stride_rms_w,
stride_out_batch, stride_out_m, stride_out_k,
N_SIZE: tl.constexpr, eps: tl.constexpr, BLOCK_N_SIZE: tl.constexpr):
pid_batch = tl.program_id(0)
pid_m = tl.program_id(1)
offs_m = pid_batch * stride_x_batch + pid_m * stride_x_m
block_N = tl.arange(0, BLOCK_N_SIZE)
var = tl.zeros((BLOCK_N_SIZE,), tl.float32)
for block_n_start_idx in range(0, N_SIZE, BLOCK_N_SIZE):
offs_n = block_n_start_idx + block_N
x_ptr_mask = offs_n < N_SIZE
x = tl.load(x_ptr + offs_m + offs_n * stride_x_k, mask=x_ptr_mask, other=0.0)
var += tl.math.pow(x.to(tl.float32), 2)
var = tl.sum(var, axis=0) / N_SIZE
rstd = tl.math.rsqrt(var + eps)
# multiply by weight and add bias
for block_n_start_idx in range(0, N_SIZE, BLOCK_N_SIZE):
offs_n = block_n_start_idx + block_N
x_ptr_mask = offs_n < N_SIZE
rms_w = tl.load(rms_w_ptr + offs_n * stride_rms_w, mask=x_ptr_mask)
x = tl.load(x_ptr + offs_m + offs_n * stride_x_k, mask=x_ptr_mask, other=0.0).to(tl.float32)
x_hat = x * rstd
out = x_hat * rms_w
out_off = pid_batch * stride_out_batch + pid_m * stride_out_m + offs_n * stride_out_k
tl.store(output_ptr + out_off, out, mask=x_ptr_mask)
def rmsnorm_triton_wrapper(x, rms_w, eps=1e-6):
batch_size, seq_length, hid_dim = x.shape
assert rms_w.shape[-1] == hid_dim
out = torch.empty_like(x)
rmsnorm_triton[(batch_size, seq_length,)](x, rms_w, out,
*x.stride(),
*rms_w.stride(),
*out.stride(),
N_SIZE=hid_dim, eps=eps, BLOCK_N_SIZE=1024,
)
return out

@ -0,0 +1,81 @@
import triton
import triton.language as tl
import torch
@triton.jit
def get_freq_multi_tokens(offs_cn, starting_idx, theta: tl.constexpr, NB_TOKENS: tl.constexpr):
DIM: tl.constexpr = 128 # in model, dim = self.params.dim // self.params.n_heads
freqs = offs_cn % DIM
freqs = freqs.to(tl.float32) / DIM
freqs = tl.math.pow(theta, freqs)
freqs = (tl.arange(0, NB_TOKENS) + starting_idx)[:, None] / freqs[None, :]
return tl.cos(freqs), tl.sin(freqs)
@triton.jit
def rbe_triton(
x_ptr,
out_ptr,
M,
K,
stride_x_batch,
stride_x_m,
stride_x_n,
stride_out_batch,
stride_out_m,
stride_out_n,
start_token_position,
THETA: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
):
pid_batch = tl.program_id(axis=0)
pid = tl.program_id(axis=1)
pid_m = pid // tl.cdiv(K, BLOCK_SIZE_K)
pid_n = pid % tl.cdiv(K, BLOCK_SIZE_K)
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K // 2) * 2 # take only even numbers
x_ptrs = x_ptr + (pid_batch * stride_x_batch + stride_x_m * offs_m[:, None] + stride_x_n * offs_n[None, :])
x_real_mask = (offs_m[:, None] < M) & (offs_n[None, :] < K)
real = tl.load(x_ptrs, mask=x_real_mask, other=0.0)
x_imag_mask = (offs_m[:, None] < M) & (1 + offs_n[None, :] < K)
imag = tl.load(x_ptrs + 1, mask=x_imag_mask, other=0.0)
tl.debug_barrier()
start_block = start_token_position + pid_m * BLOCK_SIZE_M
cos, sin = get_freq_multi_tokens(offs_cn=offs_n, starting_idx=start_block, theta=THETA, NB_TOKENS=BLOCK_SIZE_M)
out_real = real * cos - imag * sin
out_imag = real * sin + imag * cos
tl.debug_barrier()
out_ptrs = out_ptr + (
pid_batch * stride_out_batch + stride_out_m * offs_m[:, None] + stride_out_n * offs_n[None, :]
)
out_real_mask = (offs_m[:, None] < M) & (offs_n[None, :] < K)
tl.store(out_ptrs, out_real, mask=out_real_mask)
out_imag_mask = (offs_m[:, None] < M) & (1 + offs_n[None, :] < K)
tl.store(out_ptrs + 1, out_imag, mask=out_imag_mask)
def rbe_triton_wrapper(x: torch.Tensor, pos: int) -> torch.Tensor:
batch, M, K = x.shape
out = torch.empty_like(x)
grid = lambda META: (
batch,
triton.cdiv(META["M"], META["BLOCK_SIZE_M"]) * triton.cdiv(META["K"], META["BLOCK_SIZE_K"]),
)
rbe_triton[grid](
x,
out,
M,
K,
*x.stride(),
*out.stride(),
start_token_position=pos,
THETA=10000.0,
BLOCK_SIZE_M=2,
BLOCK_SIZE_K=1024
)
return out

@ -50,6 +50,19 @@ def convert_block(
if freeze: if freeze:
block.requires_grad_(False) block.requires_grad_(False)
if hasattr(block, "self_attn") and hasattr(block.self_attn, "qkv_proj"):
offset = 0
for data in [
block.self_attn.q_proj.weight.data,
block.self_attn.k_proj.weight.data,
block.self_attn.v_proj.weight.data,
]:
block.self_attn.qkv_proj.weight.data[offset : offset + data.size(0)].copy_(data)
offset += data.size(0)
del block.self_attn.q_proj
del block.self_attn.k_proj
del block.self_attn.v_proj
block = make_tensor_parallel(block, config, tensor_parallel_devices, output_device=output_device) block = make_tensor_parallel(block, config, tensor_parallel_devices, output_device=output_device)
if quant_type != QuantType.NONE: if quant_type != QuantType.NONE:

Loading…
Cancel
Save