Fix Mixtral-related issues (#570)

This PR fixes problems related to #569:
- block initialization
- throughput calculation and cache usage
- mixtral in tests

Beam search is removed for Mixtral and Llama for now. Those models use DynamicCache, which requires special function to change: (see https://github.com/huggingface/transformers/blob/main/src/transformers/cache_utils.py#L161)

---------

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
pull/560/merge
Artem Chumachenko 3 weeks ago committed by GitHub
parent d2fcbbc72e
commit d6f4f80f3f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -16,6 +16,8 @@ jobs:
- { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.11' }
- { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.10' }
- { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.11' }
- { model: 'artek0chumak/TestMixtral', os: 'ubuntu', python-version: '3.8' }
- { model: 'artek0chumak/TestMixtral', os: 'ubuntu', python-version: '3.11' }
fail-fast: false
runs-on: ${{ matrix.os }}-latest
timeout-minutes: 20

@ -38,7 +38,7 @@ class RemotePastKeyValues(Cache):
self.seen_tokens += new_seen
def reorder_cache(self, beam_idx):
pass
raise NotImplementedError("Beam search reordering is not implemented yet")
_skipped_tokens = ContextVar("skipped_tokens", default=0)

@ -9,6 +9,8 @@ import torch
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel, build_alibi_tensor
from petals.utils.misc import is_dummy
class WrappedBloomBlock(BloomBlock):
def forward(
@ -22,6 +24,10 @@ class WrappedBloomBlock(BloomBlock):
):
assert attention_mask is None, "Non-causal attention masks are not supported yet"
batch_size, seq_length = hidden_states.shape[:2]
if layer_past is not None and is_dummy(layer_past[0]):
# Bloom cannot use cache if it was misconsctructed(e.g. Dummy tensors)
# In this case, fallback to the old code:
layer_past = None
past_length = 0 if layer_past is None else layer_past[0].shape[-1]
seq_length_with_past = seq_length + past_length
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)

@ -1,3 +1,4 @@
import json
from typing import Optional, Tuple
import torch
@ -33,16 +34,15 @@ class WrappedMixtralBlock(MixtralDecoderLayer):
past_key_values_length = 0
past_key_value = layer_past
if past_key_value is not None:
past_key_values_length = past_key_value[0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
_past_key_value = self._reorder_cache_from_bloom(past_key_value, batch_size, past_key_values_length)
past_key_value = DynamicCache()
for idx in range(self.layer_idx):
past_key_value.update(
torch.empty(_past_key_value[0].size()), torch.empty(_past_key_value[1].size()), idx
)
past_key_value.update(_past_key_value[0], _past_key_value[1], self.layer_idx)
past_key_value.key_cache = [torch.empty(0) for _ in range(self.layer_idx)] + [_past_key_value[0]]
past_key_value.value_cache = [torch.empty(0) for _ in range(self.layer_idx)] + [_past_key_value[1]]
past_key_value._seen_tokens = past_key_values_length
if self._attn_implementation == "flash_attention_2":
# 2d mask is passed through the layers
@ -83,7 +83,7 @@ class WrappedMixtralBlock(MixtralDecoderLayer):
if use_cache:
present_key_value = outputs[-1]
present_key_value = present_key_value.to_legacy_cache()[self.layer_idx]
present_key_value = present_key_value[self.layer_idx]
present_key_value = self._reorder_cache_to_bloom(present_key_value, batch_size, seq_length_with_past)
outputs = outputs[:-1] + (present_key_value,)

@ -122,14 +122,20 @@ class DistributedMixtralModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMi
def word_embeddings(self) -> nn.Embedding: # For compatibility with RemoteGenerationMixin
return self.embed_tokens
@property
def word_embeddings_layernorm(self) -> nn.Module: # For compatibility with RemoteGenerationMixin in tests
return nn.Identity()
@property
def h(self) -> RemoteSequential: # For compatibility with RemoteGenerationMixin
return self.layers
@property
def ln_f(self) -> nn.Module: # For compatibility with RemoteGenerationMixin in tests
return self.norm
class DistributedMixtralForCausalLM(
DefaultRevisionMixin, FromPretrainedMixin, RemoteGenerationMixin, MixtralForCausalLM
):
class DistributedMixtralForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, MixtralForCausalLM):
_keys_to_ignore_on_load_missing = DistributedMixtralModel._keys_to_ignore_on_load_missing
_keys_to_ignore_on_load_unexpected = DistributedMixtralModel._keys_to_ignore_on_load_unexpected
@ -151,9 +157,12 @@ class DistributedMixtralForCausalLM(
return self.model
class DistributedMixtralForSequenceClassification(
DefaultRevisionMixin, FromPretrainedMixin, MixtralForSequenceClassification
):
class DistributedMixtralForSequenceClassification(FromPretrainedMixin, MixtralForSequenceClassification):
_keys_to_ignore_on_load_missing = DistributedMixtralModel._keys_to_ignore_on_load_missing
_keys_to_ignore_on_load_unexpected = DistributedMixtralModel._keys_to_ignore_on_load_unexpected
config_class = DistributedMixtralConfig
def __init__(self, config: DistributedMixtralConfig):
MixtralPreTrainedModel.__init__(self, config)
self.num_labels = config.num_labels

@ -2,8 +2,9 @@ from typing import Optional, Union
import torch
from accelerate import init_empty_weights
from transformers import PretrainedConfig
from transformers import PretrainedConfig, PreTrainedModel
from petals.models.mixtral.block import WrappedMixtralBlock
from petals.utils.convert_block import QuantType
from petals.utils.misc import get_size_in_bytes
@ -32,7 +33,7 @@ def get_block_size(
), 'get_block_size(..., location="memory") requires to specify dtype and quant_type for calculations'
with init_empty_weights(include_buffers=True):
block = config.block_class(config)
block = get_model_block(config)
n_params = sum(param.numel() for param in block.parameters())
if location == "memory":
@ -50,3 +51,15 @@ def get_block_size(
bytes_per_value = get_size_in_bytes(dtype)
return round(n_params * bytes_per_value * (1 + eps))
def get_model_block(config, layer_idx: int = 0):
"""
The function to create a model block based on the block class
kwargs argument **only** is necessary for specific classes, like Mixtral.
They will not be passed to other block constructors.
"""
if config.block_class == WrappedMixtralBlock:
config = PreTrainedModel._autoset_attn_implementation(config)
return config.block_class(config, layer_idx)
return config.block_class(config)

@ -24,7 +24,7 @@ from transformers.utils import get_file_from_repo
from petals.constants import DTYPE_MAP
from petals.models.mixtral import WrappedMixtralBlock
from petals.server.block_utils import resolve_block_dtype
from petals.server.block_utils import get_model_block, resolve_block_dtype
from petals.utils.auto_config import AutoDistributedConfig
from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
from petals.utils.hf_auth import always_needs_auth
@ -52,11 +52,7 @@ def load_pretrained_block(
torch_dtype = resolve_block_dtype(config, torch_dtype)
with init_empty_weights():
if config.block_class == WrappedMixtralBlock:
config = PreTrainedModel._autoset_attn_implementation(config)
block = config.block_class(config, block_index)
else:
block = config.block_class(config)
block = get_model_block(config, layer_idx=block_index)
block_prefix = f"{config.block_prefix}.{block_index}."
state_dict = _load_state_dict_from_repo(

@ -13,9 +13,10 @@ import torch.mps
from hivemind.utils.logging import get_logger
from transformers import PretrainedConfig
from petals.server.block_utils import resolve_block_dtype
from petals.server.block_utils import get_model_block, resolve_block_dtype
from petals.utils.convert_block import QuantType, convert_block
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
from petals.utils.misc import DUMMY_KEY_PAST
logger = get_logger(__name__)
@ -201,18 +202,25 @@ def measure_compute_rps(
if not tensor_parallel_devices:
tensor_parallel_devices = (device,)
with torch.inference_mode():
block = config.block_class(config).to(dtype)
block = get_model_block(config)
block = block.to(dtype)
block = convert_block(block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
cache = None
cache = (DUMMY_KEY_PAST.to(dtype), DUMMY_KEY_PAST.to(dtype))
elapsed = 0
dummy_input = torch.randn(1, n_tokens, config.hidden_size, device=device, dtype=dtype)
_, cache = block.forward(dummy_input, use_cache=True) # Skip the 1st step to exclude the initialization time
# Skip the 1st step to exclude the initialization time
def step(cache_):
outputs = block.forward(dummy_input, use_cache=inference, layer_past=cache_ if inference else None)
return outputs[1] if inference else None
cache = step(cache)
synchronize(device)
start_time = time.perf_counter()
for _ in range(n_steps):
_, cache = block.forward(dummy_input, use_cache=True, layer_past=cache if inference else None)
cache = step(cache)
synchronize(device)
elapsed = time.perf_counter() - start_time
device_rps = n_steps * n_tokens / elapsed

@ -4,6 +4,8 @@ DUMMY = torch.empty(0) # dummy tensor that replaces empty prompt or adapter par
DUMMY_INT64 = torch.empty(0, dtype=torch.int64)
DUMMY_KEY_PAST = torch.empty((0, 0, 0))
def is_dummy(tensor: torch.Tensor) -> bool:
return tensor.numel() == 0

@ -17,7 +17,7 @@ from safetensors import safe_open
from safetensors.torch import load_file
from transformers.utils import get_file_from_repo
from petals.server.block_utils import resolve_block_dtype
from petals.server.block_utils import get_model_block, resolve_block_dtype
from petals.utils.convert_block import QuantType
from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for
from petals.utils.misc import get_size_in_bytes
@ -273,7 +273,7 @@ def estimate_adapter_memory_per_block(
) -> int:
"""Get the number of extra bytes used to store a set of adapters per given block"""
with init_empty_weights(include_buffers=True):
block = block_config.block_class(block_config)
block = get_model_block(block_config)
base_block_parameters = sum(p.numel() for p in block.parameters())
create_lora_adapter(block, quant_type=QuantType.NONE)

@ -10,6 +10,7 @@ import torch
from petals import AutoDistributedConfig
from petals.client.remote_sequential import RemoteSequential
from petals.server.from_pretrained import load_pretrained_block
from petals.utils.misc import DUMMY_KEY_PAST
from test_utils import *
@ -54,12 +55,14 @@ def test_chained_inference_exact_match(atol_inference=1e-4):
outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
outputs_inference = torch.cat(outputs_inference, dim=1)
dtype = torch.float32
ref_blocks = [
load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch.float32),
load_pretrained_block(MODEL_NAME, 4, torch_dtype=torch.float32),
load_pretrained_block(MODEL_NAME, 3, torch_dtype=dtype),
load_pretrained_block(MODEL_NAME, 4, torch_dtype=dtype),
]
outputs_ref = []
caches = [None, None]
cache = (DUMMY_KEY_PAST.to(dtype), DUMMY_KEY_PAST.to(dtype))
caches = [cache, cache]
for i in range(inputs.shape[1]):
new_caches = []
hidden_states = inputs[:, i : i + 1, :]

@ -141,6 +141,10 @@ def test_sampling(tokenizer, model, ref_model, max_new_tokens=10):
), f"Sampling is not identical to HF with {options=}, {multiple_calls=}, {inputs.shape=}"
@pytest.mark.skipif(
"bloom" not in MODEL_NAME.lower(),
reason="Mixtral and Llama use DynamicCache, which can change based on beam search choices",
)
@pytest.mark.forked
def test_beam_search_generation(tokenizer, model, ref_model, max_new_tokens=4, num_beams=5):
inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]

@ -7,6 +7,7 @@ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_m
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
from petals.server.block_utils import get_model_block
from petals.utils.auto_config import AutoDistributedConfig
from petals.utils.convert_block import QuantType, convert_block
from test_utils import MODEL_NAME
@ -195,8 +196,9 @@ def test_optimized_block(device):
dtype = torch.bfloat16
quant_type = QuantType.NONE
block = config.block_class(config).to(dtype)
block = convert_block(block, 1, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
block_idx = 1
block = get_model_block(config, layer_idx=block_idx).to(dtype)
block = convert_block(block, block_idx, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
if config.model_type == "falcon":
unopt_block = UnoptimizedWrappedFalconBlock(config).to(dtype)
@ -206,7 +208,7 @@ def test_optimized_block(device):
pytest.skip(f"This test is not applicable to {config.model_type} models")
unopt_block = convert_block(
unopt_block, 1, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True
unopt_block, block_idx, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True
)
unopt_block.load_state_dict(block.state_dict())

Loading…
Cancel
Save