fix cache and throughput

pull/570/head
Artem Chumachenko 1 month ago
parent f06cfd2b97
commit 204855c751

@ -9,6 +9,8 @@ import torch
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel, build_alibi_tensor
from petals.utils.misc import is_dummy
class WrappedBloomBlock(BloomBlock):
def forward(
@ -22,6 +24,10 @@ class WrappedBloomBlock(BloomBlock):
):
assert attention_mask is None, "Non-causal attention masks are not supported yet"
batch_size, seq_length = hidden_states.shape[:2]
if layer_past is not None and is_dummy(layer_past[0]):
# Bloom cannot use cache if it was misconsctructed(e.g. Dummy tensors)
# In this case, fallback to the old code:
layer_past = None
past_length = 0 if layer_past is None else layer_past[0].shape[-1]
seq_length_with_past = seq_length + past_length
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)

@ -16,6 +16,7 @@ from transformers import PretrainedConfig
from petals.server.block_utils import resolve_block_dtype, get_model_block
from petals.utils.convert_block import QuantType, convert_block
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
from petals.utils.misc import DUMMY_KEY_PAST
logger = get_logger(__name__)
@ -205,15 +206,21 @@ def measure_compute_rps(
block = block.to(dtype)
block = convert_block(block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
cache = None
cache = (DUMMY_KEY_PAST, DUMMY_KEY_PAST)
elapsed = 0
dummy_input = torch.randn(1, n_tokens, config.hidden_size, device=device, dtype=dtype)
_, cache = block.forward(dummy_input, use_cache=True) # Skip the 1st step to exclude the initialization time
def step(cache_):
outputs = block.forward(dummy_input, use_cache=inference, layer_past=cache_ if inference else None)
return outputs[1] if inference else None
cache = step(cache)
# Skip the 1st step to exclude the initialization time
synchronize(device)
start_time = time.perf_counter()
for _ in range(n_steps):
_, cache = block.forward(dummy_input, use_cache=True, layer_past=cache if inference else None)
cache = step(cache)
synchronize(device)
elapsed = time.perf_counter() - start_time
device_rps = n_steps * n_tokens / elapsed

@ -4,6 +4,8 @@ DUMMY = torch.empty(0) # dummy tensor that replaces empty prompt or adapter par
DUMMY_INT64 = torch.empty(0, dtype=torch.int64)
DUMMY_KEY_PAST = torch.empty((0, 0, 0))
def is_dummy(tensor: torch.Tensor) -> bool:
return tensor.numel() == 0

Loading…
Cancel
Save