Enable CUDA graphs only on CUDA

pull/500/head
Max Ryabinin 9 months ago
parent 91f6248535
commit ea6c037c8b

@ -4,6 +4,7 @@ Based on https://github.com/huggingface/transformers/blob/main/src/transformers/
See commit history for authorship.
"""
import math
from functools import partial
from typing import Optional, Tuple
import torch
@ -26,6 +27,10 @@ KVCache = Tuple[torch.Tensor, torch.Tensor]
INFERENCE_MAX_LENGTH = 8192
def apply_rotary(query, key, cos, sin):
return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin)
class OptimizedFalconRotaryEmbedding(nn.Module):
def __init__(self, head_dim: int, base=10000):
super().__init__()
@ -34,6 +39,31 @@ class OptimizedFalconRotaryEmbedding(nn.Module):
self.head_dim = head_dim
self.seq_len_cached = -1
self.cuda_graph = None
self.input_surface = None
self.static_outputs = None
def _optimized_apply_rotary(self, query, key, cos, sin):
if self.cuda_graph is None:
self.cuda_graph = torch.cuda.CUDAGraph()
self.input_surface = (query, key, cos, sin)
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3):
apply_rotary(*self.input_surface)
torch.cuda.current_stream().wait_stream(s)
with torch.cuda.graph(self.cuda_graph):
self.static_outputs = apply_rotary(*self.input_surface)
inputs = (query, key, cos, sin)
for static_input, data in zip(self.input_surface, inputs):
static_input.copy_(data)
self.cuda_graph.replay()
return tuple(o.detach() for o in self.static_outputs)
def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor:
total_length = seq_len + past_key_values_length
if self.seq_len_cached == -1:
@ -61,7 +91,23 @@ class OptimizedFalconRotaryEmbedding(nn.Module):
def forward(self, query, key, past_key_values_length=0):
batch, seq_len, head_dim = query.shape
cos, sin = self.cos_sin(seq_len, past_key_values_length, query.device, query.dtype)
return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin)
if seq_len == 1 and torch.is_inference_mode_enabled() and query.device.type == "cuda":
return self._optimized_apply_rotary(query, key, cos, sin)
else:
return apply_rotary(query, key, cos, sin)
def split_heads(
fused_qkv: torch.Tensor, num_heads: int, num_kv_heads: int, head_dim: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
batch, seq_len, _ = fused_qkv.shape
qkv = fused_qkv.view(batch, seq_len, -1, num_heads // num_kv_heads + 2, head_dim)
query, key, value = torch.split(qkv, [num_heads // num_kv_heads, 1, 1], dim=3)
key = torch.broadcast_to(key, query.shape)
value = torch.broadcast_to(value, query.shape)
query, key, value = [x.flatten(2, 3) for x in (query, key, value)]
return query, key, value
class OptimizedFalconAttention(FalconAttention):
@ -98,6 +144,35 @@ class OptimizedFalconAttention(FalconAttention):
self.attention_dropout = nn.Dropout(config.attention_dropout)
self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1
if self.new_decoder_architecture:
self._split_heads = partial(
split_heads, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, head_dim=self.head_dim
)
self.qkv_graph = None
self.input_surface = None
self.static_outputs = None
def _optimized_apply_qkv(self, hidden_states):
if self.qkv_graph is None:
self.qkv_graph = torch.cuda.CUDAGraph()
self.input_surface = torch.randn_like(hidden_states)
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3):
fused_qkv = self.query_key_value(self.input_surface)
self._split_heads(fused_qkv)
torch.cuda.current_stream().wait_stream(s)
with torch.cuda.graph(self.qkv_graph):
static_fused_qkv = self.query_key_value(self.input_surface)
self.static_outputs = self._split_heads(static_fused_qkv)
self.input_surface.copy_(hidden_states)
self.qkv_graph.replay()
return tuple(o.detach() for o in self.static_outputs)
def forward(
self,
hidden_states: torch.Tensor,
@ -110,9 +185,17 @@ class OptimizedFalconAttention(FalconAttention):
):
assert not output_attentions
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
if (
self.new_decoder_architecture
and hidden_states.size(1) == 1
and torch.is_inference_mode_enabled()
and hidden_states.device.type == "cuda"
):
query_layer, key_layer, value_layer = self._optimized_apply_qkv(hidden_states)
else:
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
num_kv_heads = self.num_heads
batch_size, query_length, _, _ = query_layer.shape
@ -228,6 +311,32 @@ class OptimizedFalconDecoderLayer(FalconDecoderLayer):
self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.ln_graph = None
self.static_input = None
self.static_outputs = None
def _optimized_apply_ln(self, hidden_states):
if self.ln_graph is None:
self.ln_graph = torch.cuda.CUDAGraph()
self.static_input = hidden_states
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3):
self.ln_attn(hidden_states)
self.ln_mlp(hidden_states)
torch.cuda.current_stream().wait_stream(s)
with torch.cuda.graph(self.ln_graph):
ln_attn_output = self.ln_attn(hidden_states)
ln_mlp_output = self.ln_mlp(hidden_states)
self.static_outputs = (ln_attn_output, ln_mlp_output)
self.static_input.copy_(hidden_states)
self.ln_graph.replay()
return tuple(o.detach() for o in self.static_outputs)
def forward(
self,
hidden_states: torch.Tensor,
@ -241,8 +350,11 @@ class OptimizedFalconDecoderLayer(FalconDecoderLayer):
residual = hidden_states
if self.config.new_decoder_architecture:
attention_layernorm_out = self.ln_attn(hidden_states)
mlp_layernorm_out = self.ln_mlp(hidden_states)
if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
attention_layernorm_out, mlp_layernorm_out = self._optimized_apply_ln(hidden_states)
else:
attention_layernorm_out = self.ln_attn(hidden_states)
mlp_layernorm_out = self.ln_mlp(hidden_states)
else:
attention_layernorm_out = self.input_layernorm(hidden_states)

Loading…
Cancel
Save