From fa464dfc99195840d0305df8a8e68ccebd8cd325 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Sun, 3 Sep 2023 00:11:03 +0300 Subject: [PATCH] WIP Triton+QKV merge --- src/petals/models/llama/block.py | 118 +++++++++++++++++++- src/petals/server/throughput.py | 118 +++++++++++++++++++- src/petals/triton/__init__.py | 3 + src/petals/triton/attention.py | 178 ++++++++++++++++++++++++++++++ src/petals/triton/rmsnorm.py | 49 ++++++++ src/petals/triton/rotary.py | 81 ++++++++++++++ src/petals/utils/convert_block.py | 13 +++ 7 files changed, 554 insertions(+), 6 deletions(-) create mode 100644 src/petals/triton/__init__.py create mode 100644 src/petals/triton/attention.py create mode 100644 src/petals/triton/rmsnorm.py create mode 100644 src/petals/triton/rotary.py diff --git a/src/petals/models/llama/block.py b/src/petals/models/llama/block.py index 55f659a..2dfaa20 100644 --- a/src/petals/models/llama/block.py +++ b/src/petals/models/llama/block.py @@ -6,10 +6,124 @@ See commit history for authorship. from typing import Optional, Tuple import torch -from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel +import torch.nn as nn +import torch.nn.functional as F +import math +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaConfig, + LlamaDecoderLayer, + LlamaMLP, + LlamaModel, + LlamaRMSNorm, + repeat_kv, + apply_rotary_pos_emb, +) +from petals.triton import attention_triton_wrapper, rbe_triton_wrapper, rmsnorm_triton_wrapper -class WrappedLlamaBlock(LlamaDecoderLayer): + +class OptimizedLlamaRMSNorm(LlamaRMSNorm): + def forward(self, hidden_states): + if torch.is_inference_mode_enabled(): + return rmsnorm_triton_wrapper(hidden_states, self.weight) + return super().forward(hidden_states) + + +class OptimizedLlamaAttention(LlamaAttention): + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.qkv_proj = nn.Linear( + self.hidden_size, (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, bias=False + ) + self.qkv_sizes = [ + self.num_heads * self.head_dim, + self.num_key_value_heads * self.head_dim, + self.num_key_value_heads * self.head_dim, + ] + self.attn_norm_constant = math.sqrt(self.head_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + assert ( + self.config.pretraining_tp == 1 + ), "OptimizedLlamaAttention assumes that config.pretraining_tp is equal to 1" + + query_states, key_states, value_states = torch.split(self.qkv_proj(hidden_states), self.qkv_sizes, dim=2) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / self.attn_norm_constant + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +class OptimizedLlamaDecoderLayer(LlamaDecoderLayer): + def __init__(self, config: LlamaConfig): + nn.Module.__init__(self) + self.hidden_size = config.hidden_size + self.self_attn = OptimizedLlamaAttention(config=config) + self.mlp = LlamaMLP(config) + self.input_layernorm = OptimizedLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = OptimizedLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + +class WrappedLlamaBlock(OptimizedLlamaDecoderLayer): def forward( self, hidden_states: torch.Tensor, diff --git a/src/petals/server/throughput.py b/src/petals/server/throughput.py index bf71f44..8ef3632 100644 --- a/src/petals/server/throughput.py +++ b/src/petals/server/throughput.py @@ -1,6 +1,8 @@ +from __future__ import annotations + +import argparse import fcntl import json -import math import multiprocessing as mp import os import time @@ -8,14 +10,19 @@ from collections import Counter from pathlib import Path from typing import Dict, Optional, Sequence, Union +import configargparse import torch + import torch.mps from hivemind.utils.logging import get_logger from transformers import PretrainedConfig +from petals.constants import DTYPE_MAP from petals.server.block_utils import resolve_block_dtype +from petals.utils.auto_config import AutoDistributedConfig from petals.utils.convert_block import QuantType, convert_block from petals.utils.disk_cache import DEFAULT_CACHE_DIR +from petals.utils.version import get_compatible_model_repo logger = get_logger(__name__) @@ -114,6 +121,7 @@ def measure_throughput_info( *, quant_type: QuantType, tensor_parallel_devices: Sequence[torch.device], + measure_network: bool = True, ) -> Dict[str, float]: logger.info( "Measuring network and compute throughput. This takes about a minute and will be cached for future runs" @@ -139,14 +147,16 @@ def measure_throughput_info( n_steps=10, inference=False, ), - "network_rps": measure_network_rps(config), + "network_rps": measure_network_rps(config, use_default=not measure_network), } def measure_network_rps( - config: PretrainedConfig, *, timeout: float = 60, default_speed: float = 100e6 # 100 Mbit/s + config: PretrainedConfig, *, use_default=False, timeout: float = 60, default_speed: float = 100e6 # 100 Mbit/s ) -> Optional[float]: bits_per_request = config.hidden_size * 16 # Clients usually send 16-bit tensors for forward/backward + if use_default: + return default_speed / bits_per_request try: pipe_recv, pipe_send = mp.Pipe(duplex=False) process = mp.Process(target=_measure_bits_per_second, args=(pipe_send,)) @@ -207,13 +217,23 @@ def measure_compute_rps( cache = None elapsed = 0 dummy_input = torch.randn(1, n_tokens, config.hidden_size, device=device, dtype=dtype) + # with torch.profiler.profile( + # schedule=torch.profiler.schedule(wait=1, warmup=4, active=n_steps, repeat=1), + # on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/profbf16_70b_qkv'), + # record_shapes=True, + # profile_memory=True, + # with_stack=True + # ) as prof: _, cache = block.forward(dummy_input, use_cache=True) # Skip the 1st step to exclude the initialization time synchronize(device) + # prof.step() start_time = time.perf_counter() for _ in range(n_steps): _, cache = block.forward(dummy_input, use_cache=True, layer_past=cache if inference else None) - synchronize(device) + synchronize(device) + # prof.step() + elapsed = time.perf_counter() - start_time device_rps = n_steps * n_tokens / elapsed @@ -245,3 +265,93 @@ def get_dtype_name(dtype: torch.dtype, quant_type: QuantType) -> str: if quant_type != QuantType.NONE: name += f", quantized to {quant_type.name.lower()}" return name + + +def parse_args(): + # fmt:off + parser = configargparse.ArgParser(default_config_files=["config.yml"], + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add('-c', '--config', required=False, is_config_file=True, help='config file path') + + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument('--converted_model_name_or_path', type=str, default=None, + help="path or name of a pretrained model, converted with cli/convert_model.py") + group.add_argument('model', nargs='?', type=str, help="same as --converted_model_name_or_path") + + group = parser.add_mutually_exclusive_group(required=False) + group.add_argument("--token", type=str, default=None, help="Hugging Face hub auth token for .from_pretrained()") + group.add_argument("--use_auth_token", action="store_true", dest="token", + help="Read token saved by `huggingface-cli login") + + parser.add_argument('--device', type=str, default=None, required=False, + help='all blocks will use this device in torch notation; default: cuda if available else cpu') + parser.add_argument("--torch_dtype", type=str, choices=DTYPE_MAP.keys(), default="auto", + help="Use this dtype to store block weights and do computations. " + "By default, respect the dtypes in the pre-trained state dict.") + parser.add_argument('--revision', type=str, default=None, + help="The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models" + "and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.") + + parser.add_argument('--quant_type', type=str, default=None, choices=[choice.name.lower() for choice in QuantType], + help="Quantize blocks to 8-bit (int8 from the LLM.int8() paper) or " + "4-bit (nf4 from the QLoRA paper) formats to save GPU memory. " + "Default: 'int8' if GPU is available, 'none' otherwise") + parser.add_argument("--tensor_parallel_devices", nargs='+', default=None, + help= + "Split each block between the specified GPUs such that each device holds a portion of every " + "weight matrix. See https://huggingface.co/transformers/v4.9.0/parallelism.html#tensor-parallelism") + + # fmt:on + args = parser.parse_args() + args.converted_model_name_or_path = args.model + return args + + +if __name__ == "__main__": + args = parse_args() + converted_model_name_or_path = get_compatible_model_repo(args.converted_model_name_or_path) + config = AutoDistributedConfig.from_pretrained( + converted_model_name_or_path, + use_auth_token=args.token, + revision=args.revision, + ) + + device = args.device + if device is None: + if torch.cuda.is_available(): + device = "cuda" + elif torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" + device = torch.device(device) + if device.type == "cuda" and device.index is None: + device = torch.device(device.type, index=0) + + torch_dtype = resolve_block_dtype(config, DTYPE_MAP[args.torch_dtype]) + if device.type == "cpu" and torch_dtype == torch.float16: + raise ValueError( + f"Type float16 is not supported on CPU. Please use --torch_dtype float32 or --torch_dtype bfloat16" + ) + if device.type == "mps" and torch_dtype == torch.bfloat16: + logger.warning(f"Type bfloat16 is not supported on MPS, using float16 instead") + torch_dtype = torch.float16 + + quant_type = args.quant_type + if quant_type is None: + if device.type == "cuda": + quant_type = QuantType.NF4 if config.model_type == "llama" else QuantType.INT8 + else: + quant_type = QuantType.NONE + + if args.tensor_parallel_devices is None: + args.tensor_parallel_devices = (device,) + + measure_throughput_info( + config, + device, + torch_dtype, + quant_type=quant_type, + tensor_parallel_devices=args.tensor_parallel_devices, + measure_network=False, + ) diff --git a/src/petals/triton/__init__.py b/src/petals/triton/__init__.py new file mode 100644 index 0000000..c58ff5e --- /dev/null +++ b/src/petals/triton/__init__.py @@ -0,0 +1,3 @@ +from petals.triton.rmsnorm import rmsnorm_triton_wrapper +from petals.triton.attention import attention_triton_wrapper +from petals.triton.rotary import rbe_triton_wrapper \ No newline at end of file diff --git a/src/petals/triton/attention.py b/src/petals/triton/attention.py new file mode 100644 index 0000000..81f2696 --- /dev/null +++ b/src/petals/triton/attention.py @@ -0,0 +1,178 @@ +import math + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel( + Q, + K, + V, + sm_scale, + Out, + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vk, + stride_vn, + stride_oz, + stride_oh, + stride_om, + stride_on, + N_HEAD, + H, + N_CTX, + start_position, # <- ADDED + IS_CAUSAL: tl.constexpr, # <- ADDED + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + USE_FP8: tl.constexpr, +): + start_m = tl.program_id(0) + + head_idx = tl.program_id(1) + batch_id = head_idx // N_HEAD + off_hz = head_idx % N_HEAD + + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + off_q = ( + batch_id * stride_qz + off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk + ) # <- stride fixed + off_k = ( + batch_id * stride_kz + off_hz * stride_kh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk + ) # <- stride fixed + off_v = ( + batch_id * stride_vz + off_hz * stride_vh + offs_n[:, None] * stride_vk + offs_d[None, :] * stride_vn + ) # <- stride fixed + # Initialize pointers to Q, K, V + q_ptrs = Q + off_q + k_ptrs = K + off_k + v_ptrs = V + off_v + # initialize pointer to m and l + m_prev = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_prev = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # load q: it will stay in SRAM throughout + q = tl.load(q_ptrs, offs_m[:, None] < H, other=0.0) + # loop over k, v and update accumulator + block_n_end = N_CTX # <- ADDED (including the IF) + if IS_CAUSAL: + # in causal mode, we expect that BLOCK_M_SIZE == BLOCK_N_SIZE + # autotune will prune shapes not matching this rule + block_n_end = (start_m + 1) * BLOCK_N + start_position + for start_n in range(0, block_n_end, BLOCK_N): + block_n_offs = start_n + offs_n # <- ADDED + # -- compute qk ---- + k = tl.load(k_ptrs, block_n_offs[:, None] < N_CTX, 0.0) + if USE_FP8: + k = k.to(tl.float8e5, bitcast=True) + k = k.to(tl.float16) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, tl.trans(k)) + qk = tl.where(offs_n[None, :] < N_CTX, qk, float("-inf")) # <- ADDED + qk *= sm_scale + if IS_CAUSAL: # <- ADDED + qk = tl.where(offs_m[:, None] >= (block_n_offs[None, :] + start_position), qk, float("-inf")) + + # compute new m + m_curr = tl.maximum(tl.max(qk, 1), m_prev) + # correct old l + l_prev *= tl.exp(m_prev - m_curr) + # attention weights + p = tl.exp(qk - m_curr[:, None]) + l_curr = tl.sum(p, 1) + l_prev + # rescale operands of matmuls + l_rcp = 1.0 / l_curr + p *= l_rcp[:, None] + acc *= (l_prev * l_rcp)[:, None] + # update acc + p = p.to(Q.dtype.element_ty) + v = tl.load(v_ptrs, block_n_offs[:, None] < N_CTX, 0.0) + if USE_FP8: + v = v.to(tl.float8e5, bitcast=True) + v = v.to(tl.float16) + acc += tl.dot(p, v) + # update m_i and l_i + l_prev = l_curr + m_prev = m_curr + # update pointers + k_ptrs += BLOCK_N * stride_kn + v_ptrs += BLOCK_N * stride_vk + # rematerialize offsets to save registers + start_m = tl.program_id(0) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + + # initialize pointers to output + offs_d = tl.arange(0, BLOCK_DMODEL) + off_o = batch_id * stride_oz + off_hz * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on + out_ptrs = Out + off_o + tl.store(out_ptrs, acc, offs_m[:, None] < H) + + +def triton_fa(q, k, v, sm_scale, is_causal, start_position): + assert q.dtype == torch.float16 + assert k.dtype == v.dtype and k.dtype in [torch.float16, torch.int8] + + BLOCK = 64 + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} + o = torch.empty_like(q) + num_warps = 4 if Lk <= 64 else 8 + batch, head_size, m_size, dhead = q.size() + grid = (triton.cdiv(m_size, BLOCK), head_size * batch) + n_size = k.size(2) + _fwd_kernel[grid]( + q, + k, + v, + sm_scale, + o, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + o.stride(0), + o.stride(1), + o.stride(2), + o.stride(3), + head_size, + m_size, + n_size, + start_position=start_position, + IS_CAUSAL=is_causal, + BLOCK_M=BLOCK, + BLOCK_N=BLOCK, + BLOCK_DMODEL=Lk, + USE_FP8=k.dtype == torch.int8, # USE_FP8 + num_warps=num_warps, + num_stages=2, + ) + + return o + + +def attention_triton_wrapper(q, k, v, head_dim): + return triton_fa(q, k, v, sm_scale=1 / math.sqrt(head_dim), is_causal=True, start_position=0) diff --git a/src/petals/triton/rmsnorm.py b/src/petals/triton/rmsnorm.py new file mode 100644 index 0000000..ddff624 --- /dev/null +++ b/src/petals/triton/rmsnorm.py @@ -0,0 +1,49 @@ +import triton +import triton.language as tl +import torch + +@triton.jit +def rmsnorm_triton(x_ptr, rms_w_ptr, output_ptr, + stride_x_batch, stride_x_m, stride_x_k, + stride_rms_w, + stride_out_batch, stride_out_m, stride_out_k, + N_SIZE: tl.constexpr, eps: tl.constexpr, BLOCK_N_SIZE: tl.constexpr): + pid_batch = tl.program_id(0) + pid_m = tl.program_id(1) + + offs_m = pid_batch * stride_x_batch + pid_m * stride_x_m + block_N = tl.arange(0, BLOCK_N_SIZE) + var = tl.zeros((BLOCK_N_SIZE,), tl.float32) + for block_n_start_idx in range(0, N_SIZE, BLOCK_N_SIZE): + offs_n = block_n_start_idx + block_N + x_ptr_mask = offs_n < N_SIZE + x = tl.load(x_ptr + offs_m + offs_n * stride_x_k, mask=x_ptr_mask, other=0.0) + var += tl.math.pow(x.to(tl.float32), 2) + + var = tl.sum(var, axis=0) / N_SIZE + rstd = tl.math.rsqrt(var + eps) + + # multiply by weight and add bias + for block_n_start_idx in range(0, N_SIZE, BLOCK_N_SIZE): + offs_n = block_n_start_idx + block_N + x_ptr_mask = offs_n < N_SIZE + rms_w = tl.load(rms_w_ptr + offs_n * stride_rms_w, mask=x_ptr_mask) + + x = tl.load(x_ptr + offs_m + offs_n * stride_x_k, mask=x_ptr_mask, other=0.0).to(tl.float32) + x_hat = x * rstd + out = x_hat * rms_w + out_off = pid_batch * stride_out_batch + pid_m * stride_out_m + offs_n * stride_out_k + tl.store(output_ptr + out_off, out, mask=x_ptr_mask) + + +def rmsnorm_triton_wrapper(x, rms_w, eps=1e-6): + batch_size, seq_length, hid_dim = x.shape + assert rms_w.shape[-1] == hid_dim + out = torch.empty_like(x) + rmsnorm_triton[(batch_size, seq_length,)](x, rms_w, out, + *x.stride(), + *rms_w.stride(), + *out.stride(), + N_SIZE=hid_dim, eps=eps, BLOCK_N_SIZE=1024, + ) + return out \ No newline at end of file diff --git a/src/petals/triton/rotary.py b/src/petals/triton/rotary.py new file mode 100644 index 0000000..53bfd28 --- /dev/null +++ b/src/petals/triton/rotary.py @@ -0,0 +1,81 @@ +import triton +import triton.language as tl +import torch + + +@triton.jit +def get_freq_multi_tokens(offs_cn, starting_idx, theta: tl.constexpr, NB_TOKENS: tl.constexpr): + DIM: tl.constexpr = 128 # in model, dim = self.params.dim // self.params.n_heads + freqs = offs_cn % DIM + freqs = freqs.to(tl.float32) / DIM + freqs = tl.math.pow(theta, freqs) + freqs = (tl.arange(0, NB_TOKENS) + starting_idx)[:, None] / freqs[None, :] + return tl.cos(freqs), tl.sin(freqs) + + +@triton.jit +def rbe_triton( + x_ptr, + out_ptr, + M, + K, + stride_x_batch, + stride_x_m, + stride_x_n, + stride_out_batch, + stride_out_m, + stride_out_n, + start_token_position, + THETA: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + pid_batch = tl.program_id(axis=0) + pid = tl.program_id(axis=1) + pid_m = pid // tl.cdiv(K, BLOCK_SIZE_K) + pid_n = pid % tl.cdiv(K, BLOCK_SIZE_K) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K // 2) * 2 # take only even numbers + x_ptrs = x_ptr + (pid_batch * stride_x_batch + stride_x_m * offs_m[:, None] + stride_x_n * offs_n[None, :]) + x_real_mask = (offs_m[:, None] < M) & (offs_n[None, :] < K) + real = tl.load(x_ptrs, mask=x_real_mask, other=0.0) + x_imag_mask = (offs_m[:, None] < M) & (1 + offs_n[None, :] < K) + imag = tl.load(x_ptrs + 1, mask=x_imag_mask, other=0.0) + tl.debug_barrier() + start_block = start_token_position + pid_m * BLOCK_SIZE_M + cos, sin = get_freq_multi_tokens(offs_cn=offs_n, starting_idx=start_block, theta=THETA, NB_TOKENS=BLOCK_SIZE_M) + + out_real = real * cos - imag * sin + out_imag = real * sin + imag * cos + tl.debug_barrier() + out_ptrs = out_ptr + ( + pid_batch * stride_out_batch + stride_out_m * offs_m[:, None] + stride_out_n * offs_n[None, :] + ) + out_real_mask = (offs_m[:, None] < M) & (offs_n[None, :] < K) + tl.store(out_ptrs, out_real, mask=out_real_mask) + out_imag_mask = (offs_m[:, None] < M) & (1 + offs_n[None, :] < K) + tl.store(out_ptrs + 1, out_imag, mask=out_imag_mask) + + +def rbe_triton_wrapper(x: torch.Tensor, pos: int) -> torch.Tensor: + batch, M, K = x.shape + out = torch.empty_like(x) + grid = lambda META: ( + batch, + triton.cdiv(META["M"], META["BLOCK_SIZE_M"]) * triton.cdiv(META["K"], META["BLOCK_SIZE_K"]), + ) + + rbe_triton[grid]( + x, + out, + M, + K, + *x.stride(), + *out.stride(), + start_token_position=pos, + THETA=10000.0, + BLOCK_SIZE_M=2, + BLOCK_SIZE_K=1024 + ) + return out diff --git a/src/petals/utils/convert_block.py b/src/petals/utils/convert_block.py index 94d3e29..fc4c6c7 100644 --- a/src/petals/utils/convert_block.py +++ b/src/petals/utils/convert_block.py @@ -50,6 +50,19 @@ def convert_block( if freeze: block.requires_grad_(False) + if hasattr(block, "self_attn") and hasattr(block.self_attn, "qkv_proj"): + offset = 0 + for data in [ + block.self_attn.q_proj.weight.data, + block.self_attn.k_proj.weight.data, + block.self_attn.v_proj.weight.data, + ]: + block.self_attn.qkv_proj.weight.data[offset : offset + data.size(0)].copy_(data) + offset += data.size(0) + del block.self_attn.q_proj + del block.self_attn.k_proj + del block.self_attn.v_proj + block = make_tensor_parallel(block, config, tensor_parallel_devices, output_device=output_device) if quant_type != QuantType.NONE: