Add Falcon support (#499)

This PR adds:

- Support for models based on `transformers.FalconModel` (the in-library format for Falcon). Tested on Falcon-40B.
- CI tests for Falcon-RW-1B.
- `--throughput dry_run` option to evaluate throughput and exit right away (implemented by @mryab).

Limitations:

- Backward pass support is broken for now, will be fixed in #500.

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
pull/501/head
Alexander Borzunov 8 months ago committed by GitHub
parent b4d822afb2
commit dd4a3230bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -14,11 +14,13 @@ jobs:
- { model: 'bigscience/bloom-560m', os: 'ubuntu', python-version: '3.11' }
- { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.8' }
- { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.11' }
- { model: 'petals-team/falcon-rw-1b', os: 'ubuntu', python-version: '3.8' }
- { model: 'petals-team/falcon-rw-1b', os: 'ubuntu', python-version: '3.11' }
- { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.10' }
- { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.11' }
fail-fast: false
runs-on: ${{ matrix.os }}-latest
timeout-minutes: 15
timeout-minutes: 20
steps:
- name: Increase swap space
if: ${{ matrix.os == 'ubuntu' }}
@ -93,6 +95,9 @@ jobs:
# [Step 2] Run PyTest
# Share disk cache between Petals servers, clients, and HF Transformers
export TRANSFORMERS_CACHE=~/.cache/petals
# Necessary for @pytest.mark.forked to work properly on macOS, see https://github.com/kevlened/pytest-parallel/issues/93
export no_proxy=*
export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES

@ -106,12 +106,13 @@ def main():
"and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.")
parser.add_argument('--throughput',
type=lambda value: value if value in ['auto', 'eval'] else float(value),
type=lambda value: value if value in ['auto', 'eval', 'dry_run'] else float(value),
default='auto',
help='Expected server throughput (a float measured in RPS). '
'If set to "auto" (default), the script evaluates network and compute throughput '
'on the first run and uses these estimates for future runs. '
'If set to "eval", the script re-evaluates the throughput and overrides the cache.')
'If set to "eval", the script re-evaluates the throughput and overrides the cache. '
'If set to "dry_run", the script re-evaluates the throughput and exits.')
parser.add_argument('--update_period', type=float, required=False, default=120,
help='Server will report blocks to DHT once in this many seconds')
parser.add_argument('--expiration', type=float, required=False, default=None,

@ -1,2 +1,3 @@
from petals.models.bloom import *
from petals.models.falcon import *
from petals.models.llama import *

@ -0,0 +1,15 @@
from petals.models.falcon.block import WrappedFalconBlock
from petals.models.falcon.config import DistributedFalconConfig
from petals.models.falcon.model import (
DistributedFalconForCausalLM,
DistributedFalconForSequenceClassification,
DistributedFalconModel,
)
from petals.utils.auto_config import register_model_classes
register_model_classes(
config=DistributedFalconConfig,
model=DistributedFalconModel,
model_for_causal_lm=DistributedFalconForCausalLM,
model_for_sequence_classification=DistributedFalconForSequenceClassification,
)

@ -0,0 +1,94 @@
"""
Falcon intermediate layer
Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py
See commit history for authorship.
"""
from typing import Optional, Tuple
import torch
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor
KVCache = Tuple[torch.Tensor, torch.Tensor]
class WrappedFalconBlock(FalconDecoderLayer):
def forward(
self,
hidden_states: torch.Tensor,
*args,
attention_mask: Optional[torch.Tensor] = None,
alibi: Optional[torch.Tensor] = None,
layer_past: Optional[KVCache] = None,
use_cache: bool = False,
**kwargs
):
batch_size, seq_length = hidden_states.shape[:2]
if layer_past is not None:
layer_past = self._reorder_cache_from_bloom_to_falcon(layer_past)
past_length = 0 if layer_past is None else layer_past[0].shape[1]
seq_length_with_past = seq_length + past_length
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
if alibi is None and self.config.alibi:
alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
attention_mask = FalconModel._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length)
outputs = super().forward(
hidden_states,
*args,
attention_mask=attention_mask,
alibi=alibi,
layer_past=layer_past,
use_cache=use_cache,
**kwargs
)
if use_cache:
present_key_value = outputs[-1]
present_key_value = self._reorder_cache_from_falcon_to_bloom(present_key_value)
outputs = outputs[:-1] + (present_key_value,)
return outputs
def _reorder_cache_from_bloom_to_falcon(self, key_value: KVCache) -> KVCache:
key_states, value_states = key_value
key_states = key_states.permute(0, 2, 1)
assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim]
if self.config.new_decoder_architecture:
key_states = self._expand_states(key_states)
value_states = self._expand_states(value_states)
return (key_states, value_states)
def _reorder_cache_from_falcon_to_bloom(self, key_value: KVCache) -> KVCache:
key_states, value_states = key_value
if self.config.new_decoder_architecture:
key_states = self._collapse_states(key_states)
value_states = self._collapse_states(value_states)
assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim]
key_states = key_states.permute(0, 2, 1)
return (key_states, value_states)
def _expand_states(self, state: torch.Tensor) -> torch.Tensor:
batch_size_x_num_kv_heads, seq_len, head_dim = state.shape
batch_size = batch_size_x_num_kv_heads // self.config.num_kv_heads
state = state.view(batch_size, self.config.num_kv_heads, 1, seq_len, head_dim)
state = state.expand(-1, -1, self.config.num_key_value_groups, -1, -1) # No copy
state = state.reshape(batch_size * self.config.num_attention_heads, seq_len, head_dim) # Involves a copy
return state
def _collapse_states(self, state: torch.Tensor) -> torch.Tensor:
batch_size_x_num_attn_heads, seq_len, head_dim = state.shape
batch_size = batch_size_x_num_attn_heads // self.config.num_attention_heads
state = state.view(batch_size, self.config.num_kv_heads, self.config.num_key_value_groups, seq_len, head_dim)
state = state[:, :, 0]
state = state.view(batch_size * self.config.num_kv_heads, seq_len, head_dim)
return state

@ -0,0 +1,45 @@
import os
from typing import Optional, Union
from hivemind import get_logger
from transformers.models.falcon import FalconConfig
from transformers.models.falcon.modeling_falcon import FalconAttention
from petals.client.config import ClientConfig
from petals.client.lm_head import LMHeadConfig
from petals.client.ptune import PTuneConfig
from petals.models.falcon.block import WrappedFalconBlock
from petals.utils.auto_config import DefaultRevisionMixin
logger = get_logger(__name__)
class DistributedFalconConfig(DefaultRevisionMixin, FalconConfig, ClientConfig, PTuneConfig, LMHeadConfig):
block_class = WrappedFalconBlock
attn_class = FalconAttention
block_prefix = "transformer.h"
@property
def num_key_value_groups(self) -> int:
if self.new_decoder_architecture:
return self.num_attention_heads // self.num_kv_heads
if self.multi_query:
return self.num_attention_heads
return 1
@classmethod
def from_pretrained(
cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
):
loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
if loading_from_repo and dht_prefix is None:
dht_prefix = str(model_name_or_path)
dht_prefix = dht_prefix.split("/")[-1] # Use only repo name to merge blocks hosted by different accounts
dht_prefix = dht_prefix.replace(".", "-")
logger.info(f"Using DHT prefix: {dht_prefix}")
result = super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
config = result[0] if isinstance(result, tuple) else result
if config.pad_token_id is None:
config.pad_token_id = 0
return result

@ -0,0 +1,149 @@
from typing import Optional
import hivemind
import torch
import torch.nn as nn
from hivemind.utils.logging import get_logger
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from transformers.models.falcon import (
FalconForCausalLM,
FalconForSequenceClassification,
FalconModel,
FalconPreTrainedModel,
)
from petals.client.from_pretrained import FromPretrainedMixin
from petals.client.lm_head import LMHead
from petals.client.ptune import PTuneMixin
from petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues
from petals.client.remote_sequential import RemoteSequential
from petals.models.falcon.config import DistributedFalconConfig
from petals.utils.auto_config import DefaultRevisionMixin
logger = get_logger(__name__)
class DistributedFalconModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMixin, FalconModel):
"""FalconModel, but all transformer layers are hosted by the swarm"""
_keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing
_keys_to_ignore_on_load_unexpected = [r"^transformer\.h\."]
config_class = DistributedFalconConfig
def __init__(self, config: DistributedFalconConfig, *, dht: Optional[hivemind.DHT] = None):
n_layer, config.num_hidden_layers = config.num_hidden_layers, 0 # Prevent initialization
super().__init__(config)
assert len(self.h) == 0
config.num_hidden_layers = n_layer
self.h = RemoteSequential(config, dht=dht)
self.requires_grad_(False) # Forbid accumulate grads for embeddings and layernorm
self.init_prompts(config)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[RemotePastKeyValues] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
# The causal mask will be added on the server-side
assert (
attention_mask is None or (attention_mask == 1).all()
), f"Custom attention masks are not supported, {attention_mask=}"
assert head_mask is None, f"Custom head masks are not supported, {head_mask=}"
assert use_cache is None or use_cache, f"{use_cache=} is not supported"
assert not output_attentions, f"{output_attentions=} is not supported"
assert not output_hidden_states, f"{output_hidden_states=} is not supported"
assert return_dict is None or return_dict, f"{return_dict=} is not supported"
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
if self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0:
batch_size = inputs_embeds.shape[0]
prompts, intermediate_prompts = self.get_prompt(batch_size)
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
else:
prompts = intermediate_prompts = None
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
output_shape = input_shape + (hidden_states.size(-1),)
hidden_states = self.h(
hidden_states,
prompts=intermediate_prompts,
hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None,
)
# Remove prefix
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
hidden_states = hidden_states[:, self.pre_seq_len :]
# Add last hidden state
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=RemotePastKeyValues(),
hidden_states=None,
attentions=None,
)
@property
def word_embeddings_layernorm(self) -> nn.Module: # For compatibility with RemoteGenerationMixin
return nn.Identity()
class DistributedFalconForCausalLM(DefaultRevisionMixin, FromPretrainedMixin, RemoteGenerationMixin, FalconForCausalLM):
_keys_to_ignore_on_load_missing = DistributedFalconModel._keys_to_ignore_on_load_missing
_keys_to_ignore_on_load_unexpected = DistributedFalconModel._keys_to_ignore_on_load_unexpected
config_class = DistributedFalconConfig
def __init__(self, config: DistributedFalconConfig):
FalconPreTrainedModel.__init__(self, config)
self.transformer = DistributedFalconModel(config)
self.lm_head = LMHead(config)
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.lm_head
class DistributedFalconForSequenceClassification(
DefaultRevisionMixin, FromPretrainedMixin, FalconForSequenceClassification
):
_keys_to_ignore_on_load_missing = DistributedFalconModel._keys_to_ignore_on_load_missing
_keys_to_ignore_on_load_unexpected = DistributedFalconModel._keys_to_ignore_on_load_unexpected
config_class = DistributedFalconConfig
def __init__(self, config: DistributedFalconConfig):
FalconPreTrainedModel.__init__(self, config)
self.num_labels = config.num_labels
self.transformer = DistributedFalconModel(config)
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
# Initialize weights and apply final processing
self.post_init()

