Disable the optimization

This commit is contained in:
Max Ryabinin 2023-09-03 00:49:23 +03:00
parent b2ab84cc33
commit a7f87b636b

View File

@ -114,7 +114,7 @@ class OptimizedLlamaDecoderLayer(LlamaDecoderLayer):
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
class WrappedLlamaBlock(OptimizedLlamaDecoderLayer): class WrappedLlamaBlock(LlamaDecoderLayer):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,