|
|
|
@ -20,13 +20,19 @@ from transformers.models.falcon.modeling_falcon import (
|
|
|
|
|
LayerNorm,
|
|
|
|
|
build_alibi_tensor,
|
|
|
|
|
dropout_add,
|
|
|
|
|
rotate_half,
|
|
|
|
|
FalconRotaryEmbedding,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
|
|
|
|
INFERENCE_MAX_LENGTH = 8192
|
|
|
|
|
|
|
|
|
|
# @torch.jit.script
|
|
|
|
|
def rotate_half(x):
|
|
|
|
|
x1, x2 = torch.chunk(x, 2, dim=2)
|
|
|
|
|
return torch.cat((-x2, x1), dim=-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# @torch.jit.script
|
|
|
|
|
def apply_rotary(query, key, cos, sin):
|
|
|
|
|
return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin)
|
|
|
|
|
|
|
|
|
@ -97,14 +103,15 @@ class OptimizedFalconRotaryEmbedding(nn.Module):
|
|
|
|
|
def forward(self, query, key, past_key_values_length=0):
|
|
|
|
|
batch, seq_len, head_dim = query.shape
|
|
|
|
|
cos, sin = self.cos_sin(seq_len, past_key_values_length, query.device, query.dtype)
|
|
|
|
|
if seq_len == 1 and torch.is_inference_mode_enabled():
|
|
|
|
|
return self._optimized_apply_rotary(query, key, cos, sin)
|
|
|
|
|
else:
|
|
|
|
|
return apply_rotary(query, key, cos, sin)
|
|
|
|
|
|
|
|
|
|
# print(cos, sin)
|
|
|
|
|
# if seq_len == 1 and torch.is_inference_mode_enabled():
|
|
|
|
|
# return self._optimized_apply_rotary(query, key, cos, sin)
|
|
|
|
|
# else:
|
|
|
|
|
return apply_rotary(query, key, cos, sin)
|
|
|
|
|
|
|
|
|
|
# @torch.jit.script
|
|
|
|
|
def split_heads(
|
|
|
|
|
fused_qkv: torch.Tensor, num_heads, num_kv_heads, head_dim
|
|
|
|
|
fused_qkv: torch.Tensor, num_heads:int, num_kv_heads:int, head_dim:int
|
|
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
|
|
|
batch, seq_len, _ = fused_qkv.shape
|
|
|
|
|
qkv = fused_qkv.view(batch, seq_len, -1, num_heads // num_kv_heads + 2, head_dim)
|
|
|
|
@ -161,21 +168,21 @@ class OptimizedFalconAttention(FalconAttention):
|
|
|
|
|
def _optimized_apply_qkv(self, hidden_states):
|
|
|
|
|
if self.qkv_graph is None:
|
|
|
|
|
self.qkv_graph = torch.cuda.CUDAGraph()
|
|
|
|
|
self.static_input = hidden_states
|
|
|
|
|
self.input_surface = torch.randn_like(hidden_states)
|
|
|
|
|
|
|
|
|
|
s = torch.cuda.Stream()
|
|
|
|
|
s.wait_stream(torch.cuda.current_stream())
|
|
|
|
|
with torch.cuda.stream(s):
|
|
|
|
|
for _ in range(3):
|
|
|
|
|
fused_qkv = self.query_key_value(hidden_states)
|
|
|
|
|
fused_qkv = self.query_key_value(self.input_surface)
|
|
|
|
|
self._split_heads(fused_qkv)
|
|
|
|
|
torch.cuda.current_stream().wait_stream(s)
|
|
|
|
|
|
|
|
|
|
with torch.cuda.graph(self.qkv_graph):
|
|
|
|
|
static_fused_qkv = self.query_key_value(hidden_states)
|
|
|
|
|
static_fused_qkv = self.query_key_value(self.input_surface)
|
|
|
|
|
self.static_outputs = self._split_heads(static_fused_qkv)
|
|
|
|
|
|
|
|
|
|
self.static_input.copy_(hidden_states)
|
|
|
|
|
self.input_surface.copy_(hidden_states)
|
|
|
|
|
self.qkv_graph.replay()
|
|
|
|
|
return tuple(o.detach() for o in self.static_outputs)
|
|
|
|
|
|
|
|
|
@ -191,12 +198,12 @@ class OptimizedFalconAttention(FalconAttention):
|
|
|
|
|
):
|
|
|
|
|
assert not output_attentions
|
|
|
|
|
|
|
|
|
|
if self.new_decoder_architecture and hidden_states.size(1) == 1 and torch.is_inference_mode_enabled():
|
|
|
|
|
query_layer, key_layer, value_layer = self._optimized_apply_qkv(hidden_states)
|
|
|
|
|
else:
|
|
|
|
|
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
|
|
|
|
# 3 x [batch_size, seq_length, num_heads, head_dim]
|
|
|
|
|
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
|
|
|
|
|
# if self.new_decoder_architecture and hidden_states.size(1) == 1 and torch.is_inference_mode_enabled():
|
|
|
|
|
# query_layer, key_layer, value_layer = self._optimized_apply_qkv(hidden_states)
|
|
|
|
|
# else:
|
|
|
|
|
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
|
|
|
|
# 3 x [batch_size, seq_length, num_heads, head_dim]
|
|
|
|
|
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
|
|
|
|
|
|
|
|
|
|
num_kv_heads = self.num_heads
|
|
|
|
|
batch_size, query_length, _, _ = query_layer.shape
|
|
|
|
@ -314,6 +321,7 @@ class OptimizedFalconDecoderLayer(FalconDecoderLayer):
|
|
|
|
|
|
|
|
|
|
self.ln_graph = None
|
|
|
|
|
self.static_input = None
|
|
|
|
|
self.static_outputs = None
|
|
|
|
|
|
|
|
|
|
def _optimized_apply_ln(self, hidden_states):
|
|
|
|
|
if self.ln_graph is None:
|
|
|
|
|