""" LLaMA intermediate layer Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py See commit history for authorship. """ from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F import math from transformers.models.llama.modeling_llama import ( LlamaAttention, LlamaConfig, LlamaDecoderLayer, LlamaMLP, LlamaModel, LlamaRMSNorm, repeat_kv, apply_rotary_pos_emb, ) from petals.triton import attention_triton_wrapper, rbe_triton_wrapper, rmsnorm_triton_wrapper class OptimizedLlamaRMSNorm(LlamaRMSNorm): def forward(self, hidden_states): if torch.is_inference_mode_enabled(): return rmsnorm_triton_wrapper(hidden_states, self.weight) return super().forward(hidden_states) class OptimizedLlamaAttention(LlamaAttention): def __init__(self, config: LlamaConfig): super().__init__(config) self.qkv_proj = nn.Linear( self.hidden_size, (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, bias=False ) self.qkv_sizes = [ self.num_heads * self.head_dim, self.num_key_value_heads * self.head_dim, self.num_key_value_heads * self.head_dim, ] self.attn_norm_constant = math.sqrt(self.head_dim) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() assert ( self.config.pretraining_tp == 1 ), "OptimizedLlamaAttention assumes that config.pretraining_tp is equal to 1" query_states, key_states, value_states = torch.split(self.qkv_proj(hidden_states), self.qkv_sizes, dim=2) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: # reuse k, v, self_attention key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) past_key_value = (key_states, value_states) if use_cache else None # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / self.attn_norm_constant if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" f" {attn_weights.size()}" ) if attention_mask is not None: if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" ) attn_weights = attn_weights + attention_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value class OptimizedLlamaDecoderLayer(LlamaDecoderLayer): def __init__(self, config: LlamaConfig): nn.Module.__init__(self) self.hidden_size = config.hidden_size self.self_attn = OptimizedLlamaAttention(config=config) self.mlp = LlamaMLP(config) self.input_layernorm = OptimizedLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = OptimizedLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) class WrappedLlamaBlock(OptimizedLlamaDecoderLayer): def forward( self, hidden_states: torch.Tensor, *args, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, layer_past: Optional[Tuple[torch.Tensor]] = None, use_cache: bool = False, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: batch_size, seq_length, _ = hidden_states.shape seq_length_with_past = seq_length past_key_values_length = 0 past_key_value = layer_past if past_key_value is not None: past_key_values_length = past_key_value[0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length) if position_ids is None: device = hidden_states.device position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device ) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() # embed positions if attention_mask is None: attention_mask = torch.ones( (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device ) attention_mask = LlamaModel._prepare_decoder_attention_mask( None, attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length ) outputs = super().forward( hidden_states, *args, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache, **kwargs, ) if use_cache: present_key_value = outputs[-1] present_key_value = self._reorder_cache_from_llama_to_bloom( present_key_value, batch_size, seq_length_with_past ) outputs = outputs[:-1] + (present_key_value,) return outputs def _reorder_cache_from_bloom_to_llama( self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int ) -> Tuple[torch.Tensor]: key_states, value_states = key_value key_states = key_states.permute(0, 2, 1) key_states = key_states.view( batch_size, self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim ) value_states = value_states.view(*key_states.shape) return (key_states, value_states) def _reorder_cache_from_llama_to_bloom( self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int ) -> Tuple[torch.Tensor]: key_states, value_states = key_value value_states = value_states.view( batch_size * self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim ) key_states = key_states.view(*value_states.shape) key_states = key_states.permute(0, 2, 1) return (key_states, value_states)