remove unnes

pull/553/head
Artem Chumachenko 3 months ago
parent 81a5e70c89
commit 866927d88c

@ -1,8 +1,8 @@
from typing import Optional
import hivemind
import torch
import torch.nn as nn
from hivemind import DHT
from hivemind.utils.logging import get_logger
from transformers.modeling_outputs import MoeModelOutputWithPast
from transformers.models.mixtral import (
@ -31,7 +31,7 @@ class DistributedMixtralModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMi
config_class = DistributedMixtralConfig
def __init__(self, config: DistributedMixtralConfig, *, dht: Optional[hivemind.DHT] = None):
def __init__(self, config: DistributedMixtralConfig, *, dht: Optional[DHT] = None):
n_layer, config.num_hidden_layers = config.num_hidden_layers, 0 # Prevent initialization
super().__init__(config)
assert len(self.layers) == 0
@ -122,18 +122,10 @@ class DistributedMixtralModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMi
def word_embeddings(self) -> nn.Embedding: # For compatibility with RemoteGenerationMixin
return self.embed_tokens
@property
def word_embeddings_layernorm(self) -> nn.Module: # For compatibility with RemoteGenerationMixin
return nn.Identity()
@property
def h(self) -> RemoteSequential: # For compatibility with RemoteGenerationMixin
return self.layers
@property
def ln_f(self) -> nn.Module: # For compatibility with RemoteGenerationMixin
return self.norm
class DistributedMixtralForCausalLM(
DefaultRevisionMixin, FromPretrainedMixin, RemoteGenerationMixin, MixtralForCausalLM

Loading…
Cancel
Save