Disable the optimization

no_qkv_merge
Max Ryabinin 9 months ago
parent b2ab84cc33
commit a7f87b636b

@ -114,7 +114,7 @@ class OptimizedLlamaDecoderLayer(LlamaDecoderLayer):
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
class WrappedLlamaBlock(OptimizedLlamaDecoderLayer): class WrappedLlamaBlock(LlamaDecoderLayer):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,

Loading…
Cancel
Save