|
|
|
@ -144,22 +144,12 @@ class OptimizedLlamaDecoderLayer(LlamaDecoderLayer):
|
|
|
|
|
)
|
|
|
|
|
return self.pre_attn_graph(hidden_states)
|
|
|
|
|
|
|
|
|
|
def _post_attn(self, residual, hidden_states):
|
|
|
|
|
hidden_states = residual + hidden_states
|
|
|
|
|
|
|
|
|
|
# Fully Connected
|
|
|
|
|
residual = hidden_states
|
|
|
|
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
|
|
|
hidden_states = self.mlp(hidden_states)
|
|
|
|
|
hidden_states = residual + hidden_states
|
|
|
|
|
return hidden_states
|
|
|
|
|
|
|
|
|
|
def _optimized_post_attn(self, residual, hidden_states):
|
|
|
|
|
def _optimized_output_layernorm(self, hidden_states):
|
|
|
|
|
if self.post_attn_graph is None:
|
|
|
|
|
self.post_attn_graph = make_inference_graphed_callable(
|
|
|
|
|
self._post_attn, sample_args=(residual, hidden_states)
|
|
|
|
|
self.post_attention_layernorm.forward, sample_args=(hidden_states,)
|
|
|
|
|
)
|
|
|
|
|
return self.post_attn_graph(residual, hidden_states)
|
|
|
|
|
return self.post_attn_graph(hidden_states)
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
@ -201,10 +191,18 @@ class OptimizedLlamaDecoderLayer(LlamaDecoderLayer):
|
|
|
|
|
use_cache=use_cache,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
hidden_states = residual + hidden_states
|
|
|
|
|
|
|
|
|
|
# Fully Connected
|
|
|
|
|
residual = hidden_states
|
|
|
|
|
|
|
|
|
|
if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
|
|
|
|
|
hidden_states = self._optimized_post_attn(residual, hidden_states)
|
|
|
|
|
hidden_states = self._optimized_output_layernorm(hidden_states)
|
|
|
|
|
else:
|
|
|
|
|
hidden_states = self._post_attn(residual, hidden_states)
|
|
|
|
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
|
|
|
|
|
|
|
|
hidden_states = self.mlp(hidden_states)
|
|
|
|
|
hidden_states = residual + hidden_states
|
|
|
|
|
|
|
|
|
|
outputs = (hidden_states,)
|
|
|
|
|
|
|
|
|
@ -239,6 +237,8 @@ class WrappedLlamaBlock(OptimizedLlamaDecoderLayer):
|
|
|
|
|
seq_length_with_past = seq_length_with_past + past_key_values_length
|
|
|
|
|
past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length)
|
|
|
|
|
|
|
|
|
|
assert position_ids is None
|
|
|
|
|
|
|
|
|
|
# embed positions
|
|
|
|
|
if attention_mask is None:
|
|
|
|
|
attention_mask = torch.ones(
|
|
|
|
|