@ -5,6 +5,7 @@ import math
import multiprocessing as mp
import os
import random
import sys
import threading
import time
from typing import Dict, List, Optional, Sequence, Union
@ -186,10 +187,7 @@ class Server:
check_device_balance(self.tensor_parallel_devices)
if quant_type is None:
if device.type == "cuda":
quant_type = QuantType.NF4 if self.block_config.model_type == "llama" else QuantType.INT8
else:
quant_type = QuantType.NONE
quant_type = QuantType.NF4 if device.type == "cuda" else QuantType.NONE
self.quant_type = quant_type
logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, quant_type)} format")
@ -234,8 +232,9 @@ class Server:
self.attn_cache_bytes = self._cache_bytes_per_block * num_blocks
logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB")
assert isinstance(throughput, float) or throughput in ["auto", "eval"]
if throughput in ["auto", "eval"]:
assert isinstance(throughput, float) or throughput in ["auto", "eval", "dry_run"]
if throughput in ["auto", "eval", "dry_run"]:
force_eval = throughput in ["eval", "dry_run"]
throughput_info = get_server_throughput(
converted_model_name_or_path,
self.block_config,
@ -245,9 +244,12 @@ class Server:
quant_type=quant_type,
tensor_parallel_devices=self.tensor_parallel_devices,
reachable_via_relay=reachable_via_relay,
force_eval=(throughput == "eval"),
force_eval=force_eval,
cache_dir=cache_dir,
)
if throughput == "dry_run":
logger.info("Finished estimating throughput, exiting")
sys.exit(0)
else:
throughput_info = {"throughput": throughput}
self.server_info = ServerInfo(

@ -1,12 +1,14 @@
import os
import re
from dataclasses import dataclass
from typing import Optional, Type, Union
from hivemind import get_logger
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
from petals.utils.hf_auth import always_needs_auth
logger = get_logger(__name__)
@dataclass
class _ModelClasses:
@ -49,17 +51,44 @@ class _AutoDistributedBase:
return proper_cls.from_pretrained(model_name_or_path, *args, **kwargs)
class AutoDistributedConfig(_AutoDistributedBase):
class DefaultRevisionMixin:
"""
Petals only supports Falcon loaded in the new in-library format (transformers.FalconModel).
TII models were recently converted to this format but then reverted back due to compatibility issues.
We chose to support only the new format since HF staff promised to eventually convert these models
to the new format again, see https://huggingface.co/tiiuae/falcon-40b/discussions/90#64b4d23bf44fd957492f7602
Until it happens, we override the default `main` revision for the TII repos with the commit
pointing out to the model in the in-library format.
"""
DEFAULT_REVISIONS = {
"tiiuae/falcon-40b": "f1ba7d328c06aa6fbb4a8afd3c756f46d7e6b232",
"tiiuae/falcon-40b-instruct": "7475ff8cfc36ed9a962b658ae3c33391566a85a5",
"tiiuae/falcon-7b": "4e2d06f0a7c6370ebabbc30c6f59377ae8f73d76",
"tiiuae/falcon-7b-instruct": "f8dac3fff96d5debd43edf56fb4e1abcfffbef28",
}
@classmethod
def from_pretrained(
cls, model_name_or_path: Union[str, os.PathLike, None], *args, revision: Optional[str] = None, **kwargs
):
if revision is None and model_name_or_path in cls.DEFAULT_REVISIONS:
revision = cls.DEFAULT_REVISIONS[model_name_or_path]
logger.info(f"Loading {model_name_or_path}, revision {revision}")
return super().from_pretrained(model_name_or_path, *args, revision=revision, **kwargs)
class AutoDistributedConfig(DefaultRevisionMixin, _AutoDistributedBase):
_mapping_field = "config"
class AutoDistributedModel(_AutoDistributedBase):
class AutoDistributedModel(DefaultRevisionMixin, _AutoDistributedBase):
_mapping_field = "model"
class AutoDistributedModelForCausalLM(_AutoDistributedBase):
class AutoDistributedModelForCausalLM(DefaultRevisionMixin, _AutoDistributedBase):
_mapping_field = "model_for_causal_lm"
class AutoDistributedModelForSequenceClassification(_AutoDistributedBase):
class AutoDistributedModelForSequenceClassification(DefaultRevisionMixin, _AutoDistributedBase):
_mapping_field = "model_for_sequence_classification"

Loading…
Cancel
Save