WIP disable graphs

pull/500/head
Max Ryabinin 9 months ago
parent cfaf6c1975
commit 1f2ef79da3

@ -20,13 +20,19 @@ from transformers.models.falcon.modeling_falcon import (
LayerNorm,
build_alibi_tensor,
dropout_add,
rotate_half,
FalconRotaryEmbedding,
)
KVCache = Tuple[torch.Tensor, torch.Tensor]
INFERENCE_MAX_LENGTH = 8192
# @torch.jit.script
def rotate_half(x):
x1, x2 = torch.chunk(x, 2, dim=2)
return torch.cat((-x2, x1), dim=-1)
# @torch.jit.script
def apply_rotary(query, key, cos, sin):
return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin)
@ -97,14 +103,15 @@ class OptimizedFalconRotaryEmbedding(nn.Module):
def forward(self, query, key, past_key_values_length=0):
batch, seq_len, head_dim = query.shape
cos, sin = self.cos_sin(seq_len, past_key_values_length, query.device, query.dtype)
if seq_len == 1 and torch.is_inference_mode_enabled():
return self._optimized_apply_rotary(query, key, cos, sin)
else:
return apply_rotary(query, key, cos, sin)
# print(cos, sin)
# if seq_len == 1 and torch.is_inference_mode_enabled():
# return self._optimized_apply_rotary(query, key, cos, sin)
# else:
return apply_rotary(query, key, cos, sin)
# @torch.jit.script
def split_heads(
fused_qkv: torch.Tensor, num_heads, num_kv_heads, head_dim
fused_qkv: torch.Tensor, num_heads:int, num_kv_heads:int, head_dim:int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
batch, seq_len, _ = fused_qkv.shape
qkv = fused_qkv.view(batch, seq_len, -1, num_heads // num_kv_heads + 2, head_dim)
@ -161,21 +168,21 @@ class OptimizedFalconAttention(FalconAttention):
def _optimized_apply_qkv(self, hidden_states):
if self.qkv_graph is None:
self.qkv_graph = torch.cuda.CUDAGraph()
self.static_input = hidden_states
self.input_surface = torch.randn_like(hidden_states)
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3):
fused_qkv = self.query_key_value(hidden_states)
fused_qkv = self.query_key_value(self.input_surface)
self._split_heads(fused_qkv)
torch.cuda.current_stream().wait_stream(s)
with torch.cuda.graph(self.qkv_graph):
static_fused_qkv = self.query_key_value(hidden_states)
static_fused_qkv = self.query_key_value(self.input_surface)
self.static_outputs = self._split_heads(static_fused_qkv)
self.static_input.copy_(hidden_states)
self.input_surface.copy_(hidden_states)
self.qkv_graph.replay()
return tuple(o.detach() for o in self.static_outputs)
@ -191,12 +198,12 @@ class OptimizedFalconAttention(FalconAttention):
):
assert not output_attentions
if self.new_decoder_architecture and hidden_states.size(1) == 1 and torch.is_inference_mode_enabled():
query_layer, key_layer, value_layer = self._optimized_apply_qkv(hidden_states)
else:
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
# if self.new_decoder_architecture and hidden_states.size(1) == 1 and torch.is_inference_mode_enabled():
# query_layer, key_layer, value_layer = self._optimized_apply_qkv(hidden_states)
# else:
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
num_kv_heads = self.num_heads
batch_size, query_length, _, _ = query_layer.shape
@ -314,6 +321,7 @@ class OptimizedFalconDecoderLayer(FalconDecoderLayer):
self.ln_graph = None
self.static_input = None
self.static_outputs = None
def _optimized_apply_ln(self, hidden_states):
if self.ln_graph is None:

@ -116,9 +116,10 @@ def test_falcon():
cache = unopt_cache = None
for l in range(3):
dummy_input = torch.randn(1, 1, config.hidden_size, device=device, dtype=dtype)
block_output, cache = block(dummy_input, layer_past=cache, use_cache=True)
unopt_block_output, unopt_cache = unopt_block(dummy_input, layer_past=unopt_cache, use_cache=True)
assert torch.allclose(block_output, unopt_block_output, atol=1e-6, rtol=0), l
assert torch.allclose(cache[0], unopt_cache[0], atol=1e-6, rtol=0), l
assert torch.allclose(cache[1], unopt_cache[1], atol=1e-6, rtol=0), l
with torch.inference_mode():
dummy_input = torch.randn(1, 1, config.hidden_size, device=device, dtype=dtype)
block_output, cache = block(dummy_input, layer_past=cache, use_cache=True)
unopt_block_output, unopt_cache = unopt_block(dummy_input, layer_past=unopt_cache, use_cache=True)
assert torch.allclose(block_output, unopt_block_output, atol=1e-6, rtol=0), l
assert torch.allclose(cache[0], unopt_cache[0], atol=1e-6, rtol=0), l
assert torch.allclose(cache[1], unopt_cache[1], atol=1e-6, rtol=0), l

Loading…
Cancel
Save