mirror of
https://github.com/bigscience-workshop/petals
synced 2024-10-31 09:20:41 +00:00
remove unnes
This commit is contained in:
parent
81a5e70c89
commit
866927d88c
@ -1,8 +1,8 @@
|
||||
from typing import Optional
|
||||
|
||||
import hivemind
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from hivemind import DHT
|
||||
from hivemind.utils.logging import get_logger
|
||||
from transformers.modeling_outputs import MoeModelOutputWithPast
|
||||
from transformers.models.mixtral import (
|
||||
@ -31,7 +31,7 @@ class DistributedMixtralModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMi
|
||||
|
||||
config_class = DistributedMixtralConfig
|
||||
|
||||
def __init__(self, config: DistributedMixtralConfig, *, dht: Optional[hivemind.DHT] = None):
|
||||
def __init__(self, config: DistributedMixtralConfig, *, dht: Optional[DHT] = None):
|
||||
n_layer, config.num_hidden_layers = config.num_hidden_layers, 0 # Prevent initialization
|
||||
super().__init__(config)
|
||||
assert len(self.layers) == 0
|
||||
@ -122,18 +122,10 @@ class DistributedMixtralModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMi
|
||||
def word_embeddings(self) -> nn.Embedding: # For compatibility with RemoteGenerationMixin
|
||||
return self.embed_tokens
|
||||
|
||||
@property
|
||||
def word_embeddings_layernorm(self) -> nn.Module: # For compatibility with RemoteGenerationMixin
|
||||
return nn.Identity()
|
||||
|
||||
@property
|
||||
def h(self) -> RemoteSequential: # For compatibility with RemoteGenerationMixin
|
||||
return self.layers
|
||||
|
||||
@property
|
||||
def ln_f(self) -> nn.Module: # For compatibility with RemoteGenerationMixin
|
||||
return self.norm
|
||||
|
||||
|
||||
class DistributedMixtralForCausalLM(
|
||||
DefaultRevisionMixin, FromPretrainedMixin, RemoteGenerationMixin, MixtralForCausalLM
|
||||
|
Loading…
Reference in New Issue
Block a user