|
|
|
@ -4,7 +4,6 @@ Based on https://github.com/huggingface/transformers/blob/main/src/transformers/
|
|
|
|
|
See commit history for authorship.
|
|
|
|
|
"""
|
|
|
|
|
import math
|
|
|
|
|
from functools import partial
|
|
|
|
|
from typing import Optional, Tuple
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
@ -17,25 +16,15 @@ from transformers.models.falcon.modeling_falcon import (
|
|
|
|
|
FalconLinear,
|
|
|
|
|
FalconMLP,
|
|
|
|
|
FalconModel,
|
|
|
|
|
FalconRotaryEmbedding,
|
|
|
|
|
LayerNorm,
|
|
|
|
|
build_alibi_tensor,
|
|
|
|
|
dropout_add,
|
|
|
|
|
rotate_half,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
|
|
|
|
INFERENCE_MAX_LENGTH = 8192
|
|
|
|
|
|
|
|
|
|
# @torch.jit.script
|
|
|
|
|
def rotate_half(x):
|
|
|
|
|
x1, x2 = torch.chunk(x, 2, dim=2)
|
|
|
|
|
return torch.cat((-x2, x1), dim=-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# @torch.jit.script
|
|
|
|
|
def apply_rotary(query, key, cos, sin):
|
|
|
|
|
return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OptimizedFalconRotaryEmbedding(nn.Module):
|
|
|
|
|
def __init__(self, head_dim: int, base=10000):
|
|
|
|
@ -45,38 +34,6 @@ class OptimizedFalconRotaryEmbedding(nn.Module):
|
|
|
|
|
self.head_dim = head_dim
|
|
|
|
|
self.seq_len_cached = -1
|
|
|
|
|
|
|
|
|
|
self.cuda_graph = None
|
|
|
|
|
self.input_surface = None
|
|
|
|
|
self.static_outputs = None
|
|
|
|
|
|
|
|
|
|
self.cos_sin(
|
|
|
|
|
seq_len=INFERENCE_MAX_LENGTH,
|
|
|
|
|
past_key_values_length=0,
|
|
|
|
|
device=self.inv_freq.device,
|
|
|
|
|
dtype=torch.get_default_dtype(),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _optimized_apply_rotary(self, query, key, cos, sin):
|
|
|
|
|
if self.cuda_graph is None:
|
|
|
|
|
self.cuda_graph = torch.cuda.CUDAGraph()
|
|
|
|
|
self.input_surface = (query, key, cos, sin)
|
|
|
|
|
|
|
|
|
|
s = torch.cuda.Stream()
|
|
|
|
|
s.wait_stream(torch.cuda.current_stream())
|
|
|
|
|
with torch.cuda.stream(s):
|
|
|
|
|
for _ in range(3):
|
|
|
|
|
apply_rotary(*self.input_surface)
|
|
|
|
|
torch.cuda.current_stream().wait_stream(s)
|
|
|
|
|
|
|
|
|
|
with torch.cuda.graph(self.cuda_graph):
|
|
|
|
|
self.static_outputs = apply_rotary(*self.input_surface)
|
|
|
|
|
|
|
|
|
|
inputs = (query, key, cos, sin)
|
|
|
|
|
for static_input, data in zip(self.input_surface, inputs):
|
|
|
|
|
static_input.copy_(data)
|
|
|
|
|
self.cuda_graph.replay()
|
|
|
|
|
return tuple(o.detach() for o in self.static_outputs)
|
|
|
|
|
|
|
|
|
|
def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor:
|
|
|
|
|
total_length = seq_len + past_key_values_length
|
|
|
|
|
if self.seq_len_cached == -1:
|
|
|
|
@ -84,16 +41,17 @@ class OptimizedFalconRotaryEmbedding(nn.Module):
|
|
|
|
|
total_length = max(INFERENCE_MAX_LENGTH, total_length)
|
|
|
|
|
|
|
|
|
|
if total_length > self.seq_len_cached:
|
|
|
|
|
self.seq_len_cached = total_length
|
|
|
|
|
t = torch.arange(total_length, device=device, dtype=self.inv_freq.dtype)
|
|
|
|
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
|
|
|
|
emb = torch.cat((freqs, freqs), dim=-1).to(device)
|
|
|
|
|
with torch.inference_mode(False):
|
|
|
|
|
self.seq_len_cached = total_length
|
|
|
|
|
t = torch.arange(total_length, device=device, dtype=self.inv_freq.dtype)
|
|
|
|
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
|
|
|
|
emb = torch.cat((freqs, freqs), dim=-1).to(device)
|
|
|
|
|
|
|
|
|
|
if dtype in [torch.float16, torch.bfloat16]:
|
|
|
|
|
emb = emb.float()
|
|
|
|
|
if dtype in [torch.float16, torch.bfloat16]:
|
|
|
|
|
emb = emb.float()
|
|
|
|
|
|
|
|
|
|
self.register_buffer("cos_cached", emb.cos()[None, :, :].type(dtype), persistent=False)
|
|
|
|
|
self.register_buffer("sin_cached", emb.sin()[None, :, :].type(dtype), persistent=False)
|
|
|
|
|
self.register_buffer("cos_cached", emb.cos()[None, :, :].type(dtype), persistent=False)
|
|
|
|
|
self.register_buffer("sin_cached", emb.sin()[None, :, :].type(dtype), persistent=False)
|
|
|
|
|
|
|
|
|
|
return (
|
|
|
|
|
self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length].type(dtype),
|
|
|
|
@ -103,25 +61,7 @@ class OptimizedFalconRotaryEmbedding(nn.Module):
|
|
|
|
|
def forward(self, query, key, past_key_values_length=0):
|
|
|
|
|
batch, seq_len, head_dim = query.shape
|
|
|
|
|
cos, sin = self.cos_sin(seq_len, past_key_values_length, query.device, query.dtype)
|
|
|
|
|
# print(cos, sin)
|
|
|
|
|
# if seq_len == 1 and torch.is_inference_mode_enabled():
|
|
|
|
|
# return self._optimized_apply_rotary(query, key, cos, sin)
|
|
|
|
|
# else:
|
|
|
|
|
return apply_rotary(query, key, cos, sin)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# @torch.jit.script
|
|
|
|
|
def split_heads(
|
|
|
|
|
fused_qkv: torch.Tensor, num_heads: int, num_kv_heads: int, head_dim: int
|
|
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
|
|
|
batch, seq_len, _ = fused_qkv.shape
|
|
|
|
|
qkv = fused_qkv.view(batch, seq_len, -1, num_heads // num_kv_heads + 2, head_dim)
|
|
|
|
|
query, key, value = torch.split(qkv, [num_heads // num_kv_heads, 1, 1], dim=3)
|
|
|
|
|
key = torch.broadcast_to(key, query.shape)
|
|
|
|
|
value = torch.broadcast_to(value, query.shape)
|
|
|
|
|
|
|
|
|
|
query, key, value = [x.flatten(2, 3) for x in (query, key, value)]
|
|
|
|
|
return query, key, value
|
|
|
|
|
return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OptimizedFalconAttention(FalconAttention):
|
|
|
|
@ -158,35 +98,6 @@ class OptimizedFalconAttention(FalconAttention):
|
|
|
|
|
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
|
|
|
|
self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1
|
|
|
|
|
|
|
|
|
|
if self.new_decoder_architecture:
|
|
|
|
|
self._split_heads = partial(
|
|
|
|
|
split_heads, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, head_dim=self.head_dim
|
|
|
|
|
)
|
|
|
|
|
self.qkv_graph = None
|
|
|
|
|
self.input_surface = None
|
|
|
|
|
self.static_outputs = None
|
|
|
|
|
|
|
|
|
|
def _optimized_apply_qkv(self, hidden_states):
|
|
|
|
|
if self.qkv_graph is None:
|
|
|
|
|
self.qkv_graph = torch.cuda.CUDAGraph()
|
|
|
|
|
self.input_surface = torch.randn_like(hidden_states)
|
|
|
|
|
|
|
|
|
|
s = torch.cuda.Stream()
|
|
|
|
|
s.wait_stream(torch.cuda.current_stream())
|
|
|
|
|
with torch.cuda.stream(s):
|
|
|
|
|
for _ in range(3):
|
|
|
|
|
fused_qkv = self.query_key_value(self.input_surface)
|
|
|
|
|
self._split_heads(fused_qkv)
|
|
|
|
|
torch.cuda.current_stream().wait_stream(s)
|
|
|
|
|
|
|
|
|
|
with torch.cuda.graph(self.qkv_graph):
|
|
|
|
|
static_fused_qkv = self.query_key_value(self.input_surface)
|
|
|
|
|
self.static_outputs = self._split_heads(static_fused_qkv)
|
|
|
|
|
|
|
|
|
|
self.input_surface.copy_(hidden_states)
|
|
|
|
|
self.qkv_graph.replay()
|
|
|
|
|
return tuple(o.detach() for o in self.static_outputs)
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
hidden_states: torch.Tensor,
|
|
|
|
@ -199,9 +110,6 @@ class OptimizedFalconAttention(FalconAttention):
|
|
|
|
|
):
|
|
|
|
|
assert not output_attentions
|
|
|
|
|
|
|
|
|
|
# if self.new_decoder_architecture and hidden_states.size(1) == 1 and torch.is_inference_mode_enabled():
|
|
|
|
|
# query_layer, key_layer, value_layer = self._optimized_apply_qkv(hidden_states)
|
|
|
|
|
# else:
|
|
|
|
|
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
|
|
|
|
# 3 x [batch_size, seq_length, num_heads, head_dim]
|
|
|
|
|
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
|
|
|
|
@ -320,32 +228,6 @@ class OptimizedFalconDecoderLayer(FalconDecoderLayer):
|
|
|
|
|
self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
|
|
|
|
self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
|
|
|
|
|
|
|
|
|
self.ln_graph = None
|
|
|
|
|
self.static_input = None
|
|
|
|
|
self.static_outputs = None
|
|
|
|
|
|
|
|
|
|
def _optimized_apply_ln(self, hidden_states):
|
|
|
|
|
if self.ln_graph is None:
|
|
|
|
|
self.ln_graph = torch.cuda.CUDAGraph()
|
|
|
|
|
self.static_input = hidden_states
|
|
|
|
|
|
|
|
|
|
s = torch.cuda.Stream()
|
|
|
|
|
s.wait_stream(torch.cuda.current_stream())
|
|
|
|
|
with torch.cuda.stream(s):
|
|
|
|
|
for _ in range(3):
|
|
|
|
|
self.ln_attn(hidden_states)
|
|
|
|
|
self.ln_mlp(hidden_states)
|
|
|
|
|
torch.cuda.current_stream().wait_stream(s)
|
|
|
|
|
|
|
|
|
|
with torch.cuda.graph(self.ln_graph):
|
|
|
|
|
ln_attn_output = self.ln_attn(hidden_states)
|
|
|
|
|
ln_mlp_output = self.ln_mlp(hidden_states)
|
|
|
|
|
self.static_outputs = (ln_attn_output, ln_mlp_output)
|
|
|
|
|
|
|
|
|
|
self.static_input.copy_(hidden_states)
|
|
|
|
|
self.ln_graph.replay()
|
|
|
|
|
return tuple(o.detach() for o in self.static_outputs)
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
hidden_states: torch.Tensor,
|
|
|
|
@ -359,11 +241,8 @@ class OptimizedFalconDecoderLayer(FalconDecoderLayer):
|
|
|
|
|
residual = hidden_states
|
|
|
|
|
|
|
|
|
|
if self.config.new_decoder_architecture:
|
|
|
|
|
if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled():
|
|
|
|
|
attention_layernorm_out, mlp_layernorm_out = self._optimized_apply_ln(hidden_states)
|
|
|
|
|
else:
|
|
|
|
|
attention_layernorm_out = self.ln_attn(hidden_states)
|
|
|
|
|
mlp_layernorm_out = self.ln_mlp(hidden_states)
|
|
|
|
|
attention_layernorm_out = self.ln_attn(hidden_states)
|
|
|
|
|
mlp_layernorm_out = self.ln_mlp(hidden_states)
|
|
|
|
|
else:
|
|
|
|
|
attention_layernorm_out = self.input_layernorm(hidden_states)
|
|
|
|
|
|
|
|
|
|