From 1ebd88ae7b8238fe1778409f55fa88aa542a2541 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Mon, 4 Sep 2023 14:38:32 +0200 Subject: [PATCH] Optimize the Falcon block for inference (#500) This PR attempts to optimize the inference of Falcon models in the single-token setup by reducing the majority of Python overhead and making several assumptions about the setup. Specifically, * Layer normalization, QKV projection (with splitting) and rotary embeddings are executed through CUDA graphs, which reduces most overhead related to small kernel launche * If no sin/cos tensors are cached by the rotary embedding layer, we cache them for 8192 tokens (INFERENCE_MAX_LENGTH) during the first forward pass. In general, it should be beneficial to always run a max-length sequence before starting a block, but this is a question for another PR The PR also adds a small test to ensure that the results (without quantization) of the block before and after quantization indeed match. Lastly, the pull request makes the backward pass work (as discussed in https://github.com/bigscience-workshop/petals/pull/499) by making cached sin/cos for RotaryEmbedding into buffers and disabling the inference mode during their creation. --- src/petals/models/falcon/block.py | 394 +++++++++++++++++++++++++++++- tests/test_optimized_layers.py | 128 ++++++++++ 2 files changed, 518 insertions(+), 4 deletions(-) create mode 100644 tests/test_optimized_layers.py diff --git a/src/petals/models/falcon/block.py b/src/petals/models/falcon/block.py index e677e06..a510aba 100644 --- a/src/petals/models/falcon/block.py +++ b/src/petals/models/falcon/block.py @@ -3,15 +3,399 @@ Falcon intermediate layer Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py See commit history for authorship. """ +import math +from functools import partial from typing import Optional, Tuple import torch -from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor +import torch.nn as nn +import torch.nn.functional as F +from transformers.models.falcon.modeling_falcon import ( + FalconAttention, + FalconConfig, + FalconDecoderLayer, + FalconLinear, + FalconMLP, + FalconModel, + LayerNorm, + build_alibi_tensor, + dropout_add, + rotate_half, +) KVCache = Tuple[torch.Tensor, torch.Tensor] +INFERENCE_MAX_LENGTH = 8192 -class WrappedFalconBlock(FalconDecoderLayer): +def apply_rotary(query, key, cos, sin): + return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin) + + +class OptimizedFalconRotaryEmbedding(nn.Module): + def __init__(self, head_dim: int, base=10000): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.head_dim = head_dim + self.seq_len_cached = -1 + + self.cuda_graph = None + self.input_surface = None + self.static_outputs = None + + def _optimized_apply_rotary(self, query, key, cos, sin): + if self.cuda_graph is None: + self.cuda_graph = torch.cuda.CUDAGraph() + self.input_surface = (query, key, cos, sin) + + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(3): + apply_rotary(*self.input_surface) + torch.cuda.current_stream().wait_stream(s) + + with torch.cuda.graph(self.cuda_graph): + self.static_outputs = apply_rotary(*self.input_surface) + + inputs = (query, key, cos, sin) + for static_input, data in zip(self.input_surface, inputs): + static_input.copy_(data) + self.cuda_graph.replay() + return tuple(o.detach() for o in self.static_outputs) + + def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor: + total_length = seq_len + past_key_values_length + if self.seq_len_cached == -1: + # warm up the cache + total_length = max(INFERENCE_MAX_LENGTH, total_length) + + if total_length > self.seq_len_cached: + with torch.inference_mode(False): + self.seq_len_cached = total_length + t = torch.arange(total_length, device=device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1).to(device) + + if dtype in [torch.float16, torch.bfloat16]: + emb = emb.float() + + self.register_buffer("cos_cached", emb.cos()[None, :, :].type(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, :, :].type(dtype), persistent=False) + + return ( + self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length].type(dtype), + self.sin_cached[:, past_key_values_length : seq_len + past_key_values_length].type(dtype), + ) + + def forward(self, query, key, past_key_values_length=0): + batch, seq_len, head_dim = query.shape + cos, sin = self.cos_sin(seq_len, past_key_values_length, query.device, query.dtype) + if seq_len == 1 and torch.is_inference_mode_enabled() and query.device.type == "cuda": + return self._optimized_apply_rotary(query, key, cos, sin) + else: + return apply_rotary(query, key, cos, sin) + + +def split_heads( + fused_qkv: torch.Tensor, num_heads: int, num_kv_heads: int, head_dim: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + batch, seq_len, _ = fused_qkv.shape + qkv = fused_qkv.view(batch, seq_len, -1, num_heads // num_kv_heads + 2, head_dim) + query, key, value = torch.split(qkv, [num_heads // num_kv_heads, 1, 1], dim=3) + key = torch.broadcast_to(key, query.shape) + value = torch.broadcast_to(value, query.shape) + + query, key, value = [x.flatten(2, 3) for x in (query, key, value)] + return query, key, value + + +class OptimizedFalconAttention(FalconAttention): + def __init__(self, config: FalconConfig): + nn.Module.__init__(self) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.split_size = self.hidden_size + self.hidden_dropout = config.hidden_dropout + + if self.head_dim * self.num_heads != self.hidden_size: + raise ValueError( + f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:" + f" {self.num_heads})." + ) + + self.maybe_rotary = OptimizedFalconRotaryEmbedding(config.head_dim) if config.rotary else lambda q, k, t: (q, k) + + # Layer-wise attention scaling + self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim) + self.beta = self.inv_norm_factor + if config.new_decoder_architecture: + qkv_out_dim = (config.num_kv_heads * 2 + config.num_attention_heads) * self.head_dim + elif config.multi_query: + qkv_out_dim = self.hidden_size + 2 * self.head_dim + else: + qkv_out_dim = 3 * self.hidden_size + self.query_key_value = FalconLinear(self.hidden_size, qkv_out_dim, bias=config.bias) + self.new_decoder_architecture = config.new_decoder_architecture + self.multi_query = config.multi_query + self.dense = FalconLinear(self.hidden_size, self.hidden_size, bias=config.bias) + self.attention_dropout = nn.Dropout(config.attention_dropout) + self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1 + + if self.new_decoder_architecture: + self._split_heads = partial( + split_heads, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, head_dim=self.head_dim + ) + self.split_graph = None + self.input_surface = None + self.static_outputs = None + + def _optimized_split_heads(self, fused_qkv): + if self.split_graph is None: + self.split_graph = torch.cuda.CUDAGraph() + self.input_surface = fused_qkv + + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(3): + self._split_heads(fused_qkv) + torch.cuda.current_stream().wait_stream(s) + + with torch.cuda.graph(self.split_graph): + self.static_outputs = self._split_heads(self.input_surface) + + self.input_surface.copy_(fused_qkv) + self.split_graph.replay() + return tuple(o.detach() for o in self.static_outputs) + + def forward( + self, + hidden_states: torch.Tensor, + alibi: Optional[torch.Tensor], + attention_mask: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + ): + assert not output_attentions + + fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] + + if ( + self.new_decoder_architecture + and hidden_states.size(1) == 1 + and torch.is_inference_mode_enabled() + and hidden_states.device.type == "cuda" + ): + query_layer, key_layer, value_layer = self._optimized_split_heads(fused_qkv) + else: + # 3 x [batch_size, seq_length, num_heads, head_dim] + (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) + + num_kv_heads = self.num_heads + batch_size, query_length, _, _ = query_layer.shape + + query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim) + key_layer = key_layer.transpose(1, 2).reshape( + batch_size * num_kv_heads, + query_length, + self.head_dim, + ) + value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim) + + past_kv_length = 0 if layer_past is None else layer_past[0].shape[1] + query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length) + + if layer_past is not None: + past_key, past_value = layer_past + # concatenate along seq_length dimension: + # - key: [batch_size * self.num_heads, kv_length, head_dim] + # - value: [batch_size * self.num_heads, kv_length, head_dim] + key_layer = torch.cat((past_key, key_layer), dim=1) + value_layer = torch.cat((past_value, value_layer), dim=1) + + _, kv_length, _ = key_layer.shape + if use_cache: + present = (key_layer, value_layer) + else: + present = None + + query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim) + key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim) + value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim) + + attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype) + + if alibi is None: + attn_output = F.scaled_dot_product_attention( + query_layer_, key_layer_, value_layer_, attn_mask=attention_mask_float, dropout_p=0.0, is_causal=False + ) + + attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim) + attn_output = attn_output.permute(0, 2, 1, 3) + attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) + + output_tensor = self.dense(attn_output) + + return output_tensor, present + else: + matmul_result = query_layer_ @ key_layer_.transpose(-1, -2) + + # change view to [batch_size, num_heads, q_length, kv_length] + attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length) + + # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] + input_dtype = attention_scores.dtype + # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38` + if input_dtype == torch.float16 or input_dtype == torch.bfloat16: + attention_scores = attention_scores.to(torch.float32) + # Matt (HF) note: We could possibly use F.scaled_dot_product_attention here too, by + # adding (alibi * self.inv_norm_factor) to attention_mask_float. I think this would be mathematically + # equivalent and more performant, but there might be a numerical difference. If you're reading this + # and you'd like to experiment and maybe file a PR, feel free! + attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1) + attention_logits *= self.inv_norm_factor + attention_probs = F.softmax(attention_logits + attention_mask_float, dim=-1, dtype=hidden_states.dtype) + # [batch_size, num_heads, q_length, kv_length] + attention_probs = self.attention_dropout(attention_probs) + + if head_mask is not None: + attention_probs = attention_probs * head_mask + + # change view [batch_size, num_heads, q_length, kv_length] + attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length) + + # matmul: [batch_size * num_heads, q_length, head_dim] + context_layer = (attention_probs_reshaped @ value_layer_).flatten(0, 1) + + # change view [batch_size, q_length, num_heads * head_dim] + context_layer = self._merge_heads(context_layer) + + output_tensor = self.dense(context_layer) + + if output_attentions: + return output_tensor, present, attention_probs + else: + return output_tensor, present + + +class OptimizedFalconDecoderLayer(FalconDecoderLayer): + def __init__(self, config: FalconConfig): + nn.Module.__init__(self) + hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + + self.mlp = FalconMLP(config) + self.hidden_dropout = config.hidden_dropout + self.config = config + + self.self_attention = OptimizedFalconAttention(config) + + if self.config.alibi or not config.new_decoder_architecture: + if config.new_decoder_architecture: + # The layer norm before self-attention + self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + # The layer norm before the MLP + self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + else: + self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + if not config.parallel_attn: + self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + else: + self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + self.ln_graph = None + self.static_input = None + self.static_outputs = None + + def _optimized_apply_ln(self, hidden_states): + if self.ln_graph is None: + self.ln_graph = torch.cuda.CUDAGraph() + self.static_input = hidden_states + + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(3): + self.ln_attn(hidden_states) + self.ln_mlp(hidden_states) + torch.cuda.current_stream().wait_stream(s) + + with torch.cuda.graph(self.ln_graph): + ln_attn_output = self.ln_attn(hidden_states) + ln_mlp_output = self.ln_mlp(hidden_states) + self.static_outputs = (ln_attn_output, ln_mlp_output) + + self.static_input.copy_(hidden_states) + self.ln_graph.replay() + return tuple(o.detach() for o in self.static_outputs) + + def forward( + self, + hidden_states: torch.Tensor, + alibi: Optional[torch.Tensor], + attention_mask: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + ): + residual = hidden_states + + if self.config.new_decoder_architecture: + if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda": + attention_layernorm_out, mlp_layernorm_out = self._optimized_apply_ln(hidden_states) + else: + attention_layernorm_out = self.ln_attn(hidden_states) + mlp_layernorm_out = self.ln_mlp(hidden_states) + else: + attention_layernorm_out = self.input_layernorm(hidden_states) + + attn_outputs = self.self_attention( + attention_layernorm_out, + layer_past=layer_past, + attention_mask=attention_mask, + alibi=alibi, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + attention_output = attn_outputs[0] + + if not self.config.new_decoder_architecture: + if self.config.parallel_attn: + mlp_layernorm_out = attention_layernorm_out + else: + residual = dropout_add( + attention_output, residual, self.config.attention_dropout, training=self.training + ) + mlp_layernorm_out = self.post_attention_layernorm(residual) + + outputs = attn_outputs[1:] + + mlp_output = self.mlp(mlp_layernorm_out) + + if self.config.new_decoder_architecture or self.config.parallel_attn: + mlp_output += attention_output + + output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training) + + if use_cache: + outputs = (output,) + outputs + else: + outputs = (output,) + outputs[1:] + + return outputs # hidden_states, present, attentions + + +class WrappedFalconBlock(OptimizedFalconDecoderLayer): def forward( self, hidden_states: torch.Tensor, @@ -20,8 +404,10 @@ class WrappedFalconBlock(FalconDecoderLayer): alibi: Optional[torch.Tensor] = None, layer_past: Optional[KVCache] = None, use_cache: bool = False, - **kwargs + **kwargs, ): + assert attention_mask is None + batch_size, seq_length = hidden_states.shape[:2] if layer_past is not None: @@ -41,7 +427,7 @@ class WrappedFalconBlock(FalconDecoderLayer): alibi=alibi, layer_past=layer_past, use_cache=use_cache, - **kwargs + **kwargs, ) if use_cache: diff --git a/tests/test_optimized_layers.py b/tests/test_optimized_layers.py new file mode 100644 index 0000000..5baa1a2 --- /dev/null +++ b/tests/test_optimized_layers.py @@ -0,0 +1,128 @@ +from typing import Optional, Tuple + +import pytest +import torch +from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor + +from petals.utils.auto_config import AutoDistributedConfig +from petals.utils.convert_block import QuantType, convert_block +from test_utils import MODEL_NAME + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class UnoptimizedWrappedFalconBlock(FalconDecoderLayer): + def forward( + self, + hidden_states: torch.Tensor, + *args, + attention_mask: Optional[torch.Tensor] = None, + alibi: Optional[torch.Tensor] = None, + layer_past: Optional[KVCache] = None, + use_cache: bool = False, + **kwargs, + ): + batch_size, seq_length = hidden_states.shape[:2] + + if layer_past is not None: + layer_past = self._reorder_cache_from_bloom_to_falcon(layer_past) + past_length = 0 if layer_past is None else layer_past[0].shape[1] + seq_length_with_past = seq_length + past_length + + attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) + if alibi is None and self.config.alibi: + alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype) + attention_mask = FalconModel._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length) + + outputs = super().forward( + hidden_states, + *args, + attention_mask=attention_mask, + alibi=alibi, + layer_past=layer_past, + use_cache=use_cache, + **kwargs, + ) + + if use_cache: + present_key_value = outputs[-1] + present_key_value = self._reorder_cache_from_falcon_to_bloom(present_key_value) + outputs = outputs[:-1] + (present_key_value,) + + return outputs + + def _reorder_cache_from_bloom_to_falcon(self, key_value: KVCache) -> KVCache: + key_states, value_states = key_value + + key_states = key_states.permute(0, 2, 1) + assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim] + + if self.config.new_decoder_architecture: + key_states = self._expand_states(key_states) + value_states = self._expand_states(value_states) + + return (key_states, value_states) + + def _reorder_cache_from_falcon_to_bloom(self, key_value: KVCache) -> KVCache: + key_states, value_states = key_value + + if self.config.new_decoder_architecture: + key_states = self._collapse_states(key_states) + value_states = self._collapse_states(value_states) + + assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim] + key_states = key_states.permute(0, 2, 1) + + return (key_states, value_states) + + def _expand_states(self, state: torch.Tensor) -> torch.Tensor: + batch_size_x_num_kv_heads, seq_len, head_dim = state.shape + batch_size = batch_size_x_num_kv_heads // self.config.num_kv_heads + + state = state.view(batch_size, self.config.num_kv_heads, 1, seq_len, head_dim) + state = state.expand(-1, -1, self.config.num_key_value_groups, -1, -1) # No copy + state = state.reshape(batch_size * self.config.num_attention_heads, seq_len, head_dim) # Involves a copy + return state + + def _collapse_states(self, state: torch.Tensor) -> torch.Tensor: + batch_size_x_num_attn_heads, seq_len, head_dim = state.shape + batch_size = batch_size_x_num_attn_heads // self.config.num_attention_heads + + state = state.view(batch_size, self.config.num_kv_heads, self.config.num_key_value_groups, seq_len, head_dim) + state = state[:, :, 0] + state = state.view(batch_size * self.config.num_kv_heads, seq_len, head_dim) + return state + + +@pytest.mark.skipif("falcon" not in MODEL_NAME, reason="This test is applicable only to Falcon models") +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +@pytest.mark.forked +def test_falcon(device): + if device == "cuda:0" and not torch.cuda.is_available(): + pytest.skip("CUDA tests can be run only in CUDA-enabled setups") + + config = AutoDistributedConfig.from_pretrained(MODEL_NAME) + + tensor_parallel_devices = (device,) + dtype = torch.bfloat16 + quant_type = QuantType.NONE + + block = config.block_class(config).to(dtype) + block = convert_block(block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True) + + unopt_block = UnoptimizedWrappedFalconBlock(config).to(dtype) + unopt_block = convert_block( + unopt_block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True + ) + + unopt_block.load_state_dict(block.state_dict()) + cache = unopt_cache = None + + with torch.inference_mode(): + for length in [10, 1, 1, 1]: + dummy_input = torch.randn(1, length, config.hidden_size, device=device, dtype=dtype) + block_output, cache = block(dummy_input, layer_past=cache, use_cache=True) + unopt_block_output, unopt_cache = unopt_block(dummy_input, layer_past=unopt_cache, use_cache=True) + assert torch.allclose(block_output, unopt_block_output, atol=1e-6, rtol=0), length + assert torch.allclose(cache[0], unopt_cache[0], atol=1e-6, rtol=0), length + assert torch.allclose(cache[1], unopt_cache[1], atol=1e-6, rtol=0), length