Optimize the Falcon block for inference (#500)

This PR attempts to optimize the inference of Falcon models in the single-token setup by reducing the majority of Python overhead and making several assumptions about the setup. Specifically,

* Layer normalization, QKV projection (with splitting) and rotary embeddings are executed through CUDA graphs, which reduces most overhead related to small kernel launche
* If no sin/cos tensors are cached by the rotary embedding layer, we cache them for 8192 tokens (INFERENCE_MAX_LENGTH) during the first forward pass. In general, it should be beneficial to always run a max-length sequence before starting a block, but this is a question for another PR

The PR also adds a small test to ensure that the results (without quantization) of the block before and after quantization indeed match.

Lastly, the pull request makes the backward pass work (as discussed in https://github.com/bigscience-workshop/petals/pull/499) by making cached sin/cos for RotaryEmbedding into buffers and disabling the inference mode during their creation.
pull/504/head
Max Ryabinin 8 months ago committed by GitHub
parent d40eb6c701
commit 1ebd88ae7b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -3,15 +3,399 @@ Falcon intermediate layer
Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py
See commit history for authorship.
"""
import math
from functools import partial
from typing import Optional, Tuple
import torch
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor
import torch.nn as nn
import torch.nn.functional as F
from transformers.models.falcon.modeling_falcon import (
FalconAttention,
FalconConfig,
FalconDecoderLayer,
FalconLinear,
FalconMLP,
FalconModel,
LayerNorm,
build_alibi_tensor,
dropout_add,
rotate_half,
)
KVCache = Tuple[torch.Tensor, torch.Tensor]
INFERENCE_MAX_LENGTH = 8192
class WrappedFalconBlock(FalconDecoderLayer):
def apply_rotary(query, key, cos, sin):
return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin)
class OptimizedFalconRotaryEmbedding(nn.Module):
def __init__(self, head_dim: int, base=10000):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.head_dim = head_dim
self.seq_len_cached = -1
self.cuda_graph = None
self.input_surface = None
self.static_outputs = None
def _optimized_apply_rotary(self, query, key, cos, sin):
if self.cuda_graph is None:
self.cuda_graph = torch.cuda.CUDAGraph()
self.input_surface = (query, key, cos, sin)
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3):
apply_rotary(*self.input_surface)
torch.cuda.current_stream().wait_stream(s)
with torch.cuda.graph(self.cuda_graph):
self.static_outputs = apply_rotary(*self.input_surface)
inputs = (query, key, cos, sin)
for static_input, data in zip(self.input_surface, inputs):
static_input.copy_(data)
self.cuda_graph.replay()
return tuple(o.detach() for o in self.static_outputs)
def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor:
total_length = seq_len + past_key_values_length
if self.seq_len_cached == -1:
# warm up the cache
total_length = max(INFERENCE_MAX_LENGTH, total_length)
if total_length > self.seq_len_cached:
with torch.inference_mode(False):
self.seq_len_cached = total_length
t = torch.arange(total_length, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(device)
if dtype in [torch.float16, torch.bfloat16]:
emb = emb.float()
self.register_buffer("cos_cached", emb.cos()[None, :, :].type(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, :, :].type(dtype), persistent=False)
return (
self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length].type(dtype),
self.sin_cached[:, past_key_values_length : seq_len + past_key_values_length].type(dtype),
)
def forward(self, query, key, past_key_values_length=0):
batch, seq_len, head_dim = query.shape
cos, sin = self.cos_sin(seq_len, past_key_values_length, query.device, query.dtype)
if seq_len == 1 and torch.is_inference_mode_enabled() and query.device.type == "cuda":
return self._optimized_apply_rotary(query, key, cos, sin)
else:
return apply_rotary(query, key, cos, sin)
def split_heads(
fused_qkv: torch.Tensor, num_heads: int, num_kv_heads: int, head_dim: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
batch, seq_len, _ = fused_qkv.shape
qkv = fused_qkv.view(batch, seq_len, -1, num_heads // num_kv_heads + 2, head_dim)
query, key, value = torch.split(qkv, [num_heads // num_kv_heads, 1, 1], dim=3)
key = torch.broadcast_to(key, query.shape)
value = torch.broadcast_to(value, query.shape)
query, key, value = [x.flatten(2, 3) for x in (query, key, value)]
return query, key, value
class OptimizedFalconAttention(FalconAttention):
def __init__(self, config: FalconConfig):
nn.Module.__init__(self)
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.split_size = self.hidden_size
self.hidden_dropout = config.hidden_dropout
if self.head_dim * self.num_heads != self.hidden_size:
raise ValueError(
f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
f" {self.num_heads})."
)
self.maybe_rotary = OptimizedFalconRotaryEmbedding(config.head_dim) if config.rotary else lambda q, k, t: (q, k)
# Layer-wise attention scaling
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
self.beta = self.inv_norm_factor
if config.new_decoder_architecture:
qkv_out_dim = (config.num_kv_heads * 2 + config.num_attention_heads) * self.head_dim
elif config.multi_query:
qkv_out_dim = self.hidden_size + 2 * self.head_dim
else:
qkv_out_dim = 3 * self.hidden_size
self.query_key_value = FalconLinear(self.hidden_size, qkv_out_dim, bias=config.bias)
self.new_decoder_architecture = config.new_decoder_architecture
self.multi_query = config.multi_query
self.dense = FalconLinear(self.hidden_size, self.hidden_size, bias=config.bias)
self.attention_dropout = nn.Dropout(config.attention_dropout)
self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1
if self.new_decoder_architecture:
self._split_heads = partial(
split_heads, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, head_dim=self.head_dim
)
self.split_graph = None
self.input_surface = None
self.static_outputs = None
def _optimized_split_heads(self, fused_qkv):
if self.split_graph is None:
self.split_graph = torch.cuda.CUDAGraph()
self.input_surface = fused_qkv
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3):
self._split_heads(fused_qkv)
torch.cuda.current_stream().wait_stream(s)
with torch.cuda.graph(self.split_graph):
self.static_outputs = self._split_heads(self.input_surface)
self.input_surface.copy_(fused_qkv)
self.split_graph.replay()
return tuple(o.detach() for o in self.static_outputs)
def forward(
self,
hidden_states: torch.Tensor,
alibi: Optional[torch.Tensor],
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
assert not output_attentions
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
if (
self.new_decoder_architecture
and hidden_states.size(1) == 1
and torch.is_inference_mode_enabled()
and hidden_states.device.type == "cuda"
):
query_layer, key_layer, value_layer = self._optimized_split_heads(fused_qkv)
else:
# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
num_kv_heads = self.num_heads
batch_size, query_length, _, _ = query_layer.shape
query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim)
key_layer = key_layer.transpose(1, 2).reshape(
batch_size * num_kv_heads,
query_length,
self.head_dim,
)
value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim)
past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
if layer_past is not None:
past_key, past_value = layer_past
# concatenate along seq_length dimension:
# - key: [batch_size * self.num_heads, kv_length, head_dim]
# - value: [batch_size * self.num_heads, kv_length, head_dim]
key_layer = torch.cat((past_key, key_layer), dim=1)
value_layer = torch.cat((past_value, value_layer), dim=1)
_, kv_length, _ = key_layer.shape
if use_cache:
present = (key_layer, value_layer)
else:
present = None
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype)
if alibi is None:
attn_output = F.scaled_dot_product_attention(
query_layer_, key_layer_, value_layer_, attn_mask=attention_mask_float, dropout_p=0.0, is_causal=False
)
attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim)
attn_output = attn_output.permute(0, 2, 1, 3)
attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
output_tensor = self.dense(attn_output)
return output_tensor, present
else:
matmul_result = query_layer_ @ key_layer_.transpose(-1, -2)
# change view to [batch_size, num_heads, q_length, kv_length]
attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length)
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
input_dtype = attention_scores.dtype
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
attention_scores = attention_scores.to(torch.float32)
# Matt (HF) note: We could possibly use F.scaled_dot_product_attention here too, by
# adding (alibi * self.inv_norm_factor) to attention_mask_float. I think this would be mathematically
# equivalent and more performant, but there might be a numerical difference. If you're reading this
# and you'd like to experiment and maybe file a PR, feel free!
attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)
attention_logits *= self.inv_norm_factor
attention_probs = F.softmax(attention_logits + attention_mask_float, dim=-1, dtype=hidden_states.dtype)
# [batch_size, num_heads, q_length, kv_length]
attention_probs = self.attention_dropout(attention_probs)
if head_mask is not None:
attention_probs = attention_probs * head_mask
# change view [batch_size, num_heads, q_length, kv_length]
attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length)
# matmul: [batch_size * num_heads, q_length, head_dim]
context_layer = (attention_probs_reshaped @ value_layer_).flatten(0, 1)
# change view [batch_size, q_length, num_heads * head_dim]
context_layer = self._merge_heads(context_layer)
output_tensor = self.dense(context_layer)
if output_attentions:
return output_tensor, present, attention_probs
else:
return output_tensor, present
class OptimizedFalconDecoderLayer(FalconDecoderLayer):
def __init__(self, config: FalconConfig):
nn.Module.__init__(self)
hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.mlp = FalconMLP(config)
self.hidden_dropout = config.hidden_dropout
self.config = config
self.self_attention = OptimizedFalconAttention(config)
if self.config.alibi or not config.new_decoder_architecture:
if config.new_decoder_architecture:
# The layer norm before self-attention
self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
# The layer norm before the MLP
self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
else:
self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
if not config.parallel_attn:
self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
else:
self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.ln_graph = None
self.static_input = None
self.static_outputs = None
def _optimized_apply_ln(self, hidden_states):
if self.ln_graph is None:
self.ln_graph = torch.cuda.CUDAGraph()
self.static_input = hidden_states
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3):
self.ln_attn(hidden_states)
self.ln_mlp(hidden_states)
torch.cuda.current_stream().wait_stream(s)
with torch.cuda.graph(self.ln_graph):
ln_attn_output = self.ln_attn(hidden_states)
ln_mlp_output = self.ln_mlp(hidden_states)
self.static_outputs = (ln_attn_output, ln_mlp_output)
self.static_input.copy_(hidden_states)
self.ln_graph.replay()
return tuple(o.detach() for o in self.static_outputs)
def forward(
self,
hidden_states: torch.Tensor,
alibi: Optional[torch.Tensor],
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
residual = hidden_states
if self.config.new_decoder_architecture:
if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
attention_layernorm_out, mlp_layernorm_out = self._optimized_apply_ln(hidden_states)
else:
attention_layernorm_out = self.ln_attn(hidden_states)
mlp_layernorm_out = self.ln_mlp(hidden_states)
else:
attention_layernorm_out = self.input_layernorm(hidden_states)
attn_outputs = self.self_attention(
attention_layernorm_out,
layer_past=layer_past,
attention_mask=attention_mask,
alibi=alibi,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
attention_output = attn_outputs[0]
if not self.config.new_decoder_architecture:
if self.config.parallel_attn:
mlp_layernorm_out = attention_layernorm_out
else:
residual = dropout_add(
attention_output, residual, self.config.attention_dropout, training=self.training
)
mlp_layernorm_out = self.post_attention_layernorm(residual)
outputs = attn_outputs[1:]
mlp_output = self.mlp(mlp_layernorm_out)
if self.config.new_decoder_architecture or self.config.parallel_attn:
mlp_output += attention_output
output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training)
if use_cache:
outputs = (output,) + outputs
else:
outputs = (output,) + outputs[1:]
return outputs # hidden_states, present, attentions
class WrappedFalconBlock(OptimizedFalconDecoderLayer):
def forward(
self,
hidden_states: torch.Tensor,
@ -20,8 +404,10 @@ class WrappedFalconBlock(FalconDecoderLayer):
alibi: Optional[torch.Tensor] = None,
layer_past: Optional[KVCache] = None,
use_cache: bool = False,
**kwargs
**kwargs,
):
assert attention_mask is None
batch_size, seq_length = hidden_states.shape[:2]
if layer_past is not None:
@ -41,7 +427,7 @@ class WrappedFalconBlock(FalconDecoderLayer):
alibi=alibi,
layer_past=layer_past,
use_cache=use_cache,
**kwargs
**kwargs,
)
if use_cache:

@ -0,0 +1,128 @@
from typing import Optional, Tuple
import pytest
import torch
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor
from petals.utils.auto_config import AutoDistributedConfig
from petals.utils.convert_block import QuantType, convert_block
from test_utils import MODEL_NAME
KVCache = Tuple[torch.Tensor, torch.Tensor]
class UnoptimizedWrappedFalconBlock(FalconDecoderLayer):
def forward(
self,
hidden_states: torch.Tensor,
*args,
attention_mask: Optional[torch.Tensor] = None,
alibi: Optional[torch.Tensor] = None,
layer_past: Optional[KVCache] = None,
use_cache: bool = False,
**kwargs,
):
batch_size, seq_length = hidden_states.shape[:2]
if layer_past is not None:
layer_past = self._reorder_cache_from_bloom_to_falcon(layer_past)
past_length = 0 if layer_past is None else layer_past[0].shape[1]
seq_length_with_past = seq_length + past_length
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
if alibi is None and self.config.alibi:
alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
attention_mask = FalconModel._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length)
outputs = super().forward(
hidden_states,
*args,
attention_mask=attention_mask,
alibi=alibi,
layer_past=layer_past,
use_cache=use_cache,
**kwargs,
)
if use_cache:
present_key_value = outputs[-1]
present_key_value = self._reorder_cache_from_falcon_to_bloom(present_key_value)
outputs = outputs[:-1] + (present_key_value,)
return outputs
def _reorder_cache_from_bloom_to_falcon(self, key_value: KVCache) -> KVCache:
key_states, value_states = key_value
key_states = key_states.permute(0, 2, 1)
assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim]
if self.config.new_decoder_architecture:
key_states = self._expand_states(key_states)
value_states = self._expand_states(value_states)
return (key_states, value_states)
def _reorder_cache_from_falcon_to_bloom(self, key_value: KVCache) -> KVCache:
key_states, value_states = key_value
if self.config.new_decoder_architecture:
key_states = self._collapse_states(key_states)
value_states = self._collapse_states(value_states)
assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim]
key_states = key_states.permute(0, 2, 1)
return (key_states, value_states)
def _expand_states(self, state: torch.Tensor) -> torch.Tensor:
batch_size_x_num_kv_heads, seq_len, head_dim = state.shape
batch_size = batch_size_x_num_kv_heads // self.config.num_kv_heads
state = state.view(batch_size, self.config.num_kv_heads, 1, seq_len, head_dim)
state = state.expand(-1, -1, self.config.num_key_value_groups, -1, -1) # No copy
state = state.reshape(batch_size * self.config.num_attention_heads, seq_len, head_dim) # Involves a copy
return state
def _collapse_states(self, state: torch.Tensor) -> torch.Tensor:
batch_size_x_num_attn_heads, seq_len, head_dim = state.shape
batch_size = batch_size_x_num_attn_heads // self.config.num_attention_heads
state = state.view(batch_size, self.config.num_kv_heads, self.config.num_key_value_groups, seq_len, head_dim)
state = state[:, :, 0]
state = state.view(batch_size * self.config.num_kv_heads, seq_len, head_dim)
return state
@pytest.mark.skipif("falcon" not in MODEL_NAME, reason="This test is applicable only to Falcon models")
@pytest.mark.parametrize("device", ["cpu", "cuda:0"])
@pytest.mark.forked
def test_falcon(device):
if device == "cuda:0" and not torch.cuda.is_available():
pytest.skip("CUDA tests can be run only in CUDA-enabled setups")
config = AutoDistributedConfig.from_pretrained(MODEL_NAME)
tensor_parallel_devices = (device,)
dtype = torch.bfloat16
quant_type = QuantType.NONE
block = config.block_class(config).to(dtype)
block = convert_block(block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
unopt_block = UnoptimizedWrappedFalconBlock(config).to(dtype)
unopt_block = convert_block(
unopt_block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True
)
unopt_block.load_state_dict(block.state_dict())
cache = unopt_cache = None
with torch.inference_mode():
for length in [10, 1, 1, 1]:
dummy_input = torch.randn(1, length, config.hidden_size, device=device, dtype=dtype)
block_output, cache = block(dummy_input, layer_past=cache, use_cache=True)
unopt_block_output, unopt_cache = unopt_block(dummy_input, layer_past=unopt_cache, use_cache=True)
assert torch.allclose(block_output, unopt_block_output, atol=1e-6, rtol=0), length
assert torch.allclose(cache[0], unopt_cache[0], atol=1e-6, rtol=0), length
assert torch.allclose(cache[1], unopt_cache[1], atol=1e-6, rtol=0), length
Loading…
Cancel
Save