Optimize LLaMA for inference (#513)

* Optimize LLaMa for inference
* Fix model type detection in tests
pull/543/head
Max Ryabinin 6 months ago committed by GitHub
parent 25a0796b39
commit 03cbe90234
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -3,13 +3,219 @@ LLaMA intermediate layer
Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
See commit history for authorship.
"""
import math
from typing import Optional, Tuple
import torch
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
import torch.nn as nn
import torch.nn.functional as F
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaConfig,
LlamaDecoderLayer,
LlamaMLP,
LlamaModel,
LlamaRMSNorm,
repeat_kv,
rotate_half,
)
from petals.utils.cuda_graphs import make_inference_graphed_callable
def apply_rotary_pos_emb(q, k, cos, sin):
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class OptimizedLlamaAttention(LlamaAttention):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._rotary_graph = None
def _optimized_apply_rotary(self, query_states, key_states, cos, sin):
if self._rotary_graph is None:
self._rotary_graph = make_inference_graphed_callable(
apply_rotary_pos_emb, sample_args=(query_states, key_states, cos, sin)
)
return self._rotary_graph(query_states, key_states, cos, sin)
class WrappedLlamaBlock(LlamaDecoderLayer):
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
assert not output_attentions
assert position_ids is None
bsz, q_len, _ = hidden_states.size()
if self.config.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
query_states = torch.cat(query_states, dim=-1)
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
key_states = torch.cat(key_states, dim=-1)
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
cos = cos[:, :, kv_seq_len - q_len :]
sin = sin[:, :, kv_seq_len - q_len :]
if q_len == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
query_states, key_states = self._optimized_apply_rotary(query_states, key_states, cos, sin)
else:
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
if self.config.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
else:
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
class OptimizedLlamaDecoderLayer(LlamaDecoderLayer):
def __init__(self, config: LlamaConfig):
nn.Module.__init__(self)
self.hidden_size = config.hidden_size
self.self_attn = OptimizedLlamaAttention(config=config)
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pre_attn_graph = None
self.post_attn_graph = None
def _optimized_input_layernorm(self, hidden_states):
if self.pre_attn_graph is None:
self.pre_attn_graph = make_inference_graphed_callable(
self.input_layernorm.forward, sample_args=(hidden_states,)
)
return self.pre_attn_graph(hidden_states)
def _optimized_output_layernorm(self, hidden_states):
if self.post_attn_graph is None:
self.post_attn_graph = make_inference_graphed_callable(
self.post_attention_layernorm.forward, sample_args=(hidden_states,)
)
return self.post_attn_graph(hidden_states)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
residual = hidden_states
if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
hidden_states = self._optimized_input_layernorm(hidden_states)
else:
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
hidden_states = self._optimized_output_layernorm(hidden_states)
else:
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
class WrappedLlamaBlock(OptimizedLlamaDecoderLayer):
def forward(
self,
hidden_states: torch.Tensor,
@ -31,14 +237,7 @@ class WrappedLlamaBlock(LlamaDecoderLayer):
seq_length_with_past = seq_length_with_past + past_key_values_length
past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length)
if position_ids is None:
device = hidden_states.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
assert position_ids is None
# embed positions
if attention_mask is None:

@ -0,0 +1,76 @@
import torch
from torch.utils._pytree import tree_flatten as _tree_flatten, tree_unflatten as _tree_unflatten
def make_inference_graphed_callable(callable: callable, sample_args, num_warmup_iters=3):
"""Similar to torch.cuda.make_graphed_callables, but takes only one function and does not build a graph for the backward pass"""
assert not isinstance(callable, torch.nn.Module)
if torch.is_autocast_enabled() and torch.is_autocast_cache_enabled():
raise RuntimeError(
"make_graphed_callables does not support the autocast caching. Please set `cache_enabled=False`."
)
flatten_arg, _ = _tree_flatten(sample_args)
flatten_sample_args = tuple(flatten_arg)
assert all(
isinstance(arg, torch.Tensor) for arg in flatten_arg
), "In the beta API, sample_args for each callable must contain only Tensors. Other types are not allowed."
len_user_args = len(sample_args)
static_input_surface = flatten_sample_args
graph = torch.cuda.CUDAGraph()
# Warmup
# Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work
# from ending up in any captures.
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(num_warmup_iters):
outputs, _ = _tree_flatten(callable(*sample_args))
del outputs
torch.cuda.current_stream().wait_stream(s)
# Capture forward graph
with torch.cuda.graph(graph):
outputs = callable(*sample_args)
flatten_outputs, output_unflatten_spec = _tree_flatten(outputs)
static_outputs = tuple(flatten_outputs)
def make_graphed_function(
graph,
len_user_args,
output_unflatten_spec,
static_input_surface,
static_outputs,
):
def replay_graph(*inputs):
# At this stage, only the user args may (potentially) be new tensors.
for i in range(len_user_args):
if static_input_surface[i].data_ptr() != inputs[i].data_ptr():
static_input_surface[i].copy_(inputs[i])
graph.replay()
assert isinstance(static_outputs, tuple)
return tuple(o.detach() for o in static_outputs)
def functionalized(*user_args):
# Runs the autograd function with inputs == all inputs to the graph that might require grad
# (explicit user args + module parameters)
# Assumes module params didn't change since capture.
flatten_user_args, _ = _tree_flatten(user_args)
out = replay_graph(*flatten_user_args)
return _tree_unflatten(out, output_unflatten_spec)
return functionalized
# Put together the final graphed callable
graphed = make_graphed_function(
graph,
len_user_args,
output_unflatten_spec,
static_input_surface,
static_outputs,
)
return graphed

@ -3,6 +3,7 @@ from typing import Optional, Tuple
import pytest
import torch
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
from petals.utils.auto_config import AutoDistributedConfig
from petals.utils.convert_block import QuantType, convert_block
@ -94,10 +95,91 @@ class UnoptimizedWrappedFalconBlock(FalconDecoderLayer):
return state
@pytest.mark.skipif("falcon" not in MODEL_NAME, reason="This test is applicable only to Falcon models")
class UnoptimizedWrappedLlamaBlock(LlamaDecoderLayer):
def forward(
self,
hidden_states: torch.Tensor,
*args,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
layer_past: Optional[Tuple[torch.Tensor]] = None,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
batch_size, seq_length, _ = hidden_states.shape
seq_length_with_past = seq_length
past_key_values_length = 0
past_key_value = layer_past
if past_key_value is not None:
past_key_values_length = past_key_value[0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length)
if position_ids is None:
device = hidden_states.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
)
attention_mask = LlamaModel._prepare_decoder_attention_mask(
None, attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
)
outputs = super().forward(
hidden_states,
*args,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
use_cache=use_cache,
**kwargs,
)
if use_cache:
present_key_value = outputs[-1]
present_key_value = self._reorder_cache_from_llama_to_bloom(
present_key_value, batch_size, seq_length_with_past
)
outputs = outputs[:-1] + (present_key_value,)
return outputs
def _reorder_cache_from_bloom_to_llama(
self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
) -> Tuple[torch.Tensor]:
key_states, value_states = key_value
key_states = key_states.permute(0, 2, 1)
key_states = key_states.view(
batch_size, self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
)
value_states = value_states.view(*key_states.shape)
return (key_states, value_states)
def _reorder_cache_from_llama_to_bloom(
self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
) -> Tuple[torch.Tensor]:
key_states, value_states = key_value
value_states = value_states.view(
batch_size * self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
)
key_states = key_states.view(*value_states.shape)
key_states = key_states.permute(0, 2, 1)
return (key_states, value_states)
@pytest.mark.parametrize("device", ["cpu", "cuda:0"])
@pytest.mark.forked
def test_falcon(device):
def test_optimized_block(device):
if device == "cuda:0" and not torch.cuda.is_available():
pytest.skip("CUDA tests can be run only in CUDA-enabled setups")
@ -108,11 +190,17 @@ def test_falcon(device):
quant_type = QuantType.NONE
block = config.block_class(config).to(dtype)
block = convert_block(block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
block = convert_block(block, 1, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
if config.model_type == "falcon":
unopt_block = UnoptimizedWrappedFalconBlock(config).to(dtype)
elif config.model_type == "llama":
unopt_block = UnoptimizedWrappedLlamaBlock(config).to(dtype)
else:
pytest.skip(f"This test is not applicable to {config.model_type} models")
unopt_block = UnoptimizedWrappedFalconBlock(config).to(dtype)
unopt_block = convert_block(
unopt_block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True
unopt_block, 1, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True
)
unopt_block.load_state_dict(block.state_dict())

Loading…
Cancel
Save