From 204855c7513451542af23d7fe8c085160dc33e98 Mon Sep 17 00:00:00 2001 From: Artem Chumachenko Date: Mon, 8 Apr 2024 18:59:13 +0200 Subject: [PATCH] fix cache and throughput --- src/petals/models/bloom/block.py | 6 ++++++ src/petals/server/throughput.py | 13 ++++++++++--- src/petals/utils/misc.py | 2 ++ 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/src/petals/models/bloom/block.py b/src/petals/models/bloom/block.py index 86fc4aa..439b9ca 100644 --- a/src/petals/models/bloom/block.py +++ b/src/petals/models/bloom/block.py @@ -9,6 +9,8 @@ import torch from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel, build_alibi_tensor +from petals.utils.misc import is_dummy + class WrappedBloomBlock(BloomBlock): def forward( @@ -22,6 +24,10 @@ class WrappedBloomBlock(BloomBlock): ): assert attention_mask is None, "Non-causal attention masks are not supported yet" batch_size, seq_length = hidden_states.shape[:2] + if layer_past is not None and is_dummy(layer_past[0]): + # Bloom cannot use cache if it was misconsctructed(e.g. Dummy tensors) + # In this case, fallback to the old code: + layer_past = None past_length = 0 if layer_past is None else layer_past[0].shape[-1] seq_length_with_past = seq_length + past_length attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) diff --git a/src/petals/server/throughput.py b/src/petals/server/throughput.py index eab392f..c6948c8 100644 --- a/src/petals/server/throughput.py +++ b/src/petals/server/throughput.py @@ -16,6 +16,7 @@ from transformers import PretrainedConfig from petals.server.block_utils import resolve_block_dtype, get_model_block from petals.utils.convert_block import QuantType, convert_block from petals.utils.disk_cache import DEFAULT_CACHE_DIR +from petals.utils.misc import DUMMY_KEY_PAST logger = get_logger(__name__) @@ -205,15 +206,21 @@ def measure_compute_rps( block = block.to(dtype) block = convert_block(block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True) - cache = None + cache = (DUMMY_KEY_PAST, DUMMY_KEY_PAST) elapsed = 0 dummy_input = torch.randn(1, n_tokens, config.hidden_size, device=device, dtype=dtype) - _, cache = block.forward(dummy_input, use_cache=True) # Skip the 1st step to exclude the initialization time + + def step(cache_): + outputs = block.forward(dummy_input, use_cache=inference, layer_past=cache_ if inference else None) + return outputs[1] if inference else None + + cache = step(cache) + # Skip the 1st step to exclude the initialization time synchronize(device) start_time = time.perf_counter() for _ in range(n_steps): - _, cache = block.forward(dummy_input, use_cache=True, layer_past=cache if inference else None) + cache = step(cache) synchronize(device) elapsed = time.perf_counter() - start_time device_rps = n_steps * n_tokens / elapsed diff --git a/src/petals/utils/misc.py b/src/petals/utils/misc.py index d0cfd7c..2d53bab 100644 --- a/src/petals/utils/misc.py +++ b/src/petals/utils/misc.py @@ -4,6 +4,8 @@ DUMMY = torch.empty(0) # dummy tensor that replaces empty prompt or adapter par DUMMY_INT64 = torch.empty(0, dtype=torch.int64) +DUMMY_KEY_PAST = torch.empty((0, 0, 0)) + def is_dummy(tensor: torch.Tensor) -> bool: return tensor.numel() == 0