|
|
|
@ -148,29 +148,27 @@ class OptimizedFalconAttention(FalconAttention):
|
|
|
|
|
self._split_heads = partial(
|
|
|
|
|
split_heads, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, head_dim=self.head_dim
|
|
|
|
|
)
|
|
|
|
|
self.qkv_graph = None
|
|
|
|
|
self.split_graph = None
|
|
|
|
|
self.input_surface = None
|
|
|
|
|
self.static_outputs = None
|
|
|
|
|
|
|
|
|
|
def _optimized_apply_qkv(self, hidden_states):
|
|
|
|
|
if self.qkv_graph is None:
|
|
|
|
|
self.qkv_graph = torch.cuda.CUDAGraph()
|
|
|
|
|
self.input_surface = torch.randn_like(hidden_states)
|
|
|
|
|
def _optimized_split_heads(self, fused_qkv):
|
|
|
|
|
if self.split_graph is None:
|
|
|
|
|
self.split_graph = torch.cuda.CUDAGraph()
|
|
|
|
|
self.input_surface = fused_qkv
|
|
|
|
|
|
|
|
|
|
s = torch.cuda.Stream()
|
|
|
|
|
s.wait_stream(torch.cuda.current_stream())
|
|
|
|
|
with torch.cuda.stream(s):
|
|
|
|
|
for _ in range(3):
|
|
|
|
|
fused_qkv = self.query_key_value(self.input_surface)
|
|
|
|
|
self._split_heads(fused_qkv)
|
|
|
|
|
torch.cuda.current_stream().wait_stream(s)
|
|
|
|
|
|
|
|
|
|
with torch.cuda.graph(self.qkv_graph):
|
|
|
|
|
static_fused_qkv = self.query_key_value(self.input_surface)
|
|
|
|
|
self.static_outputs = self._split_heads(static_fused_qkv)
|
|
|
|
|
with torch.cuda.graph(self.split_graph):
|
|
|
|
|
self.static_outputs = self._split_heads(self.input_surface)
|
|
|
|
|
|
|
|
|
|
self.input_surface.copy_(hidden_states)
|
|
|
|
|
self.qkv_graph.replay()
|
|
|
|
|
self.input_surface.copy_(fused_qkv)
|
|
|
|
|
self.split_graph.replay()
|
|
|
|
|
return tuple(o.detach() for o in self.static_outputs)
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
@ -185,15 +183,16 @@ class OptimizedFalconAttention(FalconAttention):
|
|
|
|
|
):
|
|
|
|
|
assert not output_attentions
|
|
|
|
|
|
|
|
|
|
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
|
|
|
|
|
|
|
|
|
if (
|
|
|
|
|
self.new_decoder_architecture
|
|
|
|
|
and hidden_states.size(1) == 1
|
|
|
|
|
and torch.is_inference_mode_enabled()
|
|
|
|
|
and hidden_states.device.type == "cuda"
|
|
|
|
|
):
|
|
|
|
|
query_layer, key_layer, value_layer = self._optimized_apply_qkv(hidden_states)
|
|
|
|
|
query_layer, key_layer, value_layer = self._optimized_split_heads(fused_qkv)
|
|
|
|
|
else:
|
|
|
|
|
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
|
|
|
|
# 3 x [batch_size, seq_length, num_heads, head_dim]
|
|
|
|
|
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
|
|
|
|
|
|
|
|
|
|