You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
petals/src/petals/models/llama/block.py

206 lines
8.2 KiB
Python

"""
LLaMA intermediate layer
Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
See commit history for authorship.
"""
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaConfig,
LlamaDecoderLayer,
LlamaMLP,
LlamaModel,
LlamaRMSNorm,
repeat_kv,
apply_rotary_pos_emb,
)
from petals.triton import attention_triton_wrapper, rbe_triton_wrapper, rmsnorm_triton_wrapper
class OptimizedLlamaRMSNorm(LlamaRMSNorm):
def forward(self, hidden_states):
if torch.is_inference_mode_enabled():
return rmsnorm_triton_wrapper(hidden_states, self.weight)
return super().forward(hidden_states)
class OptimizedLlamaAttention(LlamaAttention):
def __init__(self, config: LlamaConfig):
super().__init__(config)
self.qkv_proj = nn.Linear(
self.hidden_size, (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, bias=False
)
self.qkv_sizes = [
self.num_heads * self.head_dim,
self.num_key_value_heads * self.head_dim,
self.num_key_value_heads * self.head_dim,
]
self.attn_norm_constant = math.sqrt(self.head_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
assert (
self.config.pretraining_tp == 1
), "OptimizedLlamaAttention assumes that config.pretraining_tp is equal to 1"
query_states, key_states, value_states = torch.split(self.qkv_proj(hidden_states), self.qkv_sizes, dim=2)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / self.attn_norm_constant
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
class OptimizedLlamaDecoderLayer(LlamaDecoderLayer):
def __init__(self, config: LlamaConfig):
nn.Module.__init__(self)
self.hidden_size = config.hidden_size
self.self_attn = OptimizedLlamaAttention(config=config)
self.mlp = LlamaMLP(config)
self.input_layernorm = OptimizedLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = OptimizedLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
class WrappedLlamaBlock(OptimizedLlamaDecoderLayer):
def forward(
self,
hidden_states: torch.Tensor,
*args,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
layer_past: Optional[Tuple[torch.Tensor]] = None,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
batch_size, seq_length, _ = hidden_states.shape
seq_length_with_past = seq_length
past_key_values_length = 0
past_key_value = layer_past
if past_key_value is not None:
past_key_values_length = past_key_value[0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length)
if position_ids is None:
device = hidden_states.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
)
attention_mask = LlamaModel._prepare_decoder_attention_mask(
None, attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
)
outputs = super().forward(
hidden_states,
*args,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
use_cache=use_cache,
**kwargs,
)
if use_cache:
present_key_value = outputs[-1]
present_key_value = self._reorder_cache_from_llama_to_bloom(
present_key_value, batch_size, seq_length_with_past
)
outputs = outputs[:-1] + (present_key_value,)
return outputs
def _reorder_cache_from_bloom_to_llama(
self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
) -> Tuple[torch.Tensor]:
key_states, value_states = key_value
key_states = key_states.permute(0, 2, 1)
key_states = key_states.view(
batch_size, self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
)
value_states = value_states.view(*key_states.shape)
return (key_states, value_states)
def _reorder_cache_from_llama_to_bloom(
self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
) -> Tuple[torch.Tensor]:
key_states, value_states = key_value
value_states = value_states.view(
batch_size * self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
)
key_states = key_states.view(*value_states.shape)
key_states = key_states.permute(0, 2, 1)
return (key_states, value_states)