fix optimized layers test

pull/554/head
Denis Mazur 4 months ago
parent 1e6bd07bf6
commit 61456d9968

@ -4,9 +4,11 @@ import pytest
import torch
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from petals.utils.auto_config import AutoDistributedConfig
from petals.utils.convert_block import QuantType, convert_block
from transformers.cache_utils import DynamicCache
from test_utils import MODEL_NAME
KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -116,6 +118,8 @@ class UnoptimizedWrappedLlamaBlock(LlamaDecoderLayer):
past_key_values_length = past_key_value[0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length)
elif use_cache:
past_key_value = DynamicCache()
if position_ids is None:
device = hidden_states.device
@ -131,8 +135,9 @@ class UnoptimizedWrappedLlamaBlock(LlamaDecoderLayer):
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
)
attention_mask = LlamaModel._prepare_decoder_attention_mask(
None, attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
)
outputs = super().forward(
@ -156,19 +161,20 @@ class UnoptimizedWrappedLlamaBlock(LlamaDecoderLayer):
def _reorder_cache_from_bloom_to_llama(
self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
) -> Tuple[torch.Tensor]:
) -> DynamicCache:
key_states, value_states = key_value
key_states = key_states.permute(0, 2, 1)
key_states = key_states.view(
batch_size, self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
)
value_states = value_states.view(*key_states.shape)
return (key_states, value_states)
past_key_values = ((key_states, value_states),)
return DynamicCache.from_legacy_cache(past_key_values)
def _reorder_cache_from_llama_to_bloom(
self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
self, key_value: DynamicCache, batch_size: int, seq_length: int
) -> Tuple[torch.Tensor]:
key_states, value_states = key_value
key_states, value_states = key_value.to_legacy_cache()[0]
value_states = value_states.view(
batch_size * self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
)
@ -178,7 +184,7 @@ class UnoptimizedWrappedLlamaBlock(LlamaDecoderLayer):
@pytest.mark.parametrize("device", ["cpu", "cuda:0"])
@pytest.mark.forked
# @pytest.mark.forked
def test_optimized_block(device):
if device == "cuda:0" and not torch.cuda.is_available():
pytest.skip("CUDA tests can be run only in CUDA-enabled setups")
@ -195,7 +201,7 @@ def test_optimized_block(device):
if config.model_type == "falcon":
unopt_block = UnoptimizedWrappedFalconBlock(config).to(dtype)
elif config.model_type == "llama":
unopt_block = UnoptimizedWrappedLlamaBlock(config).to(dtype)
unopt_block = UnoptimizedWrappedLlamaBlock(config, layer_idx=0).to(dtype)
else:
pytest.skip(f"This test is not applicable to {config.model_type} models")

Loading…
Cancel
Save