|
|
@ -12,10 +12,11 @@ from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, Mi
|
|
|
|
|
|
|
|
|
|
|
|
class WrappedMixtralBlock(MixtralDecoderLayer):
|
|
|
|
class WrappedMixtralBlock(MixtralDecoderLayer):
|
|
|
|
def __init__(self, config: MixtralConfig, layer_idx: int):
|
|
|
|
def __init__(self, config: MixtralConfig, layer_idx: int):
|
|
|
|
|
|
|
|
super().__init__(config, layer_idx)
|
|
|
|
|
|
|
|
|
|
|
|
self._attn_implementation = config._attn_implementation
|
|
|
|
self._attn_implementation = config._attn_implementation
|
|
|
|
self.sliding_window = config.sliding_window
|
|
|
|
self.sliding_window = config.sliding_window
|
|
|
|
self.layer_idx = layer_idx
|
|
|
|
self.layer_idx = layer_idx
|
|
|
|
super().__init__(config, layer_idx)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
def forward(
|
|
|
|
self,
|
|
|
|
self,
|
|
|
|