From d2fcbbc72e02b88cc34f2da8b3ae7de2873204a9 Mon Sep 17 00:00:00 2001 From: Artem Chumachenko Date: Fri, 29 Mar 2024 12:07:42 +0100 Subject: [PATCH] Add Mixtral models (#553) * Add somehow workable version * Fix generation * Fixes * Choose right attn * style * fix bloom * remove unnes * Update src/petals/models/mixtral/model.py Co-authored-by: Max Ryabinin * fix order of init --------- Co-authored-by: Max Ryabinin --- src/petals/models/__init__.py | 1 + src/petals/models/mixtral/__init__.py | 15 +++ src/petals/models/mixtral/block.py | 114 +++++++++++++++++ src/petals/models/mixtral/config.py | 36 ++++++ src/petals/models/mixtral/model.py | 169 ++++++++++++++++++++++++++ src/petals/server/backend.py | 2 + src/petals/server/from_pretrained.py | 9 +- 7 files changed, 344 insertions(+), 2 deletions(-) create mode 100644 src/petals/models/mixtral/__init__.py create mode 100644 src/petals/models/mixtral/block.py create mode 100644 src/petals/models/mixtral/config.py create mode 100644 src/petals/models/mixtral/model.py diff --git a/src/petals/models/__init__.py b/src/petals/models/__init__.py index f52a429..4966725 100644 --- a/src/petals/models/__init__.py +++ b/src/petals/models/__init__.py @@ -1,3 +1,4 @@ from petals.models.bloom import * from petals.models.falcon import * from petals.models.llama import * +from petals.models.mixtral import * diff --git a/src/petals/models/mixtral/__init__.py b/src/petals/models/mixtral/__init__.py new file mode 100644 index 0000000..0ad85fa --- /dev/null +++ b/src/petals/models/mixtral/__init__.py @@ -0,0 +1,15 @@ +from petals.models.mixtral.block import WrappedMixtralBlock +from petals.models.mixtral.config import DistributedMixtralConfig +from petals.models.mixtral.model import ( + DistributedMixtralForCausalLM, + DistributedMixtralForSequenceClassification, + DistributedMixtralModel, +) +from petals.utils.auto_config import register_model_classes + +register_model_classes( + config=DistributedMixtralConfig, + model=DistributedMixtralModel, + model_for_causal_lm=DistributedMixtralForCausalLM, + model_for_sequence_classification=DistributedMixtralForSequenceClassification, +) diff --git a/src/petals/models/mixtral/block.py b/src/petals/models/mixtral/block.py new file mode 100644 index 0000000..b90a39b --- /dev/null +++ b/src/petals/models/mixtral/block.py @@ -0,0 +1,114 @@ +from typing import Optional, Tuple + +import torch +from transformers import MixtralConfig +from transformers.cache_utils import DynamicCache +from transformers.modeling_attn_mask_utils import ( + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralModel + + +class WrappedMixtralBlock(MixtralDecoderLayer): + def __init__(self, config: MixtralConfig, layer_idx: int): + super().__init__(config, layer_idx) + + self._attn_implementation = config._attn_implementation + self.sliding_window = config.sliding_window + self.layer_idx = layer_idx + + def forward( + self, + hidden_states: torch.Tensor, + *args, + attention_mask: Optional[torch.Tensor] = None, + layer_past: Optional[Tuple[torch.Tensor]] = None, + use_cache: bool = False, + **kwargs + ): + batch_size, seq_length, _ = hidden_states.shape + + seq_length_with_past = seq_length + past_key_values_length = 0 + + past_key_value = layer_past + if past_key_value is not None: + past_key_values_length = past_key_value[0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + _past_key_value = self._reorder_cache_from_bloom(past_key_value, batch_size, past_key_values_length) + past_key_value = DynamicCache() + for idx in range(self.layer_idx): + past_key_value.update( + torch.empty(_past_key_value[0].size()), torch.empty(_past_key_value[1].size()), idx + ) + past_key_value.update(_past_key_value[0], _past_key_value[1], self.layer_idx) + + if self._attn_implementation == "flash_attention_2": + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._attn_implementation == "sdpa": + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + hidden_states, + past_key_values_length, + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + hidden_states, + past_key_values_length, + sliding_window=self.sliding_window, + ) + + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=hidden_states.device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + + outputs = super().forward( + hidden_states, + *args, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + use_cache=use_cache, + **kwargs + ) + + if use_cache: + present_key_value = outputs[-1] + present_key_value = present_key_value.to_legacy_cache()[self.layer_idx] + present_key_value = self._reorder_cache_to_bloom(present_key_value, batch_size, seq_length_with_past) + outputs = outputs[:-1] + (present_key_value,) + + return outputs + + def _reorder_cache_from_bloom( + self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int + ) -> Tuple[torch.Tensor]: + # TODO: Move to mixin + key_states, value_states = key_value + key_states = key_states.permute(0, 2, 1) + key_states = key_states.view( + batch_size, self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim + ) + value_states = value_states.view(*key_states.shape) + return (key_states, value_states) + + def _reorder_cache_to_bloom( + self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int + ) -> Tuple[torch.Tensor]: + # TODO: Move to mixin + key_states, value_states = key_value + value_states = value_states.view( + batch_size * self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim + ) + key_states = key_states.view(*value_states.shape) + key_states = key_states.permute(0, 2, 1) + return (key_states, value_states) diff --git a/src/petals/models/mixtral/config.py b/src/petals/models/mixtral/config.py new file mode 100644 index 0000000..a93c8df --- /dev/null +++ b/src/petals/models/mixtral/config.py @@ -0,0 +1,36 @@ +import os +from typing import Optional, Union + +from hivemind import get_logger +from transformers.models.mixtral import MixtralConfig +from transformers.models.mixtral.modeling_mixtral import MixtralAttention + +from petals.client.config import ClientConfig +from petals.client.lm_head import LMHeadConfig +from petals.client.ptune import PTuneConfig +from petals.models.mixtral.block import WrappedMixtralBlock + +logger = get_logger(__name__) + + +class DistributedMixtralConfig(MixtralConfig, ClientConfig, PTuneConfig, LMHeadConfig): + block_class = WrappedMixtralBlock + attn_class = MixtralAttention + block_prefix = "model.layers" + + num_key_value_groups = 1 + + @classmethod + def from_pretrained( + cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs + ): + loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path) + if loading_from_repo and dht_prefix is None: + dht_prefix = str(model_name_or_path) + dht_prefix = dht_prefix.replace(".", "-") + logger.info(f"Using DHT prefix: {dht_prefix}") + result = super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs) + config = result[0] if isinstance(result, tuple) else result + if config.pad_token_id is None: + config.pad_token_id = 0 + return result diff --git a/src/petals/models/mixtral/model.py b/src/petals/models/mixtral/model.py new file mode 100644 index 0000000..7e127ab --- /dev/null +++ b/src/petals/models/mixtral/model.py @@ -0,0 +1,169 @@ +from typing import Optional + +import torch +import torch.nn as nn +from hivemind import DHT +from hivemind.utils.logging import get_logger +from transformers.modeling_outputs import MoeModelOutputWithPast +from transformers.models.mixtral import ( + MixtralForCausalLM, + MixtralForSequenceClassification, + MixtralModel, + MixtralPreTrainedModel, +) + +from petals.client.from_pretrained import FromPretrainedMixin +from petals.client.lm_head import LMHead +from petals.client.ptune import PTuneMixin +from petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues +from petals.client.remote_sequential import RemoteSequential +from petals.models.mixtral.config import DistributedMixtralConfig +from petals.utils.auto_config import DefaultRevisionMixin + +logger = get_logger(__name__) + + +class DistributedMixtralModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMixin, MixtralModel): + """MixtralModel, but all transformer layers are hosted by the swarm""" + + _keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing + _keys_to_ignore_on_load_unexpected = [r"^model\.layers\."] + + config_class = DistributedMixtralConfig + + def __init__(self, config: DistributedMixtralConfig, *, dht: Optional[DHT] = None): + n_layer, config.num_hidden_layers = config.num_hidden_layers, 0 # Prevent initialization + super().__init__(config) + assert len(self.layers) == 0 + config.num_hidden_layers = n_layer + + self.layers = RemoteSequential(config, dht=dht) + + self.requires_grad_(False) # Forbid accumulate grads for embeddings and layernorm + self.init_prompts(config) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[RemotePastKeyValues] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + # The causal mask will be added on the server-side + assert ( + attention_mask is None or (attention_mask == 1).all() + ), f"Custom attention masks are not supported, {attention_mask=}" + assert ( + position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all() + ), f"Non-consecutive position_ids are not supported, {position_ids=}" + assert head_mask is None, f"Custom head masks are not supported, {head_mask=}" + assert use_cache is None or use_cache, f"{use_cache=} is not supported" + assert not output_attentions, f"{output_attentions=} is not supported" + assert not output_hidden_states, f"{output_hidden_states=} is not supported" + assert return_dict is None or return_dict, f"{return_dict=} is not supported" + assert not output_router_logits, f"{output_router_logits=} is not supported" + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + use_prompts = self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0 + if use_prompts: + batch_size = inputs_embeds.shape[0] + prompts, intermediate_prompts = self.get_prompt(batch_size) + inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1) + else: + prompts = intermediate_prompts = None + + hidden_states = inputs_embeds + output_shape = input_shape + (hidden_states.size(-1),) + + if past_key_values is None: + past_key_values = RemotePastKeyValues() + past_key_values.update_seen(hidden_states.size(1)) + + hidden_states = self.layers( + hidden_states, + prompts=intermediate_prompts, + hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None, + ) + + # Remove prefix + if use_prompts: + hidden_states = hidden_states[:, self.pre_seq_len :] + + # Add last hidden state + hidden_states = self.norm(hidden_states) + hidden_states = hidden_states.view(output_shape) + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=None, + attentions=None, + ) + + @property + def word_embeddings(self) -> nn.Embedding: # For compatibility with RemoteGenerationMixin + return self.embed_tokens + + @property + def h(self) -> RemoteSequential: # For compatibility with RemoteGenerationMixin + return self.layers + + +class DistributedMixtralForCausalLM( + DefaultRevisionMixin, FromPretrainedMixin, RemoteGenerationMixin, MixtralForCausalLM +): + _keys_to_ignore_on_load_missing = DistributedMixtralModel._keys_to_ignore_on_load_missing + _keys_to_ignore_on_load_unexpected = DistributedMixtralModel._keys_to_ignore_on_load_unexpected + + config_class = DistributedMixtralConfig + + def __init__(self, config: DistributedMixtralConfig): + MixtralPreTrainedModel.__init__(self, config) + self.model = DistributedMixtralModel(config) + self.lm_head = LMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head + + @property + def transformer(self) -> DistributedMixtralModel: # For compatibility with RemoteGenerationMixin + return self.model + + +class DistributedMixtralForSequenceClassification( + DefaultRevisionMixin, FromPretrainedMixin, MixtralForSequenceClassification +): + def __init__(self, config: DistributedMixtralConfig): + MixtralPreTrainedModel.__init__(self, config) + self.num_labels = config.num_labels + + self.model = DistributedMixtralModel(config) + self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @property + def transformer(self) -> DistributedMixtralModel: # For compatibility with RemoteGenerationMixin + return self.model diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index 3a9b63e..fe1c474 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -91,6 +91,8 @@ class TransformerBackend(ModuleBackend): cache_tensors = [] for device, num_heads in zip(self.module.devices, self.shard_num_heads): num_heads //= self.config.num_key_value_groups + if hasattr(self.config, "num_key_value_heads"): + num_heads = self.config.num_key_value_heads keys = TensorDescriptor((batch_size, num_heads, head_dim, max_length), dtype=self.dtype, device=device) values = TensorDescriptor((batch_size, num_heads, max_length, head_dim), dtype=self.dtype, device=device) cache_tensors.extend((keys, values)) diff --git a/src/petals/server/from_pretrained.py b/src/petals/server/from_pretrained.py index 73956fe..95cfbd8 100644 --- a/src/petals/server/from_pretrained.py +++ b/src/petals/server/from_pretrained.py @@ -19,10 +19,11 @@ from accelerate.utils import set_module_tensor_to_device from hivemind.utils.logging import get_logger from huggingface_hub import get_hf_file_metadata, hf_hub_url from huggingface_hub.utils import EntryNotFoundError -from transformers import PretrainedConfig +from transformers import PretrainedConfig, PreTrainedModel from transformers.utils import get_file_from_repo from petals.constants import DTYPE_MAP +from petals.models.mixtral import WrappedMixtralBlock from petals.server.block_utils import resolve_block_dtype from petals.utils.auto_config import AutoDistributedConfig from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for @@ -51,7 +52,11 @@ def load_pretrained_block( torch_dtype = resolve_block_dtype(config, torch_dtype) with init_empty_weights(): - block = config.block_class(config) + if config.block_class == WrappedMixtralBlock: + config = PreTrainedModel._autoset_attn_implementation(config) + block = config.block_class(config, block_index) + else: + block = config.block_class(config) block_prefix = f"{config.block_prefix}.{block_index}." state_dict = _load_state_dict_from_repo(