|
|
|
@ -1,8 +1,8 @@
|
|
|
|
|
from typing import Optional
|
|
|
|
|
|
|
|
|
|
import hivemind
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
from hivemind import DHT
|
|
|
|
|
from hivemind.utils.logging import get_logger
|
|
|
|
|
from transformers.modeling_outputs import MoeModelOutputWithPast
|
|
|
|
|
from transformers.models.mixtral import (
|
|
|
|
@ -31,7 +31,7 @@ class DistributedMixtralModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMi
|
|
|
|
|
|
|
|
|
|
config_class = DistributedMixtralConfig
|
|
|
|
|
|
|
|
|
|
def __init__(self, config: DistributedMixtralConfig, *, dht: Optional[hivemind.DHT] = None):
|
|
|
|
|
def __init__(self, config: DistributedMixtralConfig, *, dht: Optional[DHT] = None):
|
|
|
|
|
n_layer, config.num_hidden_layers = config.num_hidden_layers, 0 # Prevent initialization
|
|
|
|
|
super().__init__(config)
|
|
|
|
|
assert len(self.layers) == 0
|
|
|
|
@ -122,18 +122,10 @@ class DistributedMixtralModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMi
|
|
|
|
|
def word_embeddings(self) -> nn.Embedding: # For compatibility with RemoteGenerationMixin
|
|
|
|
|
return self.embed_tokens
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def word_embeddings_layernorm(self) -> nn.Module: # For compatibility with RemoteGenerationMixin
|
|
|
|
|
return nn.Identity()
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def h(self) -> RemoteSequential: # For compatibility with RemoteGenerationMixin
|
|
|
|
|
return self.layers
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def ln_f(self) -> nn.Module: # For compatibility with RemoteGenerationMixin
|
|
|
|
|
return self.norm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DistributedMixtralForCausalLM(
|
|
|
|
|
DefaultRevisionMixin, FromPretrainedMixin, RemoteGenerationMixin, MixtralForCausalLM
|
|
|
|
|