Disable the optimization

no_qkv_merge
Max Ryabinin 8 months ago
parent b2ab84cc33
commit a7f87b636b

@ -114,7 +114,7 @@ class OptimizedLlamaDecoderLayer(LlamaDecoderLayer):
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
class WrappedLlamaBlock(OptimizedLlamaDecoderLayer):
class WrappedLlamaBlock(LlamaDecoderLayer):
def forward(
self,
hidden_states: torch.Tensor,

Loading…
Cancel
Save