Split long sequences into chunks (#403)

This PR is designed to avoid OOMs when processing long sequences that happen due to the huge attention logits matrices.

Co-authored-by: Alexander Borzunov <borzunov.alexander@gmail.com>
pull/407/head
justheuristic 10 months ago committed by GitHub
parent 30b94ef18b
commit 5af04524dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -37,7 +37,8 @@ jobs:
python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:12 \
--new_swarm --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --throughput 1 \
--torch_dtype float32 --compression NONE --attn_cache_tokens 2048 --adapters $ADAPTER_NAME &> server1.log &
--torch_dtype float32 --compression NONE --attn_cache_tokens 2048 --max_chunk_size_bytes 1024 \
--adapters $ADAPTER_NAME &> server1.log &
SERVER1_PID=$!
sleep 5 # wait for the first server to initialize DHT

@ -74,6 +74,8 @@ def main():
parser.add_argument('--max_batch_size', type=int, default=None,
help='The total number of tokens in the same batch will not exceed this value. '
'Default: 2048 for most models, 8192 for models with multi-query attention (e.g., Llama-2-70b)')
parser.add_argument('--max_chunk_size_bytes', type=int, default=256 * 1024 * 1024,
help='Maximum size of activation tensor processed in one go; larger tensors are split into chunks')
parser.add_argument('--attn_cache_tokens', type=int, default=None,
help='The number of past attention key/value pairs that will be stored between inference steps. '
'Default: 8192 for most models, 32768 for models with multi-query attention (e.g., Llama-2-70b)')

@ -27,7 +27,13 @@ class TransformerBackend(ModuleBackend):
_peft_module = None
def __init__(
self, *args, config: PretrainedConfig, memory_cache: MemoryCache, backend_dtype: torch.dtype, **kwargs
self,
*args,
config: PretrainedConfig,
memory_cache: MemoryCache,
backend_dtype: torch.dtype,
max_chunk_size_bytes: int,
**kwargs,
):
import petals.utils.peft as _peft_module
@ -37,6 +43,8 @@ class TransformerBackend(ModuleBackend):
assert isinstance(self.module, TensorParallel)
self.config = config
self.memory_cache = memory_cache
self.max_chunk_size_bytes = max_chunk_size_bytes
for name, param in self.module.named_parameters():
assert not param.requires_grad, f"Block parameters must not accumulate gradients, but {name} does"
for name, buf in self.module.named_buffers():
@ -55,6 +63,7 @@ class TransformerBackend(ModuleBackend):
)
self.dtype = backend_dtype
self.dtype_bytes = torch.finfo(self.dtype).bits // 8
self.shard_num_heads = []
for shard in self.module.module_shards:
for submodule in shard.modules():
@ -105,14 +114,40 @@ class TransformerBackend(ModuleBackend):
inference_info: InferenceMetadata,
) -> Tuple[torch.Tensor, ...]:
assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
seq_len = hidden_states.shape[1]
with self.memory_cache.use_cache(
*inference_info.cache_handles
) as cache_tensors, self._peft_module.using_adapter(inference_info.active_adapter):
self._reorder_cache_inplace(cache_tensors, hypo_ids)
# We chunk the inputs so that peak memory for long sequences fits into `autograd_memory`
# reserved in `Server._choose_num_blocks()`. This saves us from OOMs if `max_chunk_size_bytes`
# is at least 4-6x less than `autograd_memory`.
max_chunk_length = self._estimate_max_chunk_length(hidden_states, inference_info)
output_hidden_states = torch.empty_like(hidden_states) if seq_len > max_chunk_length else None
layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length)
hidden_states, new_kvs = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True)
for offset in range(0, seq_len, max_chunk_length):
hidden_states_chunk = hidden_states[:, offset : offset + max_chunk_length, :]
output_hidden_states_chunk, new_kvs = self.module.forward(
hidden_states_chunk, layer_past=layer_past, use_cache=True
)
if seq_len > max_chunk_length:
output_hidden_states[:, offset : offset + max_chunk_length] = output_hidden_states_chunk
else:
output_hidden_states = output_hidden_states_chunk # saves one memcopy
layer_past = new_kvs
self._update_cache_inplace(cache_tensors, new_kvs, inference_info.prefix_length)
return (hidden_states,)
return (output_hidden_states,)
def _estimate_max_chunk_length(self, hidden_states: torch.Tensor, inference_info: InferenceMetadata) -> int:
# We assume that attention logit matrices are the main thing that consumes memory, given that
# the model uses multi-query attention
batch_size, seq_length, hidden_size = hidden_states.shape
worst_case_length = inference_info.prefix_length + seq_length
attn_bytes_per_token = max(self.shard_num_heads) * batch_size * self.dtype_bytes * worst_case_length
return max(1, self.max_chunk_size_bytes // attn_bytes_per_token)
def _reorder_cache_inplace(self, cache_tensors: torch.Tensor, hypo_ids: torch.Tensor):
"""If hypo_ids is specified, reorder elements of each cache tensor in-place by taking indices from hypo_ids"""

@ -58,6 +58,7 @@ class Server:
inference_max_length: Optional[int] = None,
min_batch_size: int = 1,
max_batch_size: Optional[int] = None,
max_chunk_size_bytes: int = 256 * 1024 * 1024,
attn_cache_tokens: Optional[int] = None,
torch_dtype: str = "auto",
revision: Optional[str] = None,
@ -183,6 +184,7 @@ class Server:
inference_max_length = 8192 if is_multiquery_attn else 2048
self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
self.inference_max_length = inference_max_length
self.max_chunk_size_bytes = max_chunk_size_bytes
# For attention cache in GPU or RAM
if attn_cache_tokens is None:
@ -312,6 +314,7 @@ class Server:
num_handlers=self.num_handlers,
min_batch_size=self.min_batch_size,
max_batch_size=self.max_batch_size,
max_chunk_size_bytes=self.max_chunk_size_bytes,
inference_max_length=self.inference_max_length,
torch_dtype=self.torch_dtype,
cache_dir=self.cache_dir,
@ -412,6 +415,7 @@ class ModuleContainer(threading.Thread):
block_indices: List[int],
min_batch_size: int,
max_batch_size: int,
max_chunk_size_bytes: int,
torch_dtype: torch.dtype,
cache_dir: str,
max_disk_space: int,
@ -477,6 +481,7 @@ class ModuleContainer(threading.Thread):
config=block_config,
memory_cache=memory_cache,
backend_dtype=torch_dtype,
max_chunk_size_bytes=max_chunk_size_bytes,
args_schema=(
BatchTensorDescriptor(
1, 2048, block_config.hidden_size, dtype=torch_dtype, compression=compression

@ -28,7 +28,7 @@ def test_full_model_exact_match(use_peft: bool, pass_empty_tensors: bool, atol_f
assert isinstance(model, DistributedBloomForCausalLM)
assert len(model.transformer.h) == model.config.num_hidden_layers
test_inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
test_inputs = tokenizer("A quick brown fox was minding its own buisness", return_tensors="pt")["input_ids"]
with torch.inference_mode():
parallel_outputs = model.forward(test_inputs).logits
@ -43,8 +43,14 @@ def test_full_model_exact_match(use_peft: bool, pass_empty_tensors: bool, atol_f
recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
for t in range(embs.shape[1]):
recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
if t == int(embs.shape[1] // 2) and pass_empty_tensors:
if t == 4:
recurrent_outputs.append(sess.step(embs[:, 4:9, :]))
elif 4 < t < 9:
continue
else:
recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
if t == 2 and pass_empty_tensors:
recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))

Loading…
Cancel
Save