Merge remote-tracking branch 'origin/main' into forward_kwargs
commit
6256995bb1
@ -1,2 +1,3 @@
|
|||||||
from petals.models.bloom import *
|
from petals.models.bloom import *
|
||||||
|
from petals.models.falcon import *
|
||||||
from petals.models.llama import *
|
from petals.models.llama import *
|
||||||
|
@ -0,0 +1,15 @@
|
|||||||
|
from petals.models.falcon.block import WrappedFalconBlock
|
||||||
|
from petals.models.falcon.config import DistributedFalconConfig
|
||||||
|
from petals.models.falcon.model import (
|
||||||
|
DistributedFalconForCausalLM,
|
||||||
|
DistributedFalconForSequenceClassification,
|
||||||
|
DistributedFalconModel,
|
||||||
|
)
|
||||||
|
from petals.utils.auto_config import register_model_classes
|
||||||
|
|
||||||
|
register_model_classes(
|
||||||
|
config=DistributedFalconConfig,
|
||||||
|
model=DistributedFalconModel,
|
||||||
|
model_for_causal_lm=DistributedFalconForCausalLM,
|
||||||
|
model_for_sequence_classification=DistributedFalconForSequenceClassification,
|
||||||
|
)
|
@ -0,0 +1,480 @@
|
|||||||
|
"""
|
||||||
|
Falcon intermediate layer
|
||||||
|
Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py
|
||||||
|
See commit history for authorship.
|
||||||
|
"""
|
||||||
|
import math
|
||||||
|
from functools import partial
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from transformers.models.falcon.modeling_falcon import (
|
||||||
|
FalconAttention,
|
||||||
|
FalconConfig,
|
||||||
|
FalconDecoderLayer,
|
||||||
|
FalconLinear,
|
||||||
|
FalconMLP,
|
||||||
|
FalconModel,
|
||||||
|
LayerNorm,
|
||||||
|
build_alibi_tensor,
|
||||||
|
dropout_add,
|
||||||
|
rotate_half,
|
||||||
|
)
|
||||||
|
|
||||||
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
INFERENCE_MAX_LENGTH = 8192
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary(query, key, cos, sin):
|
||||||
|
return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin)
|
||||||
|
|
||||||
|
|
||||||
|
class OptimizedFalconRotaryEmbedding(nn.Module):
|
||||||
|
def __init__(self, head_dim: int, base=10000):
|
||||||
|
super().__init__()
|
||||||
|
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
|
||||||
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.seq_len_cached = -1
|
||||||
|
|
||||||
|
self.cuda_graph = None
|
||||||
|
self.input_surface = None
|
||||||
|
self.static_outputs = None
|
||||||
|
|
||||||
|
def _optimized_apply_rotary(self, query, key, cos, sin):
|
||||||
|
if self.cuda_graph is None:
|
||||||
|
self.cuda_graph = torch.cuda.CUDAGraph()
|
||||||
|
self.input_surface = (query, key, cos, sin)
|
||||||
|
|
||||||
|
s = torch.cuda.Stream()
|
||||||
|
s.wait_stream(torch.cuda.current_stream())
|
||||||
|
with torch.cuda.stream(s):
|
||||||
|
for _ in range(3):
|
||||||
|
apply_rotary(*self.input_surface)
|
||||||
|
torch.cuda.current_stream().wait_stream(s)
|
||||||
|
|
||||||
|
with torch.cuda.graph(self.cuda_graph):
|
||||||
|
self.static_outputs = apply_rotary(*self.input_surface)
|
||||||
|
|
||||||
|
inputs = (query, key, cos, sin)
|
||||||
|
for static_input, data in zip(self.input_surface, inputs):
|
||||||
|
static_input.copy_(data)
|
||||||
|
self.cuda_graph.replay()
|
||||||
|
return tuple(o.detach() for o in self.static_outputs)
|
||||||
|
|
||||||
|
def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor:
|
||||||
|
total_length = seq_len + past_key_values_length
|
||||||
|
if self.seq_len_cached == -1:
|
||||||
|
# warm up the cache
|
||||||
|
total_length = max(INFERENCE_MAX_LENGTH, total_length)
|
||||||
|
|
||||||
|
if total_length > self.seq_len_cached:
|
||||||
|
with torch.inference_mode(False):
|
||||||
|
self.seq_len_cached = total_length
|
||||||
|
t = torch.arange(total_length, device=device, dtype=self.inv_freq.dtype)
|
||||||
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||||
|
emb = torch.cat((freqs, freqs), dim=-1).to(device)
|
||||||
|
|
||||||
|
if dtype in [torch.float16, torch.bfloat16]:
|
||||||
|
emb = emb.float()
|
||||||
|
|
||||||
|
self.register_buffer("cos_cached", emb.cos()[None, :, :].type(dtype), persistent=False)
|
||||||
|
self.register_buffer("sin_cached", emb.sin()[None, :, :].type(dtype), persistent=False)
|
||||||
|
|
||||||
|
return (
|
||||||
|
self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length].type(dtype),
|
||||||
|
self.sin_cached[:, past_key_values_length : seq_len + past_key_values_length].type(dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, query, key, past_key_values_length=0):
|
||||||
|
batch, seq_len, head_dim = query.shape
|
||||||
|
cos, sin = self.cos_sin(seq_len, past_key_values_length, query.device, query.dtype)
|
||||||
|
if seq_len == 1 and torch.is_inference_mode_enabled() and query.device.type == "cuda":
|
||||||
|
return self._optimized_apply_rotary(query, key, cos, sin)
|
||||||
|
else:
|
||||||
|
return apply_rotary(query, key, cos, sin)
|
||||||
|
|
||||||
|
|
||||||
|
def split_heads(
|
||||||
|
fused_qkv: torch.Tensor, num_heads: int, num_kv_heads: int, head_dim: int
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
batch, seq_len, _ = fused_qkv.shape
|
||||||
|
qkv = fused_qkv.view(batch, seq_len, -1, num_heads // num_kv_heads + 2, head_dim)
|
||||||
|
query, key, value = torch.split(qkv, [num_heads // num_kv_heads, 1, 1], dim=3)
|
||||||
|
key = torch.broadcast_to(key, query.shape)
|
||||||
|
value = torch.broadcast_to(value, query.shape)
|
||||||
|
|
||||||
|
query, key, value = [x.flatten(2, 3) for x in (query, key, value)]
|
||||||
|
return query, key, value
|
||||||
|
|
||||||
|
|
||||||
|
class OptimizedFalconAttention(FalconAttention):
|
||||||
|
def __init__(self, config: FalconConfig):
|
||||||
|
nn.Module.__init__(self)
|
||||||
|
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.head_dim = self.hidden_size // self.num_heads
|
||||||
|
self.split_size = self.hidden_size
|
||||||
|
self.hidden_dropout = config.hidden_dropout
|
||||||
|
|
||||||
|
if self.head_dim * self.num_heads != self.hidden_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
|
||||||
|
f" {self.num_heads})."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.maybe_rotary = OptimizedFalconRotaryEmbedding(config.head_dim) if config.rotary else lambda q, k, t: (q, k)
|
||||||
|
|
||||||
|
# Layer-wise attention scaling
|
||||||
|
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
|
||||||
|
self.beta = self.inv_norm_factor
|
||||||
|
if config.new_decoder_architecture:
|
||||||
|
qkv_out_dim = (config.num_kv_heads * 2 + config.num_attention_heads) * self.head_dim
|
||||||
|
elif config.multi_query:
|
||||||
|
qkv_out_dim = self.hidden_size + 2 * self.head_dim
|
||||||
|
else:
|
||||||
|
qkv_out_dim = 3 * self.hidden_size
|
||||||
|
self.query_key_value = FalconLinear(self.hidden_size, qkv_out_dim, bias=config.bias)
|
||||||
|
self.new_decoder_architecture = config.new_decoder_architecture
|
||||||
|
self.multi_query = config.multi_query
|
||||||
|
self.dense = FalconLinear(self.hidden_size, self.hidden_size, bias=config.bias)
|
||||||
|
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
||||||
|
self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1
|
||||||
|
|
||||||
|
if self.new_decoder_architecture:
|
||||||
|
self._split_heads = partial(
|
||||||
|
split_heads, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, head_dim=self.head_dim
|
||||||
|
)
|
||||||
|
self.split_graph = None
|
||||||
|
self.input_surface = None
|
||||||
|
self.static_outputs = None
|
||||||
|
|
||||||
|
def _optimized_split_heads(self, fused_qkv):
|
||||||
|
if self.split_graph is None:
|
||||||
|
self.split_graph = torch.cuda.CUDAGraph()
|
||||||
|
self.input_surface = fused_qkv
|
||||||
|
|
||||||
|
s = torch.cuda.Stream()
|
||||||
|
s.wait_stream(torch.cuda.current_stream())
|
||||||
|
with torch.cuda.stream(s):
|
||||||
|
for _ in range(3):
|
||||||
|
self._split_heads(fused_qkv)
|
||||||
|
torch.cuda.current_stream().wait_stream(s)
|
||||||
|
|
||||||
|
with torch.cuda.graph(self.split_graph):
|
||||||
|
self.static_outputs = self._split_heads(self.input_surface)
|
||||||
|
|
||||||
|
self.input_surface.copy_(fused_qkv)
|
||||||
|
self.split_graph.replay()
|
||||||
|
return tuple(o.detach() for o in self.static_outputs)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
alibi: Optional[torch.Tensor],
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
|
use_cache: bool = False,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
):
|
||||||
|
assert not output_attentions
|
||||||
|
|
||||||
|
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.new_decoder_architecture
|
||||||
|
and hidden_states.size(1) == 1
|
||||||
|
and torch.is_inference_mode_enabled()
|
||||||
|
and hidden_states.device.type == "cuda"
|
||||||
|
):
|
||||||
|
query_layer, key_layer, value_layer = self._optimized_split_heads(fused_qkv)
|
||||||
|
else:
|
||||||
|
# 3 x [batch_size, seq_length, num_heads, head_dim]
|
||||||
|
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
|
||||||
|
|
||||||
|
num_kv_heads = self.num_heads
|
||||||
|
batch_size, query_length, _, _ = query_layer.shape
|
||||||
|
|
||||||
|
query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim)
|
||||||
|
key_layer = key_layer.transpose(1, 2).reshape(
|
||||||
|
batch_size * num_kv_heads,
|
||||||
|
query_length,
|
||||||
|
self.head_dim,
|
||||||
|
)
|
||||||
|
value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim)
|
||||||
|
|
||||||
|
past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
|
||||||
|
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
|
||||||
|
|
||||||
|
if layer_past is not None:
|
||||||
|
past_key, past_value = layer_past
|
||||||
|
# concatenate along seq_length dimension:
|
||||||
|
# - key: [batch_size * self.num_heads, kv_length, head_dim]
|
||||||
|
# - value: [batch_size * self.num_heads, kv_length, head_dim]
|
||||||
|
key_layer = torch.cat((past_key, key_layer), dim=1)
|
||||||
|
value_layer = torch.cat((past_value, value_layer), dim=1)
|
||||||
|
|
||||||
|
_, kv_length, _ = key_layer.shape
|
||||||
|
if use_cache:
|
||||||
|
present = (key_layer, value_layer)
|
||||||
|
else:
|
||||||
|
present = None
|
||||||
|
|
||||||
|
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
||||||
|
key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
|
||||||
|
value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
|
||||||
|
|
||||||
|
attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype)
|
||||||
|
|
||||||
|
if alibi is None:
|
||||||
|
attn_output = F.scaled_dot_product_attention(
|
||||||
|
query_layer_, key_layer_, value_layer_, attn_mask=attention_mask_float, dropout_p=0.0, is_causal=False
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim)
|
||||||
|
attn_output = attn_output.permute(0, 2, 1, 3)
|
||||||
|
attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
|
||||||
|
|
||||||
|
output_tensor = self.dense(attn_output)
|
||||||
|
|
||||||
|
return output_tensor, present
|
||||||
|
else:
|
||||||
|
matmul_result = query_layer_ @ key_layer_.transpose(-1, -2)
|
||||||
|
|
||||||
|
# change view to [batch_size, num_heads, q_length, kv_length]
|
||||||
|
attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length)
|
||||||
|
|
||||||
|
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
|
||||||
|
input_dtype = attention_scores.dtype
|
||||||
|
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
|
||||||
|
if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
|
||||||
|
attention_scores = attention_scores.to(torch.float32)
|
||||||
|
# Matt (HF) note: We could possibly use F.scaled_dot_product_attention here too, by
|
||||||
|
# adding (alibi * self.inv_norm_factor) to attention_mask_float. I think this would be mathematically
|
||||||
|
# equivalent and more performant, but there might be a numerical difference. If you're reading this
|
||||||
|
# and you'd like to experiment and maybe file a PR, feel free!
|
||||||
|
attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)
|
||||||
|
attention_logits *= self.inv_norm_factor
|
||||||
|
attention_probs = F.softmax(attention_logits + attention_mask_float, dim=-1, dtype=hidden_states.dtype)
|
||||||
|
# [batch_size, num_heads, q_length, kv_length]
|
||||||
|
attention_probs = self.attention_dropout(attention_probs)
|
||||||
|
|
||||||
|
if head_mask is not None:
|
||||||
|
attention_probs = attention_probs * head_mask
|
||||||
|
|
||||||
|
# change view [batch_size, num_heads, q_length, kv_length]
|
||||||
|
attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length)
|
||||||
|
|
||||||
|
# matmul: [batch_size * num_heads, q_length, head_dim]
|
||||||
|
context_layer = (attention_probs_reshaped @ value_layer_).flatten(0, 1)
|
||||||
|
|
||||||
|
# change view [batch_size, q_length, num_heads * head_dim]
|
||||||
|
context_layer = self._merge_heads(context_layer)
|
||||||
|
|
||||||
|
output_tensor = self.dense(context_layer)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
return output_tensor, present, attention_probs
|
||||||
|
else:
|
||||||
|
return output_tensor, present
|
||||||
|
|
||||||
|
|
||||||
|
class OptimizedFalconDecoderLayer(FalconDecoderLayer):
|
||||||
|
def __init__(self, config: FalconConfig):
|
||||||
|
nn.Module.__init__(self)
|
||||||
|
hidden_size = config.hidden_size
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
|
||||||
|
self.mlp = FalconMLP(config)
|
||||||
|
self.hidden_dropout = config.hidden_dropout
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.self_attention = OptimizedFalconAttention(config)
|
||||||
|
|
||||||
|
if self.config.alibi or not config.new_decoder_architecture:
|
||||||
|
if config.new_decoder_architecture:
|
||||||
|
# The layer norm before self-attention
|
||||||
|
self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||||
|
# The layer norm before the MLP
|
||||||
|
self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||||
|
else:
|
||||||
|
self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||||
|
if not config.parallel_attn:
|
||||||
|
self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||||
|
self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||||
|
|
||||||
|
self.ln_graph = None
|
||||||
|
self.static_input = None
|
||||||
|
self.static_outputs = None
|
||||||
|
|
||||||
|
def _optimized_apply_ln(self, hidden_states):
|
||||||
|
if self.ln_graph is None:
|
||||||
|
self.ln_graph = torch.cuda.CUDAGraph()
|
||||||
|
self.static_input = hidden_states
|
||||||
|
|
||||||
|
s = torch.cuda.Stream()
|
||||||
|
s.wait_stream(torch.cuda.current_stream())
|
||||||
|
with torch.cuda.stream(s):
|
||||||
|
for _ in range(3):
|
||||||
|
self.ln_attn(hidden_states)
|
||||||
|
self.ln_mlp(hidden_states)
|
||||||
|
torch.cuda.current_stream().wait_stream(s)
|
||||||
|
|
||||||
|
with torch.cuda.graph(self.ln_graph):
|
||||||
|
ln_attn_output = self.ln_attn(hidden_states)
|
||||||
|
ln_mlp_output = self.ln_mlp(hidden_states)
|
||||||
|
self.static_outputs = (ln_attn_output, ln_mlp_output)
|
||||||
|
|
||||||
|
self.static_input.copy_(hidden_states)
|
||||||
|
self.ln_graph.replay()
|
||||||
|
return tuple(o.detach() for o in self.static_outputs)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
alibi: Optional[torch.Tensor],
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
|
use_cache: bool = False,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
):
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
if self.config.new_decoder_architecture:
|
||||||
|
if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
|
||||||
|
attention_layernorm_out, mlp_layernorm_out = self._optimized_apply_ln(hidden_states)
|
||||||
|
else:
|
||||||
|
attention_layernorm_out = self.ln_attn(hidden_states)
|
||||||
|
mlp_layernorm_out = self.ln_mlp(hidden_states)
|
||||||
|
else:
|
||||||
|
attention_layernorm_out = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
|
attn_outputs = self.self_attention(
|
||||||
|
attention_layernorm_out,
|
||||||
|
layer_past=layer_past,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
alibi=alibi,
|
||||||
|
head_mask=head_mask,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
attention_output = attn_outputs[0]
|
||||||
|
|
||||||
|
if not self.config.new_decoder_architecture:
|
||||||
|
if self.config.parallel_attn:
|
||||||
|
mlp_layernorm_out = attention_layernorm_out
|
||||||
|
else:
|
||||||
|
residual = dropout_add(
|
||||||
|
attention_output, residual, self.config.attention_dropout, training=self.training
|
||||||
|
)
|
||||||
|
mlp_layernorm_out = self.post_attention_layernorm(residual)
|
||||||
|
|
||||||
|
outputs = attn_outputs[1:]
|
||||||
|
|
||||||
|
mlp_output = self.mlp(mlp_layernorm_out)
|
||||||
|
|
||||||
|
if self.config.new_decoder_architecture or self.config.parallel_attn:
|
||||||
|
mlp_output += attention_output
|
||||||
|
|
||||||
|
output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training)
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
outputs = (output,) + outputs
|
||||||
|
else:
|
||||||
|
outputs = (output,) + outputs[1:]
|
||||||
|
|
||||||
|
return outputs # hidden_states, present, attentions
|
||||||
|
|
||||||
|
|
||||||
|
class WrappedFalconBlock(OptimizedFalconDecoderLayer):
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
*args,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
alibi: Optional[torch.Tensor] = None,
|
||||||
|
layer_past: Optional[KVCache] = None,
|
||||||
|
use_cache: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
assert attention_mask is None
|
||||||
|
|
||||||
|
batch_size, seq_length = hidden_states.shape[:2]
|
||||||
|
|
||||||
|
if layer_past is not None:
|
||||||
|
layer_past = self._reorder_cache_from_bloom_to_falcon(layer_past)
|
||||||
|
past_length = 0 if layer_past is None else layer_past[0].shape[1]
|
||||||
|
seq_length_with_past = seq_length + past_length
|
||||||
|
|
||||||
|
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
|
||||||
|
if alibi is None and self.config.alibi:
|
||||||
|
alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
|
||||||
|
attention_mask = FalconModel._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length)
|
||||||
|
|
||||||
|
outputs = super().forward(
|
||||||
|
hidden_states,
|
||||||
|
*args,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
alibi=alibi,
|
||||||
|
layer_past=layer_past,
|
||||||
|
use_cache=use_cache,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
present_key_value = outputs[-1]
|
||||||
|
present_key_value = self._reorder_cache_from_falcon_to_bloom(present_key_value)
|
||||||
|
outputs = outputs[:-1] + (present_key_value,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def _reorder_cache_from_bloom_to_falcon(self, key_value: KVCache) -> KVCache:
|
||||||
|
key_states, value_states = key_value
|
||||||
|
|
||||||
|
key_states = key_states.permute(0, 2, 1)
|
||||||
|
assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim]
|
||||||
|
|
||||||
|
if self.config.new_decoder_architecture:
|
||||||
|
key_states = self._expand_states(key_states)
|
||||||
|
value_states = self._expand_states(value_states)
|
||||||
|
|
||||||
|
return (key_states, value_states)
|
||||||
|
|
||||||
|
def _reorder_cache_from_falcon_to_bloom(self, key_value: KVCache) -> KVCache:
|
||||||
|
key_states, value_states = key_value
|
||||||
|
|
||||||
|
if self.config.new_decoder_architecture:
|
||||||
|
key_states = self._collapse_states(key_states)
|
||||||
|
value_states = self._collapse_states(value_states)
|
||||||
|
|
||||||
|
assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim]
|
||||||
|
key_states = key_states.permute(0, 2, 1)
|
||||||
|
|
||||||
|
return (key_states, value_states)
|
||||||
|
|
||||||
|
def _expand_states(self, state: torch.Tensor) -> torch.Tensor:
|
||||||
|
batch_size_x_num_kv_heads, seq_len, head_dim = state.shape
|
||||||
|
batch_size = batch_size_x_num_kv_heads // self.config.num_kv_heads
|
||||||
|
|
||||||
|
state = state.view(batch_size, self.config.num_kv_heads, 1, seq_len, head_dim)
|
||||||
|
state = state.expand(-1, -1, self.config.num_key_value_groups, -1, -1) # No copy
|
||||||
|
state = state.reshape(batch_size * self.config.num_attention_heads, seq_len, head_dim) # Involves a copy
|
||||||
|
return state
|
||||||
|
|
||||||
|
def _collapse_states(self, state: torch.Tensor) -> torch.Tensor:
|
||||||
|
batch_size_x_num_attn_heads, seq_len, head_dim = state.shape
|
||||||
|
batch_size = batch_size_x_num_attn_heads // self.config.num_attention_heads
|
||||||
|
|
||||||
|
state = state.view(batch_size, self.config.num_kv_heads, self.config.num_key_value_groups, seq_len, head_dim)
|
||||||
|
state = state[:, :, 0]
|
||||||
|
state = state.view(batch_size * self.config.num_kv_heads, seq_len, head_dim)
|
||||||
|
return state
|
@ -0,0 +1,45 @@
|
|||||||
|
import os
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from hivemind import get_logger
|
||||||
|
from transformers.models.falcon import FalconConfig
|
||||||
|
from transformers.models.falcon.modeling_falcon import FalconAttention
|
||||||
|
|
||||||
|
from petals.client.config import ClientConfig
|
||||||
|
from petals.client.lm_head import LMHeadConfig
|
||||||
|
from petals.client.ptune import PTuneConfig
|
||||||
|
from petals.models.falcon.block import WrappedFalconBlock
|
||||||
|
from petals.utils.auto_config import DefaultRevisionMixin
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedFalconConfig(DefaultRevisionMixin, FalconConfig, ClientConfig, PTuneConfig, LMHeadConfig):
|
||||||
|
block_class = WrappedFalconBlock
|
||||||
|
attn_class = FalconAttention
|
||||||
|
block_prefix = "transformer.h"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_key_value_groups(self) -> int:
|
||||||
|
if self.new_decoder_architecture:
|
||||||
|
return self.num_attention_heads // self.num_kv_heads
|
||||||
|
if self.multi_query:
|
||||||
|
return self.num_attention_heads
|
||||||
|
return 1
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(
|
||||||
|
cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
|
||||||
|
):
|
||||||
|
loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
|
||||||
|
if loading_from_repo and dht_prefix is None:
|
||||||
|
dht_prefix = str(model_name_or_path)
|
||||||
|
dht_prefix = dht_prefix.split("/")[-1] # Use only repo name to merge blocks hosted by different accounts
|
||||||
|
dht_prefix = dht_prefix.replace(".", "-")
|
||||||
|
logger.info(f"Using DHT prefix: {dht_prefix}")
|
||||||
|
|
||||||
|
result = super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
|
||||||
|
config = result[0] if isinstance(result, tuple) else result
|
||||||
|
if config.pad_token_id is None:
|
||||||
|
config.pad_token_id = 0
|
||||||
|
return result
|
@ -0,0 +1,150 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import hivemind
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from hivemind.utils.logging import get_logger
|
||||||
|
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
|
||||||
|
from transformers.models.falcon import (
|
||||||
|
FalconForCausalLM,
|
||||||
|
FalconForSequenceClassification,
|
||||||
|
FalconModel,
|
||||||
|
FalconPreTrainedModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
from petals.client.from_pretrained import FromPretrainedMixin
|
||||||
|
from petals.client.lm_head import LMHead
|
||||||
|
from petals.client.ptune import PTuneMixin
|
||||||
|
from petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues
|
||||||
|
from petals.client.remote_sequential import RemoteSequential
|
||||||
|
from petals.models.falcon.config import DistributedFalconConfig
|
||||||
|
from petals.utils.auto_config import DefaultRevisionMixin
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedFalconModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMixin, FalconModel):
|
||||||
|
"""FalconModel, but all transformer layers are hosted by the swarm"""
|
||||||
|
|
||||||
|
_keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing
|
||||||
|
_keys_to_ignore_on_load_unexpected = [r"^transformer\.h\."]
|
||||||
|
|
||||||
|
config_class = DistributedFalconConfig
|
||||||
|
|
||||||
|
def __init__(self, config: DistributedFalconConfig, *, dht: Optional[hivemind.DHT] = None):
|
||||||
|
n_layer, config.num_hidden_layers = config.num_hidden_layers, 0 # Prevent initialization
|
||||||
|
super().__init__(config)
|
||||||
|
assert len(self.h) == 0
|
||||||
|
config.num_hidden_layers = n_layer
|
||||||
|
|
||||||
|
self.h = RemoteSequential(config, dht=dht)
|
||||||
|
|
||||||
|
self.requires_grad_(False) # Forbid accumulate grads for embeddings and layernorm
|
||||||
|
self.init_prompts(config)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[RemotePastKeyValues] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
head_mask: Optional[torch.LongTensor] = None,
|
||||||
|
inputs_embeds: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
):
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
input_shape = input_ids.size()
|
||||||
|
input_ids = input_ids.view(-1, input_shape[-1])
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
input_shape = inputs_embeds.size()[:-1]
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
# The causal mask will be added on the server-side
|
||||||
|
assert (
|
||||||
|
attention_mask is None or (attention_mask == 1).all()
|
||||||
|
), f"Custom attention masks are not supported, {attention_mask=}"
|
||||||
|
assert head_mask is None, f"Custom head masks are not supported, {head_mask=}"
|
||||||
|
assert use_cache is None or use_cache, f"{use_cache=} is not supported"
|
||||||
|
assert not output_attentions, f"{output_attentions=} is not supported"
|
||||||
|
assert not output_hidden_states, f"{output_hidden_states=} is not supported"
|
||||||
|
assert return_dict is None or return_dict, f"{return_dict=} is not supported"
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.word_embeddings(input_ids)
|
||||||
|
|
||||||
|
use_prompts = self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0
|
||||||
|
if use_prompts:
|
||||||
|
batch_size = inputs_embeds.shape[0]
|
||||||
|
prompts, intermediate_prompts = self.get_prompt(batch_size)
|
||||||
|
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
|
||||||
|
else:
|
||||||
|
prompts = intermediate_prompts = None
|
||||||
|
|
||||||
|
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
||||||
|
output_shape = input_shape + (hidden_states.size(-1),)
|
||||||
|
|
||||||
|
hidden_states = self.h(
|
||||||
|
hidden_states,
|
||||||
|
prompts=intermediate_prompts,
|
||||||
|
hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove prefix
|
||||||
|
if use_prompts:
|
||||||
|
hidden_states = hidden_states[:, self.pre_seq_len :]
|
||||||
|
|
||||||
|
# Add last hidden state
|
||||||
|
hidden_states = self.ln_f(hidden_states)
|
||||||
|
hidden_states = hidden_states.view(output_shape)
|
||||||
|
return BaseModelOutputWithPastAndCrossAttentions(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=RemotePastKeyValues(),
|
||||||
|
hidden_states=None,
|
||||||
|
attentions=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def word_embeddings_layernorm(self) -> nn.Module: # For compatibility with RemoteGenerationMixin
|
||||||
|
return nn.Identity()
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedFalconForCausalLM(DefaultRevisionMixin, FromPretrainedMixin, RemoteGenerationMixin, FalconForCausalLM):
|
||||||
|
_keys_to_ignore_on_load_missing = DistributedFalconModel._keys_to_ignore_on_load_missing
|
||||||
|
_keys_to_ignore_on_load_unexpected = DistributedFalconModel._keys_to_ignore_on_load_unexpected
|
||||||
|
|
||||||
|
config_class = DistributedFalconConfig
|
||||||
|
|
||||||
|
def __init__(self, config: DistributedFalconConfig):
|
||||||
|
FalconPreTrainedModel.__init__(self, config)
|
||||||
|
self.transformer = DistributedFalconModel(config)
|
||||||
|
self.lm_head = LMHead(config)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
return self.lm_head
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedFalconForSequenceClassification(
|
||||||
|
DefaultRevisionMixin, FromPretrainedMixin, FalconForSequenceClassification
|
||||||
|
):
|
||||||
|
_keys_to_ignore_on_load_missing = DistributedFalconModel._keys_to_ignore_on_load_missing
|
||||||
|
_keys_to_ignore_on_load_unexpected = DistributedFalconModel._keys_to_ignore_on_load_unexpected
|
||||||
|
|
||||||
|
config_class = DistributedFalconConfig
|
||||||
|
|
||||||
|
def __init__(self, config: DistributedFalconConfig):
|
||||||
|
FalconPreTrainedModel.__init__(self, config)
|
||||||
|
self.num_labels = config.num_labels
|
||||||
|
|
||||||
|
self.transformer = DistributedFalconModel(config)
|
||||||
|
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
@ -0,0 +1,128 @@
|
|||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor
|
||||||
|
|
||||||
|
from petals.utils.auto_config import AutoDistributedConfig
|
||||||
|
from petals.utils.convert_block import QuantType, convert_block
|
||||||
|
from test_utils import MODEL_NAME
|
||||||
|
|
||||||
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
class UnoptimizedWrappedFalconBlock(FalconDecoderLayer):
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
*args,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
alibi: Optional[torch.Tensor] = None,
|
||||||
|
layer_past: Optional[KVCache] = None,
|
||||||
|
use_cache: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
batch_size, seq_length = hidden_states.shape[:2]
|
||||||
|
|
||||||
|
if layer_past is not None:
|
||||||
|
layer_past = self._reorder_cache_from_bloom_to_falcon(layer_past)
|
||||||
|
past_length = 0 if layer_past is None else layer_past[0].shape[1]
|
||||||
|
seq_length_with_past = seq_length + past_length
|
||||||
|
|
||||||
|
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
|
||||||
|
if alibi is None and self.config.alibi:
|
||||||
|
alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
|
||||||
|
attention_mask = FalconModel._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length)
|
||||||
|
|
||||||
|
outputs = super().forward(
|
||||||
|
hidden_states,
|
||||||
|
*args,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
alibi=alibi,
|
||||||
|
layer_past=layer_past,
|
||||||
|
use_cache=use_cache,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
present_key_value = outputs[-1]
|
||||||
|
present_key_value = self._reorder_cache_from_falcon_to_bloom(present_key_value)
|
||||||
|
outputs = outputs[:-1] + (present_key_value,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def _reorder_cache_from_bloom_to_falcon(self, key_value: KVCache) -> KVCache:
|
||||||
|
key_states, value_states = key_value
|
||||||
|
|
||||||
|
key_states = key_states.permute(0, 2, 1)
|
||||||
|
assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim]
|
||||||
|
|
||||||
|
if self.config.new_decoder_architecture:
|
||||||
|
key_states = self._expand_states(key_states)
|
||||||
|
value_states = self._expand_states(value_states)
|
||||||
|
|
||||||
|
return (key_states, value_states)
|
||||||
|
|
||||||
|
def _reorder_cache_from_falcon_to_bloom(self, key_value: KVCache) -> KVCache:
|
||||||
|
key_states, value_states = key_value
|
||||||
|
|
||||||
|
if self.config.new_decoder_architecture:
|
||||||
|
key_states = self._collapse_states(key_states)
|
||||||
|
value_states = self._collapse_states(value_states)
|
||||||
|
|
||||||
|
assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim]
|
||||||
|
key_states = key_states.permute(0, 2, 1)
|
||||||
|
|
||||||
|
return (key_states, value_states)
|
||||||
|
|
||||||
|
def _expand_states(self, state: torch.Tensor) -> torch.Tensor:
|
||||||
|
batch_size_x_num_kv_heads, seq_len, head_dim = state.shape
|
||||||
|
batch_size = batch_size_x_num_kv_heads // self.config.num_kv_heads
|
||||||
|
|
||||||
|
state = state.view(batch_size, self.config.num_kv_heads, 1, seq_len, head_dim)
|
||||||
|
state = state.expand(-1, -1, self.config.num_key_value_groups, -1, -1) # No copy
|
||||||
|
state = state.reshape(batch_size * self.config.num_attention_heads, seq_len, head_dim) # Involves a copy
|
||||||
|
return state
|
||||||
|
|
||||||
|
def _collapse_states(self, state: torch.Tensor) -> torch.Tensor:
|
||||||
|
batch_size_x_num_attn_heads, seq_len, head_dim = state.shape
|
||||||
|
batch_size = batch_size_x_num_attn_heads // self.config.num_attention_heads
|
||||||
|
|
||||||
|
state = state.view(batch_size, self.config.num_kv_heads, self.config.num_key_value_groups, seq_len, head_dim)
|
||||||
|
state = state[:, :, 0]
|
||||||
|
state = state.view(batch_size * self.config.num_kv_heads, seq_len, head_dim)
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif("falcon" not in MODEL_NAME, reason="This test is applicable only to Falcon models")
|
||||||
|
@pytest.mark.parametrize("device", ["cpu", "cuda:0"])
|
||||||
|
@pytest.mark.forked
|
||||||
|
def test_falcon(device):
|
||||||
|
if device == "cuda:0" and not torch.cuda.is_available():
|
||||||
|
pytest.skip("CUDA tests can be run only in CUDA-enabled setups")
|
||||||
|
|
||||||
|
config = AutoDistributedConfig.from_pretrained(MODEL_NAME)
|
||||||
|
|
||||||
|
tensor_parallel_devices = (device,)
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
quant_type = QuantType.NONE
|
||||||
|
|
||||||
|
block = config.block_class(config).to(dtype)
|
||||||
|
block = convert_block(block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
|
||||||
|
|
||||||
|
unopt_block = UnoptimizedWrappedFalconBlock(config).to(dtype)
|
||||||
|
unopt_block = convert_block(
|
||||||
|
unopt_block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True
|
||||||
|
)
|
||||||
|
|
||||||
|
unopt_block.load_state_dict(block.state_dict())
|
||||||
|
cache = unopt_cache = None
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
for length in [10, 1, 1, 1]:
|
||||||
|
dummy_input = torch.randn(1, length, config.hidden_size, device=device, dtype=dtype)
|
||||||
|
block_output, cache = block(dummy_input, layer_past=cache, use_cache=True)
|
||||||
|
unopt_block_output, unopt_cache = unopt_block(dummy_input, layer_past=unopt_cache, use_cache=True)
|
||||||
|
assert torch.allclose(block_output, unopt_block_output, atol=1e-6, rtol=0), length
|
||||||
|
assert torch.allclose(cache[0], unopt_cache[0], atol=1e-6, rtol=0), length
|
||||||
|
assert torch.allclose(cache[1], unopt_cache[1], atol=1e-6, rtol=0), length
|
Loading…
Reference in New Issue