From 5af04524dd8bc5ce6a9a11a35ee71714f310aad6 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sat, 22 Jul 2023 22:10:46 +0300 Subject: [PATCH] Split long sequences into chunks (#403) This PR is designed to avoid OOMs when processing long sequences that happen due to the huge attention logits matrices. Co-authored-by: Alexander Borzunov --- .github/workflows/run-tests.yaml | 3 ++- src/petals/cli/run_server.py | 2 ++ src/petals/server/backend.py | 41 +++++++++++++++++++++++++++++--- src/petals/server/server.py | 5 ++++ tests/test_full_model.py | 12 +++++++--- 5 files changed, 56 insertions(+), 7 deletions(-) diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index 7ec5bf3..3bccda3 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -37,7 +37,8 @@ jobs: python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:12 \ --new_swarm --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --throughput 1 \ - --torch_dtype float32 --compression NONE --attn_cache_tokens 2048 --adapters $ADAPTER_NAME &> server1.log & + --torch_dtype float32 --compression NONE --attn_cache_tokens 2048 --max_chunk_size_bytes 1024 \ + --adapters $ADAPTER_NAME &> server1.log & SERVER1_PID=$! sleep 5 # wait for the first server to initialize DHT diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index a33e233..8132a39 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -74,6 +74,8 @@ def main(): parser.add_argument('--max_batch_size', type=int, default=None, help='The total number of tokens in the same batch will not exceed this value. ' 'Default: 2048 for most models, 8192 for models with multi-query attention (e.g., Llama-2-70b)') + parser.add_argument('--max_chunk_size_bytes', type=int, default=256 * 1024 * 1024, + help='Maximum size of activation tensor processed in one go; larger tensors are split into chunks') parser.add_argument('--attn_cache_tokens', type=int, default=None, help='The number of past attention key/value pairs that will be stored between inference steps. ' 'Default: 8192 for most models, 32768 for models with multi-query attention (e.g., Llama-2-70b)') diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index d61470a..8b788b0 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -27,7 +27,13 @@ class TransformerBackend(ModuleBackend): _peft_module = None def __init__( - self, *args, config: PretrainedConfig, memory_cache: MemoryCache, backend_dtype: torch.dtype, **kwargs + self, + *args, + config: PretrainedConfig, + memory_cache: MemoryCache, + backend_dtype: torch.dtype, + max_chunk_size_bytes: int, + **kwargs, ): import petals.utils.peft as _peft_module @@ -37,6 +43,8 @@ class TransformerBackend(ModuleBackend): assert isinstance(self.module, TensorParallel) self.config = config self.memory_cache = memory_cache + self.max_chunk_size_bytes = max_chunk_size_bytes + for name, param in self.module.named_parameters(): assert not param.requires_grad, f"Block parameters must not accumulate gradients, but {name} does" for name, buf in self.module.named_buffers(): @@ -55,6 +63,7 @@ class TransformerBackend(ModuleBackend): ) self.dtype = backend_dtype + self.dtype_bytes = torch.finfo(self.dtype).bits // 8 self.shard_num_heads = [] for shard in self.module.module_shards: for submodule in shard.modules(): @@ -105,14 +114,40 @@ class TransformerBackend(ModuleBackend): inference_info: InferenceMetadata, ) -> Tuple[torch.Tensor, ...]: assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]" + seq_len = hidden_states.shape[1] + with self.memory_cache.use_cache( *inference_info.cache_handles ) as cache_tensors, self._peft_module.using_adapter(inference_info.active_adapter): self._reorder_cache_inplace(cache_tensors, hypo_ids) + + # We chunk the inputs so that peak memory for long sequences fits into `autograd_memory` + # reserved in `Server._choose_num_blocks()`. This saves us from OOMs if `max_chunk_size_bytes` + # is at least 4-6x less than `autograd_memory`. + max_chunk_length = self._estimate_max_chunk_length(hidden_states, inference_info) + output_hidden_states = torch.empty_like(hidden_states) if seq_len > max_chunk_length else None layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length) - hidden_states, new_kvs = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True) + for offset in range(0, seq_len, max_chunk_length): + hidden_states_chunk = hidden_states[:, offset : offset + max_chunk_length, :] + output_hidden_states_chunk, new_kvs = self.module.forward( + hidden_states_chunk, layer_past=layer_past, use_cache=True + ) + if seq_len > max_chunk_length: + output_hidden_states[:, offset : offset + max_chunk_length] = output_hidden_states_chunk + else: + output_hidden_states = output_hidden_states_chunk # saves one memcopy + layer_past = new_kvs + self._update_cache_inplace(cache_tensors, new_kvs, inference_info.prefix_length) - return (hidden_states,) + return (output_hidden_states,) + + def _estimate_max_chunk_length(self, hidden_states: torch.Tensor, inference_info: InferenceMetadata) -> int: + # We assume that attention logit matrices are the main thing that consumes memory, given that + # the model uses multi-query attention + batch_size, seq_length, hidden_size = hidden_states.shape + worst_case_length = inference_info.prefix_length + seq_length + attn_bytes_per_token = max(self.shard_num_heads) * batch_size * self.dtype_bytes * worst_case_length + return max(1, self.max_chunk_size_bytes // attn_bytes_per_token) def _reorder_cache_inplace(self, cache_tensors: torch.Tensor, hypo_ids: torch.Tensor): """If hypo_ids is specified, reorder elements of each cache tensor in-place by taking indices from hypo_ids""" diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 6d5c293..5cb9b91 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -58,6 +58,7 @@ class Server: inference_max_length: Optional[int] = None, min_batch_size: int = 1, max_batch_size: Optional[int] = None, + max_chunk_size_bytes: int = 256 * 1024 * 1024, attn_cache_tokens: Optional[int] = None, torch_dtype: str = "auto", revision: Optional[str] = None, @@ -183,6 +184,7 @@ class Server: inference_max_length = 8192 if is_multiquery_attn else 2048 self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size self.inference_max_length = inference_max_length + self.max_chunk_size_bytes = max_chunk_size_bytes # For attention cache in GPU or RAM if attn_cache_tokens is None: @@ -312,6 +314,7 @@ class Server: num_handlers=self.num_handlers, min_batch_size=self.min_batch_size, max_batch_size=self.max_batch_size, + max_chunk_size_bytes=self.max_chunk_size_bytes, inference_max_length=self.inference_max_length, torch_dtype=self.torch_dtype, cache_dir=self.cache_dir, @@ -412,6 +415,7 @@ class ModuleContainer(threading.Thread): block_indices: List[int], min_batch_size: int, max_batch_size: int, + max_chunk_size_bytes: int, torch_dtype: torch.dtype, cache_dir: str, max_disk_space: int, @@ -477,6 +481,7 @@ class ModuleContainer(threading.Thread): config=block_config, memory_cache=memory_cache, backend_dtype=torch_dtype, + max_chunk_size_bytes=max_chunk_size_bytes, args_schema=( BatchTensorDescriptor( 1, 2048, block_config.hidden_size, dtype=torch_dtype, compression=compression diff --git a/tests/test_full_model.py b/tests/test_full_model.py index acd5e6a..511604b 100644 --- a/tests/test_full_model.py +++ b/tests/test_full_model.py @@ -28,7 +28,7 @@ def test_full_model_exact_match(use_peft: bool, pass_empty_tensors: bool, atol_f assert isinstance(model, DistributedBloomForCausalLM) assert len(model.transformer.h) == model.config.num_hidden_layers - test_inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"] + test_inputs = tokenizer("A quick brown fox was minding its own buisness", return_tensors="pt")["input_ids"] with torch.inference_mode(): parallel_outputs = model.forward(test_inputs).logits @@ -43,8 +43,14 @@ def test_full_model_exact_match(use_peft: bool, pass_empty_tensors: bool, atol_f recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size))) for t in range(embs.shape[1]): - recurrent_outputs.append(sess.step(embs[:, t : t + 1, :])) - if t == int(embs.shape[1] // 2) and pass_empty_tensors: + if t == 4: + recurrent_outputs.append(sess.step(embs[:, 4:9, :])) + elif 4 < t < 9: + continue + else: + recurrent_outputs.append(sess.step(embs[:, t : t + 1, :])) + + if t == 2 and pass_empty_tensors: recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size))) recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))