|
|
@ -114,7 +114,7 @@ class OptimizedLlamaDecoderLayer(LlamaDecoderLayer):
|
|
|
|
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class WrappedLlamaBlock(OptimizedLlamaDecoderLayer):
|
|
|
|
class WrappedLlamaBlock(LlamaDecoderLayer):
|
|
|
|
def forward(
|
|
|
|
def forward(
|
|
|
|
self,
|
|
|
|
self,
|
|
|
|
hidden_states: torch.Tensor,
|
|
|
|
hidden_states: torch.Tensor,
|
|
|
|