Add Mixtral models (#553)
* Add somehow workable version * Fix generation * Fixes * Choose right attn * style * fix bloom * remove unnes * Update src/petals/models/mixtral/model.py Co-authored-by: Max Ryabinin <mryabinin0@gmail.com> * fix order of init --------- Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>pull/557/merge
parent
2ad0b2b936
commit
d2fcbbc72e
@ -1,3 +1,4 @@
|
|||||||
from petals.models.bloom import *
|
from petals.models.bloom import *
|
||||||
from petals.models.falcon import *
|
from petals.models.falcon import *
|
||||||
from petals.models.llama import *
|
from petals.models.llama import *
|
||||||
|
from petals.models.mixtral import *
|
||||||
|
@ -0,0 +1,15 @@
|
|||||||
|
from petals.models.mixtral.block import WrappedMixtralBlock
|
||||||
|
from petals.models.mixtral.config import DistributedMixtralConfig
|
||||||
|
from petals.models.mixtral.model import (
|
||||||
|
DistributedMixtralForCausalLM,
|
||||||
|
DistributedMixtralForSequenceClassification,
|
||||||
|
DistributedMixtralModel,
|
||||||
|
)
|
||||||
|
from petals.utils.auto_config import register_model_classes
|
||||||
|
|
||||||
|
register_model_classes(
|
||||||
|
config=DistributedMixtralConfig,
|
||||||
|
model=DistributedMixtralModel,
|
||||||
|
model_for_causal_lm=DistributedMixtralForCausalLM,
|
||||||
|
model_for_sequence_classification=DistributedMixtralForSequenceClassification,
|
||||||
|
)
|
@ -0,0 +1,114 @@
|
|||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import MixtralConfig
|
||||||
|
from transformers.cache_utils import DynamicCache
|
||||||
|
from transformers.modeling_attn_mask_utils import (
|
||||||
|
_prepare_4d_causal_attention_mask,
|
||||||
|
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||||
|
)
|
||||||
|
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralModel
|
||||||
|
|
||||||
|
|
||||||
|
class WrappedMixtralBlock(MixtralDecoderLayer):
|
||||||
|
def __init__(self, config: MixtralConfig, layer_idx: int):
|
||||||
|
super().__init__(config, layer_idx)
|
||||||
|
|
||||||
|
self._attn_implementation = config._attn_implementation
|
||||||
|
self.sliding_window = config.sliding_window
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
*args,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
use_cache: bool = False,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
batch_size, seq_length, _ = hidden_states.shape
|
||||||
|
|
||||||
|
seq_length_with_past = seq_length
|
||||||
|
past_key_values_length = 0
|
||||||
|
|
||||||
|
past_key_value = layer_past
|
||||||
|
if past_key_value is not None:
|
||||||
|
past_key_values_length = past_key_value[0].shape[2]
|
||||||
|
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||||
|
_past_key_value = self._reorder_cache_from_bloom(past_key_value, batch_size, past_key_values_length)
|
||||||
|
past_key_value = DynamicCache()
|
||||||
|
for idx in range(self.layer_idx):
|
||||||
|
past_key_value.update(
|
||||||
|
torch.empty(_past_key_value[0].size()), torch.empty(_past_key_value[1].size()), idx
|
||||||
|
)
|
||||||
|
past_key_value.update(_past_key_value[0], _past_key_value[1], self.layer_idx)
|
||||||
|
|
||||||
|
if self._attn_implementation == "flash_attention_2":
|
||||||
|
# 2d mask is passed through the layers
|
||||||
|
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||||
|
elif self._attn_implementation == "sdpa":
|
||||||
|
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
||||||
|
# the manual implementation that requires a 4D causal mask in all cases.
|
||||||
|
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||||
|
attention_mask,
|
||||||
|
(batch_size, seq_length),
|
||||||
|
hidden_states,
|
||||||
|
past_key_values_length,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 4d mask is passed through the layers
|
||||||
|
attention_mask = _prepare_4d_causal_attention_mask(
|
||||||
|
attention_mask,
|
||||||
|
(batch_size, seq_length),
|
||||||
|
hidden_states,
|
||||||
|
past_key_values_length,
|
||||||
|
sliding_window=self.sliding_window,
|
||||||
|
)
|
||||||
|
|
||||||
|
position_ids = torch.arange(
|
||||||
|
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=hidden_states.device
|
||||||
|
)
|
||||||
|
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||||
|
|
||||||
|
outputs = super().forward(
|
||||||
|
hidden_states,
|
||||||
|
*args,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
use_cache=use_cache,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
present_key_value = outputs[-1]
|
||||||
|
present_key_value = present_key_value.to_legacy_cache()[self.layer_idx]
|
||||||
|
present_key_value = self._reorder_cache_to_bloom(present_key_value, batch_size, seq_length_with_past)
|
||||||
|
outputs = outputs[:-1] + (present_key_value,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def _reorder_cache_from_bloom(
|
||||||
|
self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
|
||||||
|
) -> Tuple[torch.Tensor]:
|
||||||
|
# TODO: Move to mixin
|
||||||
|
key_states, value_states = key_value
|
||||||
|
key_states = key_states.permute(0, 2, 1)
|
||||||
|
key_states = key_states.view(
|
||||||
|
batch_size, self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
|
||||||
|
)
|
||||||
|
value_states = value_states.view(*key_states.shape)
|
||||||
|
return (key_states, value_states)
|
||||||
|
|
||||||
|
def _reorder_cache_to_bloom(
|
||||||
|
self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
|
||||||
|
) -> Tuple[torch.Tensor]:
|
||||||
|
# TODO: Move to mixin
|
||||||
|
key_states, value_states = key_value
|
||||||
|
value_states = value_states.view(
|
||||||
|
batch_size * self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
|
||||||
|
)
|
||||||
|
key_states = key_states.view(*value_states.shape)
|
||||||
|
key_states = key_states.permute(0, 2, 1)
|
||||||
|
return (key_states, value_states)
|
@ -0,0 +1,36 @@
|
|||||||
|
import os
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from hivemind import get_logger
|
||||||
|
from transformers.models.mixtral import MixtralConfig
|
||||||
|
from transformers.models.mixtral.modeling_mixtral import MixtralAttention
|
||||||
|
|
||||||
|
from petals.client.config import ClientConfig
|
||||||
|
from petals.client.lm_head import LMHeadConfig
|
||||||
|
from petals.client.ptune import PTuneConfig
|
||||||
|
from petals.models.mixtral.block import WrappedMixtralBlock
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedMixtralConfig(MixtralConfig, ClientConfig, PTuneConfig, LMHeadConfig):
|
||||||
|
block_class = WrappedMixtralBlock
|
||||||
|
attn_class = MixtralAttention
|
||||||
|
block_prefix = "model.layers"
|
||||||
|
|
||||||
|
num_key_value_groups = 1
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(
|
||||||
|
cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
|
||||||
|
):
|
||||||
|
loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
|
||||||
|
if loading_from_repo and dht_prefix is None:
|
||||||
|
dht_prefix = str(model_name_or_path)
|
||||||
|
dht_prefix = dht_prefix.replace(".", "-")
|
||||||
|
logger.info(f"Using DHT prefix: {dht_prefix}")
|
||||||
|
result = super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
|
||||||
|
config = result[0] if isinstance(result, tuple) else result
|
||||||
|
if config.pad_token_id is None:
|
||||||
|
config.pad_token_id = 0
|
||||||
|
return result
|
@ -0,0 +1,169 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from hivemind import DHT
|
||||||
|
from hivemind.utils.logging import get_logger
|
||||||
|
from transformers.modeling_outputs import MoeModelOutputWithPast
|
||||||
|
from transformers.models.mixtral import (
|
||||||
|
MixtralForCausalLM,
|
||||||
|
MixtralForSequenceClassification,
|
||||||
|
MixtralModel,
|
||||||
|
MixtralPreTrainedModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
from petals.client.from_pretrained import FromPretrainedMixin
|
||||||
|
from petals.client.lm_head import LMHead
|
||||||
|
from petals.client.ptune import PTuneMixin
|
||||||
|
from petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues
|
||||||
|
from petals.client.remote_sequential import RemoteSequential
|
||||||
|
from petals.models.mixtral.config import DistributedMixtralConfig
|
||||||
|
from petals.utils.auto_config import DefaultRevisionMixin
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedMixtralModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMixin, MixtralModel):
|
||||||
|
"""MixtralModel, but all transformer layers are hosted by the swarm"""
|
||||||
|
|
||||||
|
_keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing
|
||||||
|
_keys_to_ignore_on_load_unexpected = [r"^model\.layers\."]
|
||||||
|
|
||||||
|
config_class = DistributedMixtralConfig
|
||||||
|
|
||||||
|
def __init__(self, config: DistributedMixtralConfig, *, dht: Optional[DHT] = None):
|
||||||
|
n_layer, config.num_hidden_layers = config.num_hidden_layers, 0 # Prevent initialization
|
||||||
|
super().__init__(config)
|
||||||
|
assert len(self.layers) == 0
|
||||||
|
config.num_hidden_layers = n_layer
|
||||||
|
|
||||||
|
self.layers = RemoteSequential(config, dht=dht)
|
||||||
|
|
||||||
|
self.requires_grad_(False) # Forbid accumulate grads for embeddings and layernorm
|
||||||
|
self.init_prompts(config)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[RemotePastKeyValues] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
head_mask: Optional[torch.LongTensor] = None,
|
||||||
|
inputs_embeds: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
output_router_logits: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
):
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
input_shape = input_ids.size()
|
||||||
|
input_ids = input_ids.view(-1, input_shape[-1])
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
input_shape = inputs_embeds.size()[:-1]
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
# The causal mask will be added on the server-side
|
||||||
|
assert (
|
||||||
|
attention_mask is None or (attention_mask == 1).all()
|
||||||
|
), f"Custom attention masks are not supported, {attention_mask=}"
|
||||||
|
assert (
|
||||||
|
position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all()
|
||||||
|
), f"Non-consecutive position_ids are not supported, {position_ids=}"
|
||||||
|
assert head_mask is None, f"Custom head masks are not supported, {head_mask=}"
|
||||||
|
assert use_cache is None or use_cache, f"{use_cache=} is not supported"
|
||||||
|
assert not output_attentions, f"{output_attentions=} is not supported"
|
||||||
|
assert not output_hidden_states, f"{output_hidden_states=} is not supported"
|
||||||
|
assert return_dict is None or return_dict, f"{return_dict=} is not supported"
|
||||||
|
assert not output_router_logits, f"{output_router_logits=} is not supported"
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
use_prompts = self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0
|
||||||
|
if use_prompts:
|
||||||
|
batch_size = inputs_embeds.shape[0]
|
||||||
|
prompts, intermediate_prompts = self.get_prompt(batch_size)
|
||||||
|
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
|
||||||
|
else:
|
||||||
|
prompts = intermediate_prompts = None
|
||||||
|
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
output_shape = input_shape + (hidden_states.size(-1),)
|
||||||
|
|
||||||
|
if past_key_values is None:
|
||||||
|
past_key_values = RemotePastKeyValues()
|
||||||
|
past_key_values.update_seen(hidden_states.size(1))
|
||||||
|
|
||||||
|
hidden_states = self.layers(
|
||||||
|
hidden_states,
|
||||||
|
prompts=intermediate_prompts,
|
||||||
|
hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove prefix
|
||||||
|
if use_prompts:
|
||||||
|
hidden_states = hidden_states[:, self.pre_seq_len :]
|
||||||
|
|
||||||
|
# Add last hidden state
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
hidden_states = hidden_states.view(output_shape)
|
||||||
|
return MoeModelOutputWithPast(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
hidden_states=None,
|
||||||
|
attentions=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def word_embeddings(self) -> nn.Embedding: # For compatibility with RemoteGenerationMixin
|
||||||
|
return self.embed_tokens
|
||||||
|
|
||||||
|
@property
|
||||||
|
def h(self) -> RemoteSequential: # For compatibility with RemoteGenerationMixin
|
||||||
|
return self.layers
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedMixtralForCausalLM(
|
||||||
|
DefaultRevisionMixin, FromPretrainedMixin, RemoteGenerationMixin, MixtralForCausalLM
|
||||||
|
):
|
||||||
|
_keys_to_ignore_on_load_missing = DistributedMixtralModel._keys_to_ignore_on_load_missing
|
||||||
|
_keys_to_ignore_on_load_unexpected = DistributedMixtralModel._keys_to_ignore_on_load_unexpected
|
||||||
|
|
||||||
|
config_class = DistributedMixtralConfig
|
||||||
|
|
||||||
|
def __init__(self, config: DistributedMixtralConfig):
|
||||||
|
MixtralPreTrainedModel.__init__(self, config)
|
||||||
|
self.model = DistributedMixtralModel(config)
|
||||||
|
self.lm_head = LMHead(config)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
return self.lm_head
|
||||||
|
|
||||||
|
@property
|
||||||
|
def transformer(self) -> DistributedMixtralModel: # For compatibility with RemoteGenerationMixin
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedMixtralForSequenceClassification(
|
||||||
|
DefaultRevisionMixin, FromPretrainedMixin, MixtralForSequenceClassification
|
||||||
|
):
|
||||||
|
def __init__(self, config: DistributedMixtralConfig):
|
||||||
|
MixtralPreTrainedModel.__init__(self, config)
|
||||||
|
self.num_labels = config.num_labels
|
||||||
|
|
||||||
|
self.model = DistributedMixtralModel(config)
|
||||||
|
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def transformer(self) -> DistributedMixtralModel: # For compatibility with RemoteGenerationMixin
|
||||||
|
return self.model
|
Loading…
Reference in New Issue