From 03cbe90234ccd4e3cf749d9370f53bea2a1dcb67 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Tue, 14 Nov 2023 18:14:19 +0100 Subject: [PATCH] Optimize LLaMA for inference (#513) * Optimize LLaMa for inference * Fix model type detection in tests --- src/petals/models/llama/block.py | 219 +++++++++++++++++++++++++++++-- src/petals/utils/cuda_graphs.py | 76 +++++++++++ tests/test_optimized_layers.py | 98 +++++++++++++- 3 files changed, 378 insertions(+), 15 deletions(-) create mode 100644 src/petals/utils/cuda_graphs.py diff --git a/src/petals/models/llama/block.py b/src/petals/models/llama/block.py index 55f659a..a8d433d 100644 --- a/src/petals/models/llama/block.py +++ b/src/petals/models/llama/block.py @@ -3,13 +3,219 @@ LLaMA intermediate layer Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py See commit history for authorship. """ +import math from typing import Optional, Tuple import torch -from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel +import torch.nn as nn +import torch.nn.functional as F +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaConfig, + LlamaDecoderLayer, + LlamaMLP, + LlamaModel, + LlamaRMSNorm, + repeat_kv, + rotate_half, +) +from petals.utils.cuda_graphs import make_inference_graphed_callable + + +def apply_rotary_pos_emb(q, k, cos, sin): + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class OptimizedLlamaAttention(LlamaAttention): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._rotary_graph = None + + def _optimized_apply_rotary(self, query_states, key_states, cos, sin): + if self._rotary_graph is None: + self._rotary_graph = make_inference_graphed_callable( + apply_rotary_pos_emb, sample_args=(query_states, key_states, cos, sin) + ) + return self._rotary_graph(query_states, key_states, cos, sin) -class WrappedLlamaBlock(LlamaDecoderLayer): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + assert not output_attentions + assert position_ids is None + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos = cos[:, :, kv_seq_len - q_len :] + sin = sin[:, :, kv_seq_len - q_len :] + + if q_len == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda": + query_states, key_states = self._optimized_apply_rotary(query_states, key_states, cos, sin) + else: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +class OptimizedLlamaDecoderLayer(LlamaDecoderLayer): + def __init__(self, config: LlamaConfig): + nn.Module.__init__(self) + self.hidden_size = config.hidden_size + self.self_attn = OptimizedLlamaAttention(config=config) + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.pre_attn_graph = None + self.post_attn_graph = None + + def _optimized_input_layernorm(self, hidden_states): + if self.pre_attn_graph is None: + self.pre_attn_graph = make_inference_graphed_callable( + self.input_layernorm.forward, sample_args=(hidden_states,) + ) + return self.pre_attn_graph(hidden_states) + + def _optimized_output_layernorm(self, hidden_states): + if self.post_attn_graph is None: + self.post_attn_graph = make_inference_graphed_callable( + self.post_attention_layernorm.forward, sample_args=(hidden_states,) + ) + return self.post_attn_graph(hidden_states) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda": + hidden_states = self._optimized_input_layernorm(hidden_states) + else: + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + + if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda": + hidden_states = self._optimized_output_layernorm(hidden_states) + else: + hidden_states = self.post_attention_layernorm(hidden_states) + + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class WrappedLlamaBlock(OptimizedLlamaDecoderLayer): def forward( self, hidden_states: torch.Tensor, @@ -31,14 +237,7 @@ class WrappedLlamaBlock(LlamaDecoderLayer): seq_length_with_past = seq_length_with_past + past_key_values_length past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length) - if position_ids is None: - device = hidden_states.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() + assert position_ids is None # embed positions if attention_mask is None: diff --git a/src/petals/utils/cuda_graphs.py b/src/petals/utils/cuda_graphs.py new file mode 100644 index 0000000..216ecf1 --- /dev/null +++ b/src/petals/utils/cuda_graphs.py @@ -0,0 +1,76 @@ +import torch +from torch.utils._pytree import tree_flatten as _tree_flatten, tree_unflatten as _tree_unflatten + + +def make_inference_graphed_callable(callable: callable, sample_args, num_warmup_iters=3): + """Similar to torch.cuda.make_graphed_callables, but takes only one function and does not build a graph for the backward pass""" + assert not isinstance(callable, torch.nn.Module) + if torch.is_autocast_enabled() and torch.is_autocast_cache_enabled(): + raise RuntimeError( + "make_graphed_callables does not support the autocast caching. Please set `cache_enabled=False`." + ) + + flatten_arg, _ = _tree_flatten(sample_args) + flatten_sample_args = tuple(flatten_arg) + assert all( + isinstance(arg, torch.Tensor) for arg in flatten_arg + ), "In the beta API, sample_args for each callable must contain only Tensors. Other types are not allowed." + + len_user_args = len(sample_args) + static_input_surface = flatten_sample_args + + graph = torch.cuda.CUDAGraph() + + # Warmup + # Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work + # from ending up in any captures. + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(num_warmup_iters): + outputs, _ = _tree_flatten(callable(*sample_args)) + del outputs + torch.cuda.current_stream().wait_stream(s) + + # Capture forward graph + with torch.cuda.graph(graph): + outputs = callable(*sample_args) + + flatten_outputs, output_unflatten_spec = _tree_flatten(outputs) + static_outputs = tuple(flatten_outputs) + + def make_graphed_function( + graph, + len_user_args, + output_unflatten_spec, + static_input_surface, + static_outputs, + ): + def replay_graph(*inputs): + # At this stage, only the user args may (potentially) be new tensors. + for i in range(len_user_args): + if static_input_surface[i].data_ptr() != inputs[i].data_ptr(): + static_input_surface[i].copy_(inputs[i]) + graph.replay() + assert isinstance(static_outputs, tuple) + return tuple(o.detach() for o in static_outputs) + + def functionalized(*user_args): + # Runs the autograd function with inputs == all inputs to the graph that might require grad + # (explicit user args + module parameters) + # Assumes module params didn't change since capture. + flatten_user_args, _ = _tree_flatten(user_args) + out = replay_graph(*flatten_user_args) + return _tree_unflatten(out, output_unflatten_spec) + + return functionalized + + # Put together the final graphed callable + graphed = make_graphed_function( + graph, + len_user_args, + output_unflatten_spec, + static_input_surface, + static_outputs, + ) + return graphed diff --git a/tests/test_optimized_layers.py b/tests/test_optimized_layers.py index 5baa1a2..84cbfff 100644 --- a/tests/test_optimized_layers.py +++ b/tests/test_optimized_layers.py @@ -3,6 +3,7 @@ from typing import Optional, Tuple import pytest import torch from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor +from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel from petals.utils.auto_config import AutoDistributedConfig from petals.utils.convert_block import QuantType, convert_block @@ -94,10 +95,91 @@ class UnoptimizedWrappedFalconBlock(FalconDecoderLayer): return state -@pytest.mark.skipif("falcon" not in MODEL_NAME, reason="This test is applicable only to Falcon models") +class UnoptimizedWrappedLlamaBlock(LlamaDecoderLayer): + def forward( + self, + hidden_states: torch.Tensor, + *args, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + layer_past: Optional[Tuple[torch.Tensor]] = None, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + batch_size, seq_length, _ = hidden_states.shape + + seq_length_with_past = seq_length + past_key_values_length = 0 + + past_key_value = layer_past + if past_key_value is not None: + past_key_values_length = past_key_value[0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length) + + if position_ids is None: + device = hidden_states.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device + ) + attention_mask = LlamaModel._prepare_decoder_attention_mask( + None, attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length + ) + + outputs = super().forward( + hidden_states, + *args, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + use_cache=use_cache, + **kwargs, + ) + + if use_cache: + present_key_value = outputs[-1] + present_key_value = self._reorder_cache_from_llama_to_bloom( + present_key_value, batch_size, seq_length_with_past + ) + outputs = outputs[:-1] + (present_key_value,) + + return outputs + + def _reorder_cache_from_bloom_to_llama( + self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int + ) -> Tuple[torch.Tensor]: + key_states, value_states = key_value + key_states = key_states.permute(0, 2, 1) + key_states = key_states.view( + batch_size, self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim + ) + value_states = value_states.view(*key_states.shape) + return (key_states, value_states) + + def _reorder_cache_from_llama_to_bloom( + self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int + ) -> Tuple[torch.Tensor]: + key_states, value_states = key_value + value_states = value_states.view( + batch_size * self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim + ) + key_states = key_states.view(*value_states.shape) + key_states = key_states.permute(0, 2, 1) + return (key_states, value_states) + + @pytest.mark.parametrize("device", ["cpu", "cuda:0"]) @pytest.mark.forked -def test_falcon(device): +def test_optimized_block(device): if device == "cuda:0" and not torch.cuda.is_available(): pytest.skip("CUDA tests can be run only in CUDA-enabled setups") @@ -108,11 +190,17 @@ def test_falcon(device): quant_type = QuantType.NONE block = config.block_class(config).to(dtype) - block = convert_block(block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True) + block = convert_block(block, 1, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True) + + if config.model_type == "falcon": + unopt_block = UnoptimizedWrappedFalconBlock(config).to(dtype) + elif config.model_type == "llama": + unopt_block = UnoptimizedWrappedLlamaBlock(config).to(dtype) + else: + pytest.skip(f"This test is not applicable to {config.model_type} models") - unopt_block = UnoptimizedWrappedFalconBlock(config).to(dtype) unopt_block = convert_block( - unopt_block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True + unopt_block, 1, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True ) unopt_block.load_state_dict(block.state_dict())