Add Falcon support (#499)
This PR adds: - Support for models based on `transformers.FalconModel` (the in-library format for Falcon). Tested on Falcon-40B. - CI tests for Falcon-RW-1B. - `--throughput dry_run` option to evaluate throughput and exit right away (implemented by @mryab). Limitations: - Backward pass support is broken for now, will be fixed in #500. Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>pull/572/head
parent
b3952e16ee
commit
5f36029534
@ -1,2 +1,3 @@
|
||||
from petals.models.bloom import *
|
||||
from petals.models.falcon import *
|
||||
from petals.models.llama import *
|
||||
|
@ -0,0 +1,15 @@
|
||||
from petals.models.falcon.block import WrappedFalconBlock
|
||||
from petals.models.falcon.config import DistributedFalconConfig
|
||||
from petals.models.falcon.model import (
|
||||
DistributedFalconForCausalLM,
|
||||
DistributedFalconForSequenceClassification,
|
||||
DistributedFalconModel,
|
||||
)
|
||||
from petals.utils.auto_config import register_model_classes
|
||||
|
||||
register_model_classes(
|
||||
config=DistributedFalconConfig,
|
||||
model=DistributedFalconModel,
|
||||
model_for_causal_lm=DistributedFalconForCausalLM,
|
||||
model_for_sequence_classification=DistributedFalconForSequenceClassification,
|
||||
)
|
@ -0,0 +1,94 @@
|
||||
"""
|
||||
Falcon intermediate layer
|
||||
Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py
|
||||
See commit history for authorship.
|
||||
"""
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class WrappedFalconBlock(FalconDecoderLayer):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
*args,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
alibi: Optional[torch.Tensor] = None,
|
||||
layer_past: Optional[KVCache] = None,
|
||||
use_cache: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
batch_size, seq_length = hidden_states.shape[:2]
|
||||
|
||||
if layer_past is not None:
|
||||
layer_past = self._reorder_cache_from_bloom_to_falcon(layer_past)
|
||||
past_length = 0 if layer_past is None else layer_past[0].shape[1]
|
||||
seq_length_with_past = seq_length + past_length
|
||||
|
||||
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
|
||||
if alibi is None and self.config.alibi:
|
||||
alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
|
||||
attention_mask = FalconModel._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length)
|
||||
|
||||
outputs = super().forward(
|
||||
hidden_states,
|
||||
*args,
|
||||
attention_mask=attention_mask,
|
||||
alibi=alibi,
|
||||
layer_past=layer_past,
|
||||
use_cache=use_cache,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if use_cache:
|
||||
present_key_value = outputs[-1]
|
||||
present_key_value = self._reorder_cache_from_falcon_to_bloom(present_key_value)
|
||||
outputs = outputs[:-1] + (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
def _reorder_cache_from_bloom_to_falcon(self, key_value: KVCache) -> KVCache:
|
||||
key_states, value_states = key_value
|
||||
|
||||
key_states = key_states.permute(0, 2, 1)
|
||||
assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim]
|
||||
|
||||
if self.config.new_decoder_architecture:
|
||||
key_states = self._expand_states(key_states)
|
||||
value_states = self._expand_states(value_states)
|
||||
|
||||
return (key_states, value_states)
|
||||
|
||||
def _reorder_cache_from_falcon_to_bloom(self, key_value: KVCache) -> KVCache:
|
||||
key_states, value_states = key_value
|
||||
|
||||
if self.config.new_decoder_architecture:
|
||||
key_states = self._collapse_states(key_states)
|
||||
value_states = self._collapse_states(value_states)
|
||||
|
||||
assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim]
|
||||
key_states = key_states.permute(0, 2, 1)
|
||||
|
||||
return (key_states, value_states)
|
||||
|
||||
def _expand_states(self, state: torch.Tensor) -> torch.Tensor:
|
||||
batch_size_x_num_kv_heads, seq_len, head_dim = state.shape
|
||||
batch_size = batch_size_x_num_kv_heads // self.config.num_kv_heads
|
||||
|
||||
state = state.view(batch_size, self.config.num_kv_heads, 1, seq_len, head_dim)
|
||||
state = state.expand(-1, -1, self.config.num_key_value_groups, -1, -1) # No copy
|
||||
state = state.reshape(batch_size * self.config.num_attention_heads, seq_len, head_dim) # Involves a copy
|
||||
return state
|
||||
|
||||
def _collapse_states(self, state: torch.Tensor) -> torch.Tensor:
|
||||
batch_size_x_num_attn_heads, seq_len, head_dim = state.shape
|
||||
batch_size = batch_size_x_num_attn_heads // self.config.num_attention_heads
|
||||
|
||||
state = state.view(batch_size, self.config.num_kv_heads, self.config.num_key_value_groups, seq_len, head_dim)
|
||||
state = state[:, :, 0]
|
||||
state = state.view(batch_size * self.config.num_kv_heads, seq_len, head_dim)
|
||||
return state
|
@ -0,0 +1,45 @@
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
from hivemind import get_logger
|
||||
from transformers.models.falcon import FalconConfig
|
||||
from transformers.models.falcon.modeling_falcon import FalconAttention
|
||||
|
||||
from petals.client.config import ClientConfig
|
||||
from petals.client.lm_head import LMHeadConfig
|
||||
from petals.client.ptune import PTuneConfig
|
||||
from petals.models.falcon.block import WrappedFalconBlock
|
||||
from petals.utils.auto_config import DefaultRevisionMixin
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DistributedFalconConfig(DefaultRevisionMixin, FalconConfig, ClientConfig, PTuneConfig, LMHeadConfig):
|
||||
block_class = WrappedFalconBlock
|
||||
attn_class = FalconAttention
|
||||
block_prefix = "transformer.h"
|
||||
|
||||
@property
|
||||
def num_key_value_groups(self) -> int:
|
||||
if self.new_decoder_architecture:
|
||||
return self.num_attention_heads // self.num_kv_heads
|
||||
if self.multi_query:
|
||||
return self.num_attention_heads
|
||||
return 1
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
|
||||
):
|
||||
loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
|
||||
if loading_from_repo and dht_prefix is None:
|
||||
dht_prefix = str(model_name_or_path)
|
||||
dht_prefix = dht_prefix.split("/")[-1] # Use only repo name to merge blocks hosted by different accounts
|
||||
dht_prefix = dht_prefix.replace(".", "-")
|
||||
logger.info(f"Using DHT prefix: {dht_prefix}")
|
||||
|
||||
result = super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
|
||||
config = result[0] if isinstance(result, tuple) else result
|
||||
if config.pad_token_id is None:
|
||||
config.pad_token_id = 0
|
||||
return result
|
@ -0,0 +1,149 @@
|
||||
from typing import Optional
|
||||
|
||||
import hivemind
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from hivemind.utils.logging import get_logger
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
|
||||
from transformers.models.falcon import (
|
||||
FalconForCausalLM,
|
||||
FalconForSequenceClassification,
|
||||
FalconModel,
|
||||
FalconPreTrainedModel,
|
||||
)
|
||||
|
||||
from petals.client.from_pretrained import FromPretrainedMixin
|
||||
from petals.client.lm_head import LMHead
|
||||
from petals.client.ptune import PTuneMixin
|
||||
from petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues
|
||||
from petals.client.remote_sequential import RemoteSequential
|
||||
from petals.models.falcon.config import DistributedFalconConfig
|
||||
from petals.utils.auto_config import DefaultRevisionMixin
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DistributedFalconModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMixin, FalconModel):
|
||||
"""FalconModel, but all transformer layers are hosted by the swarm"""
|
||||
|
||||
_keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing
|
||||
_keys_to_ignore_on_load_unexpected = [r"^transformer\.h\."]
|
||||
|
||||
config_class = DistributedFalconConfig
|
||||
|
||||
def __init__(self, config: DistributedFalconConfig, *, dht: Optional[hivemind.DHT] = None):
|
||||
n_layer, config.num_hidden_layers = config.num_hidden_layers, 0 # Prevent initialization
|
||||
super().__init__(config)
|
||||
assert len(self.h) == 0
|
||||
config.num_hidden_layers = n_layer
|
||||
|
||||
self.h = RemoteSequential(config, dht=dht)
|
||||
|
||||
self.requires_grad_(False) # Forbid accumulate grads for embeddings and layernorm
|
||||
self.init_prompts(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[RemotePastKeyValues] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
# The causal mask will be added on the server-side
|
||||
assert (
|
||||
attention_mask is None or (attention_mask == 1).all()
|
||||
), f"Custom attention masks are not supported, {attention_mask=}"
|
||||
assert head_mask is None, f"Custom head masks are not supported, {head_mask=}"
|
||||
assert use_cache is None or use_cache, f"{use_cache=} is not supported"
|
||||
assert not output_attentions, f"{output_attentions=} is not supported"
|
||||
assert not output_hidden_states, f"{output_hidden_states=} is not supported"
|
||||
assert return_dict is None or return_dict, f"{return_dict=} is not supported"
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
|
||||
if self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0:
|
||||
batch_size = inputs_embeds.shape[0]
|
||||
prompts, intermediate_prompts = self.get_prompt(batch_size)
|
||||
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
|
||||
else:
|
||||
prompts = intermediate_prompts = None
|
||||
|
||||
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
|
||||
hidden_states = self.h(
|
||||
hidden_states,
|
||||
prompts=intermediate_prompts,
|
||||
hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None,
|
||||
)
|
||||
|
||||
# Remove prefix
|
||||
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
||||
hidden_states = hidden_states[:, self.pre_seq_len :]
|
||||
|
||||
# Add last hidden state
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
hidden_states = hidden_states.view(output_shape)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=RemotePastKeyValues(),
|
||||
hidden_states=None,
|
||||
attentions=None,
|
||||
)
|
||||
|
||||
@property
|
||||
def word_embeddings_layernorm(self) -> nn.Module: # For compatibility with RemoteGenerationMixin
|
||||
return nn.Identity()
|
||||
|
||||
|
||||
class DistributedFalconForCausalLM(DefaultRevisionMixin, FromPretrainedMixin, RemoteGenerationMixin, FalconForCausalLM):
|
||||
_keys_to_ignore_on_load_missing = DistributedFalconModel._keys_to_ignore_on_load_missing
|
||||
_keys_to_ignore_on_load_unexpected = DistributedFalconModel._keys_to_ignore_on_load_unexpected
|
||||
|
||||
config_class = DistributedFalconConfig
|
||||
|
||||
def __init__(self, config: DistributedFalconConfig):
|
||||
FalconPreTrainedModel.__init__(self, config)
|
||||
self.transformer = DistributedFalconModel(config)
|
||||
self.lm_head = LMHead(config)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
|
||||
class DistributedFalconForSequenceClassification(
|
||||
DefaultRevisionMixin, FromPretrainedMixin, FalconForSequenceClassification
|
||||
):
|
||||
_keys_to_ignore_on_load_missing = DistributedFalconModel._keys_to_ignore_on_load_missing
|
||||
_keys_to_ignore_on_load_unexpected = DistributedFalconModel._keys_to_ignore_on_load_unexpected
|
||||
|
||||
config_class = DistributedFalconConfig
|
||||
|
||||
def __init__(self, config: DistributedFalconConfig):
|
||||
FalconPreTrainedModel.__init__(self, config)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.transformer = DistributedFalconModel(config)
|
||||
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
Loading…
Reference in New Issue