""" LLaMA intermediate layer Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py See commit history for authorship. """ import math from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.models.llama.modeling_llama import ( LlamaAttention, LlamaConfig, LlamaDecoderLayer, LlamaMLP, LlamaModel, LlamaRMSNorm, repeat_kv, rotate_half, ) from petals.utils.cuda_graphs import make_inference_graphed_callable def apply_rotary_pos_emb(q, k, cos, sin): q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed class OptimizedLlamaAttention(LlamaAttention): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._rotary_graph = None def _optimized_apply_rotary(self, query_states, key_states, cos, sin): if self._rotary_graph is None: self._rotary_graph = make_inference_graphed_callable( apply_rotary_pos_emb, sample_args=(query_states, key_states, cos, sin) ) return self._rotary_graph(query_states, key_states, cos, sin) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: assert not output_attentions if position_ids is None: past_seen_tokens = past_key_value[0].shape[2] if past_key_value is not None else 0 position_ids = torch.arange( past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device ).unsqueeze(0) bsz, q_len, _ = hidden_states.size() if self.config.pretraining_tp > 1: key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp query_slices = self.q_proj.weight.split( (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 ) key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] query_states = torch.cat(query_states, dim=-1) key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] key_states = torch.cat(key_states, dim=-1) value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] value_states = torch.cat(value_states, dim=-1) else: query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) cos, sin = cos.unsqueeze(1), sin.unsqueeze(1) if q_len == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda": query_states, key_states = self._optimized_apply_rotary(query_states, key_states, cos, sin) else: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: # reuse k, v, self_attention key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) past_key_value = (key_states, value_states) if use_cache else None # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attention_mask is not None: attn_weights = attn_weights + attention_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) if self.config.pretraining_tp > 1: attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) else: attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value class OptimizedLlamaDecoderLayer(LlamaDecoderLayer): def __init__(self, config: LlamaConfig): nn.Module.__init__(self) self.hidden_size = config.hidden_size self.self_attn = OptimizedLlamaAttention(config=config) self.mlp = LlamaMLP(config) self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.pre_attn_graph = None self.post_attn_graph = None def _optimized_input_layernorm(self, hidden_states): if self.pre_attn_graph is None: self.pre_attn_graph = make_inference_graphed_callable( self.input_layernorm.forward, sample_args=(hidden_states,) ) return self.pre_attn_graph(hidden_states) def _optimized_output_layernorm(self, hidden_states): if self.post_attn_graph is None: self.post_attn_graph = make_inference_graphed_callable( self.post_attention_layernorm.forward, sample_args=(hidden_states,) ) return self.post_attn_graph(hidden_states) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states """ residual = hidden_states if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda": hidden_states = self._optimized_input_layernorm(hidden_states) else: hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, **kwargs, ) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda": hidden_states = self._optimized_output_layernorm(hidden_states) else: hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) if use_cache: outputs += (present_key_value,) return outputs class WrappedLlamaBlock(OptimizedLlamaDecoderLayer): def forward( self, hidden_states: torch.Tensor, *args, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, layer_past: Optional[Tuple[torch.Tensor]] = None, use_cache: bool = False, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: batch_size, seq_length, _ = hidden_states.shape seq_length_with_past = seq_length past_key_values_length = 0 past_key_value = layer_past if past_key_value is not None: past_key_values_length = past_key_value[0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length) assert position_ids is None # embed positions if attention_mask is None: attention_mask = torch.ones( (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device ) attention_mask = _prepare_4d_causal_attention_mask( attention_mask=attention_mask, input_shape=(batch_size, seq_length), inputs_embeds=hidden_states, past_key_values_length=past_key_values_length, ) outputs = super().forward( hidden_states, *args, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache, **kwargs, ) if use_cache: present_key_value = outputs[-1] present_key_value = self._reorder_cache_from_llama_to_bloom( present_key_value, batch_size, seq_length_with_past ) outputs = outputs[:-1] + (present_key_value,) return outputs def _reorder_cache_from_bloom_to_llama( self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int ) -> Tuple[torch.Tensor]: key_states, value_states = key_value key_states = key_states.permute(0, 2, 1) key_states = key_states.view( batch_size, self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim ) value_states = value_states.view(*key_states.shape) return (key_states, value_states) def _reorder_cache_from_llama_to_bloom( self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int ) -> Tuple[torch.Tensor]: key_states, value_states = key_value value_states = value_states.view( batch_size * self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim ) key_states = key_states.view(*value_states.shape) key_states = key_states.permute(0, 2, 1) return (key_states, value_states)