From d6f4f80f3f524dd7d961f409ed3d5b0c8bc957e9 Mon Sep 17 00:00:00 2001 From: Artem Chumachenko Date: Wed, 10 Apr 2024 13:49:50 +0200 Subject: [PATCH] Fix Mixtral-related issues (#570) This PR fixes problems related to #569: - block initialization - throughput calculation and cache usage - mixtral in tests Beam search is removed for Mixtral and Llama for now. Those models use DynamicCache, which requires special function to change: (see https://github.com/huggingface/transformers/blob/main/src/transformers/cache_utils.py#L161) --------- Co-authored-by: Max Ryabinin --- .github/workflows/run-tests.yaml | 2 ++ src/petals/client/remote_generation.py | 2 +- src/petals/models/bloom/block.py | 6 ++++++ src/petals/models/mixtral/block.py | 12 ++++++------ src/petals/models/mixtral/model.py | 21 +++++++++++++++------ src/petals/server/block_utils.py | 17 +++++++++++++++-- src/petals/server/from_pretrained.py | 8 ++------ src/petals/server/throughput.py | 18 +++++++++++++----- src/petals/utils/misc.py | 2 ++ src/petals/utils/peft.py | 4 ++-- tests/test_chained_calls.py | 9 ++++++--- tests/test_full_model.py | 4 ++++ tests/test_optimized_layers.py | 8 +++++--- 13 files changed, 79 insertions(+), 34 deletions(-) diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index 1e76235..d6316d4 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -16,6 +16,8 @@ jobs: - { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.11' } - { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.10' } - { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.11' } + - { model: 'artek0chumak/TestMixtral', os: 'ubuntu', python-version: '3.8' } + - { model: 'artek0chumak/TestMixtral', os: 'ubuntu', python-version: '3.11' } fail-fast: false runs-on: ${{ matrix.os }}-latest timeout-minutes: 20 diff --git a/src/petals/client/remote_generation.py b/src/petals/client/remote_generation.py index 0249c0b..0060ede 100644 --- a/src/petals/client/remote_generation.py +++ b/src/petals/client/remote_generation.py @@ -38,7 +38,7 @@ class RemotePastKeyValues(Cache): self.seen_tokens += new_seen def reorder_cache(self, beam_idx): - pass + raise NotImplementedError("Beam search reordering is not implemented yet") _skipped_tokens = ContextVar("skipped_tokens", default=0) diff --git a/src/petals/models/bloom/block.py b/src/petals/models/bloom/block.py index 86fc4aa..439b9ca 100644 --- a/src/petals/models/bloom/block.py +++ b/src/petals/models/bloom/block.py @@ -9,6 +9,8 @@ import torch from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel, build_alibi_tensor +from petals.utils.misc import is_dummy + class WrappedBloomBlock(BloomBlock): def forward( @@ -22,6 +24,10 @@ class WrappedBloomBlock(BloomBlock): ): assert attention_mask is None, "Non-causal attention masks are not supported yet" batch_size, seq_length = hidden_states.shape[:2] + if layer_past is not None and is_dummy(layer_past[0]): + # Bloom cannot use cache if it was misconsctructed(e.g. Dummy tensors) + # In this case, fallback to the old code: + layer_past = None past_length = 0 if layer_past is None else layer_past[0].shape[-1] seq_length_with_past = seq_length + past_length attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) diff --git a/src/petals/models/mixtral/block.py b/src/petals/models/mixtral/block.py index b90a39b..7a2bd9f 100644 --- a/src/petals/models/mixtral/block.py +++ b/src/petals/models/mixtral/block.py @@ -1,3 +1,4 @@ +import json from typing import Optional, Tuple import torch @@ -33,16 +34,15 @@ class WrappedMixtralBlock(MixtralDecoderLayer): past_key_values_length = 0 past_key_value = layer_past + if past_key_value is not None: past_key_values_length = past_key_value[0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length _past_key_value = self._reorder_cache_from_bloom(past_key_value, batch_size, past_key_values_length) past_key_value = DynamicCache() - for idx in range(self.layer_idx): - past_key_value.update( - torch.empty(_past_key_value[0].size()), torch.empty(_past_key_value[1].size()), idx - ) - past_key_value.update(_past_key_value[0], _past_key_value[1], self.layer_idx) + past_key_value.key_cache = [torch.empty(0) for _ in range(self.layer_idx)] + [_past_key_value[0]] + past_key_value.value_cache = [torch.empty(0) for _ in range(self.layer_idx)] + [_past_key_value[1]] + past_key_value._seen_tokens = past_key_values_length if self._attn_implementation == "flash_attention_2": # 2d mask is passed through the layers @@ -83,7 +83,7 @@ class WrappedMixtralBlock(MixtralDecoderLayer): if use_cache: present_key_value = outputs[-1] - present_key_value = present_key_value.to_legacy_cache()[self.layer_idx] + present_key_value = present_key_value[self.layer_idx] present_key_value = self._reorder_cache_to_bloom(present_key_value, batch_size, seq_length_with_past) outputs = outputs[:-1] + (present_key_value,) diff --git a/src/petals/models/mixtral/model.py b/src/petals/models/mixtral/model.py index 7e127ab..dfbb6c2 100644 --- a/src/petals/models/mixtral/model.py +++ b/src/petals/models/mixtral/model.py @@ -122,14 +122,20 @@ class DistributedMixtralModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMi def word_embeddings(self) -> nn.Embedding: # For compatibility with RemoteGenerationMixin return self.embed_tokens + @property + def word_embeddings_layernorm(self) -> nn.Module: # For compatibility with RemoteGenerationMixin in tests + return nn.Identity() + @property def h(self) -> RemoteSequential: # For compatibility with RemoteGenerationMixin return self.layers + @property + def ln_f(self) -> nn.Module: # For compatibility with RemoteGenerationMixin in tests + return self.norm -class DistributedMixtralForCausalLM( - DefaultRevisionMixin, FromPretrainedMixin, RemoteGenerationMixin, MixtralForCausalLM -): + +class DistributedMixtralForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, MixtralForCausalLM): _keys_to_ignore_on_load_missing = DistributedMixtralModel._keys_to_ignore_on_load_missing _keys_to_ignore_on_load_unexpected = DistributedMixtralModel._keys_to_ignore_on_load_unexpected @@ -151,9 +157,12 @@ class DistributedMixtralForCausalLM( return self.model -class DistributedMixtralForSequenceClassification( - DefaultRevisionMixin, FromPretrainedMixin, MixtralForSequenceClassification -): +class DistributedMixtralForSequenceClassification(FromPretrainedMixin, MixtralForSequenceClassification): + _keys_to_ignore_on_load_missing = DistributedMixtralModel._keys_to_ignore_on_load_missing + _keys_to_ignore_on_load_unexpected = DistributedMixtralModel._keys_to_ignore_on_load_unexpected + + config_class = DistributedMixtralConfig + def __init__(self, config: DistributedMixtralConfig): MixtralPreTrainedModel.__init__(self, config) self.num_labels = config.num_labels diff --git a/src/petals/server/block_utils.py b/src/petals/server/block_utils.py index ac0995d..63ba687 100644 --- a/src/petals/server/block_utils.py +++ b/src/petals/server/block_utils.py @@ -2,8 +2,9 @@ from typing import Optional, Union import torch from accelerate import init_empty_weights -from transformers import PretrainedConfig +from transformers import PretrainedConfig, PreTrainedModel +from petals.models.mixtral.block import WrappedMixtralBlock from petals.utils.convert_block import QuantType from petals.utils.misc import get_size_in_bytes @@ -32,7 +33,7 @@ def get_block_size( ), 'get_block_size(..., location="memory") requires to specify dtype and quant_type for calculations' with init_empty_weights(include_buffers=True): - block = config.block_class(config) + block = get_model_block(config) n_params = sum(param.numel() for param in block.parameters()) if location == "memory": @@ -50,3 +51,15 @@ def get_block_size( bytes_per_value = get_size_in_bytes(dtype) return round(n_params * bytes_per_value * (1 + eps)) + + +def get_model_block(config, layer_idx: int = 0): + """ + The function to create a model block based on the block class + kwargs argument **only** is necessary for specific classes, like Mixtral. + They will not be passed to other block constructors. + """ + if config.block_class == WrappedMixtralBlock: + config = PreTrainedModel._autoset_attn_implementation(config) + return config.block_class(config, layer_idx) + return config.block_class(config) diff --git a/src/petals/server/from_pretrained.py b/src/petals/server/from_pretrained.py index 95cfbd8..4a3b150 100644 --- a/src/petals/server/from_pretrained.py +++ b/src/petals/server/from_pretrained.py @@ -24,7 +24,7 @@ from transformers.utils import get_file_from_repo from petals.constants import DTYPE_MAP from petals.models.mixtral import WrappedMixtralBlock -from petals.server.block_utils import resolve_block_dtype +from petals.server.block_utils import get_model_block, resolve_block_dtype from petals.utils.auto_config import AutoDistributedConfig from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for from petals.utils.hf_auth import always_needs_auth @@ -52,11 +52,7 @@ def load_pretrained_block( torch_dtype = resolve_block_dtype(config, torch_dtype) with init_empty_weights(): - if config.block_class == WrappedMixtralBlock: - config = PreTrainedModel._autoset_attn_implementation(config) - block = config.block_class(config, block_index) - else: - block = config.block_class(config) + block = get_model_block(config, layer_idx=block_index) block_prefix = f"{config.block_prefix}.{block_index}." state_dict = _load_state_dict_from_repo( diff --git a/src/petals/server/throughput.py b/src/petals/server/throughput.py index c42bdb9..d947179 100644 --- a/src/petals/server/throughput.py +++ b/src/petals/server/throughput.py @@ -13,9 +13,10 @@ import torch.mps from hivemind.utils.logging import get_logger from transformers import PretrainedConfig -from petals.server.block_utils import resolve_block_dtype +from petals.server.block_utils import get_model_block, resolve_block_dtype from petals.utils.convert_block import QuantType, convert_block from petals.utils.disk_cache import DEFAULT_CACHE_DIR +from petals.utils.misc import DUMMY_KEY_PAST logger = get_logger(__name__) @@ -201,18 +202,25 @@ def measure_compute_rps( if not tensor_parallel_devices: tensor_parallel_devices = (device,) with torch.inference_mode(): - block = config.block_class(config).to(dtype) + block = get_model_block(config) + block = block.to(dtype) block = convert_block(block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True) - cache = None + cache = (DUMMY_KEY_PAST.to(dtype), DUMMY_KEY_PAST.to(dtype)) elapsed = 0 dummy_input = torch.randn(1, n_tokens, config.hidden_size, device=device, dtype=dtype) - _, cache = block.forward(dummy_input, use_cache=True) # Skip the 1st step to exclude the initialization time + + # Skip the 1st step to exclude the initialization time + def step(cache_): + outputs = block.forward(dummy_input, use_cache=inference, layer_past=cache_ if inference else None) + return outputs[1] if inference else None + + cache = step(cache) synchronize(device) start_time = time.perf_counter() for _ in range(n_steps): - _, cache = block.forward(dummy_input, use_cache=True, layer_past=cache if inference else None) + cache = step(cache) synchronize(device) elapsed = time.perf_counter() - start_time device_rps = n_steps * n_tokens / elapsed diff --git a/src/petals/utils/misc.py b/src/petals/utils/misc.py index d0cfd7c..2d53bab 100644 --- a/src/petals/utils/misc.py +++ b/src/petals/utils/misc.py @@ -4,6 +4,8 @@ DUMMY = torch.empty(0) # dummy tensor that replaces empty prompt or adapter par DUMMY_INT64 = torch.empty(0, dtype=torch.int64) +DUMMY_KEY_PAST = torch.empty((0, 0, 0)) + def is_dummy(tensor: torch.Tensor) -> bool: return tensor.numel() == 0 diff --git a/src/petals/utils/peft.py b/src/petals/utils/peft.py index 5d285ef..149fda4 100644 --- a/src/petals/utils/peft.py +++ b/src/petals/utils/peft.py @@ -17,7 +17,7 @@ from safetensors import safe_open from safetensors.torch import load_file from transformers.utils import get_file_from_repo -from petals.server.block_utils import resolve_block_dtype +from petals.server.block_utils import get_model_block, resolve_block_dtype from petals.utils.convert_block import QuantType from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for from petals.utils.misc import get_size_in_bytes @@ -273,7 +273,7 @@ def estimate_adapter_memory_per_block( ) -> int: """Get the number of extra bytes used to store a set of adapters per given block""" with init_empty_weights(include_buffers=True): - block = block_config.block_class(block_config) + block = get_model_block(block_config) base_block_parameters = sum(p.numel() for p in block.parameters()) create_lora_adapter(block, quant_type=QuantType.NONE) diff --git a/tests/test_chained_calls.py b/tests/test_chained_calls.py index d4b012c..e8b492a 100644 --- a/tests/test_chained_calls.py +++ b/tests/test_chained_calls.py @@ -10,6 +10,7 @@ import torch from petals import AutoDistributedConfig from petals.client.remote_sequential import RemoteSequential from petals.server.from_pretrained import load_pretrained_block +from petals.utils.misc import DUMMY_KEY_PAST from test_utils import * @@ -54,12 +55,14 @@ def test_chained_inference_exact_match(atol_inference=1e-4): outputs_inference.append(sess.step(inputs[:, i : i + 1, :])) outputs_inference = torch.cat(outputs_inference, dim=1) + dtype = torch.float32 ref_blocks = [ - load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch.float32), - load_pretrained_block(MODEL_NAME, 4, torch_dtype=torch.float32), + load_pretrained_block(MODEL_NAME, 3, torch_dtype=dtype), + load_pretrained_block(MODEL_NAME, 4, torch_dtype=dtype), ] outputs_ref = [] - caches = [None, None] + cache = (DUMMY_KEY_PAST.to(dtype), DUMMY_KEY_PAST.to(dtype)) + caches = [cache, cache] for i in range(inputs.shape[1]): new_caches = [] hidden_states = inputs[:, i : i + 1, :] diff --git a/tests/test_full_model.py b/tests/test_full_model.py index bbe6f08..9bced26 100644 --- a/tests/test_full_model.py +++ b/tests/test_full_model.py @@ -141,6 +141,10 @@ def test_sampling(tokenizer, model, ref_model, max_new_tokens=10): ), f"Sampling is not identical to HF with {options=}, {multiple_calls=}, {inputs.shape=}" +@pytest.mark.skipif( + "bloom" not in MODEL_NAME.lower(), + reason="Mixtral and Llama use DynamicCache, which can change based on beam search choices", +) @pytest.mark.forked def test_beam_search_generation(tokenizer, model, ref_model, max_new_tokens=4, num_beams=5): inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"] diff --git a/tests/test_optimized_layers.py b/tests/test_optimized_layers.py index 70f763e..f81a260 100644 --- a/tests/test_optimized_layers.py +++ b/tests/test_optimized_layers.py @@ -7,6 +7,7 @@ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_m from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel +from petals.server.block_utils import get_model_block from petals.utils.auto_config import AutoDistributedConfig from petals.utils.convert_block import QuantType, convert_block from test_utils import MODEL_NAME @@ -195,8 +196,9 @@ def test_optimized_block(device): dtype = torch.bfloat16 quant_type = QuantType.NONE - block = config.block_class(config).to(dtype) - block = convert_block(block, 1, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True) + block_idx = 1 + block = get_model_block(config, layer_idx=block_idx).to(dtype) + block = convert_block(block, block_idx, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True) if config.model_type == "falcon": unopt_block = UnoptimizedWrappedFalconBlock(config).to(dtype) @@ -206,7 +208,7 @@ def test_optimized_block(device): pytest.skip(f"This test is not applicable to {config.model_type} models") unopt_block = convert_block( - unopt_block, 1, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True + unopt_block, block_idx, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True ) unopt_block.load_state_dict(block.state_dict())