pull/498/merge
Max Ryabinin 7 months ago committed by GitHub
commit 947bf2387b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -3,13 +3,118 @@ LLaMA intermediate layer
Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
See commit history for authorship.
"""
import math
from typing import Optional, Tuple
import torch
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
import torch.nn as nn
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaConfig,
LlamaDecoderLayer,
LlamaMLP,
LlamaModel,
LlamaRMSNorm,
apply_rotary_pos_emb,
repeat_kv,
)
class WrappedLlamaBlock(LlamaDecoderLayer):
class OptimizedLlamaAttention(LlamaAttention):
def __init__(self, config: LlamaConfig):
super().__init__(config)
self.qkv_proj = nn.Linear(
self.hidden_size, (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, bias=False
)
self.qkv_sizes = [
self.num_heads * self.head_dim,
self.num_key_value_heads * self.head_dim,
self.num_key_value_heads * self.head_dim,
]
self.attn_norm_constant = math.sqrt(self.head_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
assert (
self.config.pretraining_tp == 1
), "OptimizedLlamaAttention assumes that config.pretraining_tp is equal to 1"
assert not output_attentions, "output_attentions=True is not supported"
query_states, key_states, value_states = torch.split(self.qkv_proj(hidden_states), self.qkv_sizes, dim=2)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / self.attn_norm_constant
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
class OptimizedLlamaDecoderLayer(LlamaDecoderLayer):
def __init__(self, config: LlamaConfig):
nn.Module.__init__(self)
self.hidden_size = config.hidden_size
self.self_attn = OptimizedLlamaAttention(config=config)
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
class WrappedLlamaBlock(OptimizedLlamaDecoderLayer):
def forward(
self,
hidden_states: torch.Tensor,

@ -65,14 +65,21 @@ def load_pretrained_block(
# dummy load, check that keys match
report = block.load_state_dict(state_dict, strict=False)
if "self_attn.qkv_proj.weight" in report.missing_keys:
report.missing_keys.remove("self_attn.qkv_proj.weight") # will be filled later
assert not report.missing_keys, f"Some block weights are missing: {report.missing_keys}"
for param_name, _ in block.named_parameters():
assert param_name in state_dict, f"{param_name} not in state dict"
param = state_dict[param_name]
if not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
param = param.to(torch_dtype)
set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype)
if param_name != "self_attn.qkv_proj.weight":
assert param_name in state_dict, f"{param_name} not in state dict"
param = state_dict[param_name]
if not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
param = param.to(torch_dtype)
set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype)
else:
cur_block = getattr(block, param_name)
dummy_value = torch.empty_like(cur_block, device="cpu")
set_module_tensor_to_device(block, param_name, "cpu", dummy_value)
logger.info(f"Loaded {model_name} block {block_index}")
logger.debug(f"Details: {report}")

@ -1,6 +1,5 @@
import fcntl
import json
import math
import multiprocessing as mp
import os
import time

@ -50,6 +50,19 @@ def convert_block(
if freeze:
block.requires_grad_(False)
if hasattr(block, "self_attn") and hasattr(block.self_attn, "qkv_proj"):
offset = 0
for data in [
block.self_attn.q_proj.weight.data,
block.self_attn.k_proj.weight.data,
block.self_attn.v_proj.weight.data,
]:
block.self_attn.qkv_proj.weight.data[offset : offset + data.size(0)].copy_(data)
offset += data.size(0)
del block.self_attn.q_proj
del block.self_attn.k_proj
del block.self_attn.v_proj
block = make_tensor_parallel(block, config, tensor_parallel_devices, output_device=output_device)
if quant_type != QuantType.NONE:

Loading…
Cancel
Save