From dd4a3230bced491c22fc44587ef72f83a676fc2d Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Mon, 4 Sep 2023 01:45:37 +0400 Subject: [PATCH] Add Falcon support (#499) This PR adds: - Support for models based on `transformers.FalconModel` (the in-library format for Falcon). Tested on Falcon-40B. - CI tests for Falcon-RW-1B. - `--throughput dry_run` option to evaluate throughput and exit right away (implemented by @mryab). Limitations: - Backward pass support is broken for now, will be fixed in #500. Co-authored-by: Max Ryabinin --- .github/workflows/run-tests.yaml | 7 +- src/petals/cli/run_server.py | 5 +- src/petals/models/__init__.py | 1 + src/petals/models/falcon/__init__.py | 15 +++ src/petals/models/falcon/block.py | 94 +++++++++++++++++ src/petals/models/falcon/config.py | 45 ++++++++ src/petals/models/falcon/model.py | 149 +++++++++++++++++++++++++++ src/petals/server/server.py | 16 +-- src/petals/utils/auto_config.py | 39 ++++++- 9 files changed, 356 insertions(+), 15 deletions(-) create mode 100644 src/petals/models/falcon/__init__.py create mode 100644 src/petals/models/falcon/block.py create mode 100644 src/petals/models/falcon/config.py create mode 100644 src/petals/models/falcon/model.py diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index 663b400..05cebdd 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -14,11 +14,13 @@ jobs: - { model: 'bigscience/bloom-560m', os: 'ubuntu', python-version: '3.11' } - { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.8' } - { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.11' } + - { model: 'petals-team/falcon-rw-1b', os: 'ubuntu', python-version: '3.8' } + - { model: 'petals-team/falcon-rw-1b', os: 'ubuntu', python-version: '3.11' } - { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.10' } - { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.11' } fail-fast: false runs-on: ${{ matrix.os }}-latest - timeout-minutes: 15 + timeout-minutes: 20 steps: - name: Increase swap space if: ${{ matrix.os == 'ubuntu' }} @@ -93,6 +95,9 @@ jobs: # [Step 2] Run PyTest + # Share disk cache between Petals servers, clients, and HF Transformers + export TRANSFORMERS_CACHE=~/.cache/petals + # Necessary for @pytest.mark.forked to work properly on macOS, see https://github.com/kevlened/pytest-parallel/issues/93 export no_proxy=* export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index 3728c16..94f5c2e 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -106,12 +106,13 @@ def main(): "and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.") parser.add_argument('--throughput', - type=lambda value: value if value in ['auto', 'eval'] else float(value), + type=lambda value: value if value in ['auto', 'eval', 'dry_run'] else float(value), default='auto', help='Expected server throughput (a float measured in RPS). ' 'If set to "auto" (default), the script evaluates network and compute throughput ' 'on the first run and uses these estimates for future runs. ' - 'If set to "eval", the script re-evaluates the throughput and overrides the cache.') + 'If set to "eval", the script re-evaluates the throughput and overrides the cache. ' + 'If set to "dry_run", the script re-evaluates the throughput and exits.') parser.add_argument('--update_period', type=float, required=False, default=120, help='Server will report blocks to DHT once in this many seconds') parser.add_argument('--expiration', type=float, required=False, default=None, diff --git a/src/petals/models/__init__.py b/src/petals/models/__init__.py index acb4d38..f52a429 100644 --- a/src/petals/models/__init__.py +++ b/src/petals/models/__init__.py @@ -1,2 +1,3 @@ from petals.models.bloom import * +from petals.models.falcon import * from petals.models.llama import * diff --git a/src/petals/models/falcon/__init__.py b/src/petals/models/falcon/__init__.py new file mode 100644 index 0000000..019ca5d --- /dev/null +++ b/src/petals/models/falcon/__init__.py @@ -0,0 +1,15 @@ +from petals.models.falcon.block import WrappedFalconBlock +from petals.models.falcon.config import DistributedFalconConfig +from petals.models.falcon.model import ( + DistributedFalconForCausalLM, + DistributedFalconForSequenceClassification, + DistributedFalconModel, +) +from petals.utils.auto_config import register_model_classes + +register_model_classes( + config=DistributedFalconConfig, + model=DistributedFalconModel, + model_for_causal_lm=DistributedFalconForCausalLM, + model_for_sequence_classification=DistributedFalconForSequenceClassification, +) diff --git a/src/petals/models/falcon/block.py b/src/petals/models/falcon/block.py new file mode 100644 index 0000000..e677e06 --- /dev/null +++ b/src/petals/models/falcon/block.py @@ -0,0 +1,94 @@ +""" +Falcon intermediate layer +Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py +See commit history for authorship. +""" +from typing import Optional, Tuple + +import torch +from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class WrappedFalconBlock(FalconDecoderLayer): + def forward( + self, + hidden_states: torch.Tensor, + *args, + attention_mask: Optional[torch.Tensor] = None, + alibi: Optional[torch.Tensor] = None, + layer_past: Optional[KVCache] = None, + use_cache: bool = False, + **kwargs + ): + batch_size, seq_length = hidden_states.shape[:2] + + if layer_past is not None: + layer_past = self._reorder_cache_from_bloom_to_falcon(layer_past) + past_length = 0 if layer_past is None else layer_past[0].shape[1] + seq_length_with_past = seq_length + past_length + + attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) + if alibi is None and self.config.alibi: + alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype) + attention_mask = FalconModel._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length) + + outputs = super().forward( + hidden_states, + *args, + attention_mask=attention_mask, + alibi=alibi, + layer_past=layer_past, + use_cache=use_cache, + **kwargs + ) + + if use_cache: + present_key_value = outputs[-1] + present_key_value = self._reorder_cache_from_falcon_to_bloom(present_key_value) + outputs = outputs[:-1] + (present_key_value,) + + return outputs + + def _reorder_cache_from_bloom_to_falcon(self, key_value: KVCache) -> KVCache: + key_states, value_states = key_value + + key_states = key_states.permute(0, 2, 1) + assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim] + + if self.config.new_decoder_architecture: + key_states = self._expand_states(key_states) + value_states = self._expand_states(value_states) + + return (key_states, value_states) + + def _reorder_cache_from_falcon_to_bloom(self, key_value: KVCache) -> KVCache: + key_states, value_states = key_value + + if self.config.new_decoder_architecture: + key_states = self._collapse_states(key_states) + value_states = self._collapse_states(value_states) + + assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim] + key_states = key_states.permute(0, 2, 1) + + return (key_states, value_states) + + def _expand_states(self, state: torch.Tensor) -> torch.Tensor: + batch_size_x_num_kv_heads, seq_len, head_dim = state.shape + batch_size = batch_size_x_num_kv_heads // self.config.num_kv_heads + + state = state.view(batch_size, self.config.num_kv_heads, 1, seq_len, head_dim) + state = state.expand(-1, -1, self.config.num_key_value_groups, -1, -1) # No copy + state = state.reshape(batch_size * self.config.num_attention_heads, seq_len, head_dim) # Involves a copy + return state + + def _collapse_states(self, state: torch.Tensor) -> torch.Tensor: + batch_size_x_num_attn_heads, seq_len, head_dim = state.shape + batch_size = batch_size_x_num_attn_heads // self.config.num_attention_heads + + state = state.view(batch_size, self.config.num_kv_heads, self.config.num_key_value_groups, seq_len, head_dim) + state = state[:, :, 0] + state = state.view(batch_size * self.config.num_kv_heads, seq_len, head_dim) + return state diff --git a/src/petals/models/falcon/config.py b/src/petals/models/falcon/config.py new file mode 100644 index 0000000..a1ae5e9 --- /dev/null +++ b/src/petals/models/falcon/config.py @@ -0,0 +1,45 @@ +import os +from typing import Optional, Union + +from hivemind import get_logger +from transformers.models.falcon import FalconConfig +from transformers.models.falcon.modeling_falcon import FalconAttention + +from petals.client.config import ClientConfig +from petals.client.lm_head import LMHeadConfig +from petals.client.ptune import PTuneConfig +from petals.models.falcon.block import WrappedFalconBlock +from petals.utils.auto_config import DefaultRevisionMixin + +logger = get_logger(__name__) + + +class DistributedFalconConfig(DefaultRevisionMixin, FalconConfig, ClientConfig, PTuneConfig, LMHeadConfig): + block_class = WrappedFalconBlock + attn_class = FalconAttention + block_prefix = "transformer.h" + + @property + def num_key_value_groups(self) -> int: + if self.new_decoder_architecture: + return self.num_attention_heads // self.num_kv_heads + if self.multi_query: + return self.num_attention_heads + return 1 + + @classmethod + def from_pretrained( + cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs + ): + loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path) + if loading_from_repo and dht_prefix is None: + dht_prefix = str(model_name_or_path) + dht_prefix = dht_prefix.split("/")[-1] # Use only repo name to merge blocks hosted by different accounts + dht_prefix = dht_prefix.replace(".", "-") + logger.info(f"Using DHT prefix: {dht_prefix}") + + result = super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs) + config = result[0] if isinstance(result, tuple) else result + if config.pad_token_id is None: + config.pad_token_id = 0 + return result diff --git a/src/petals/models/falcon/model.py b/src/petals/models/falcon/model.py new file mode 100644 index 0000000..3a2a6b0 --- /dev/null +++ b/src/petals/models/falcon/model.py @@ -0,0 +1,149 @@ +from typing import Optional + +import hivemind +import torch +import torch.nn as nn +from hivemind.utils.logging import get_logger +from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions +from transformers.models.falcon import ( + FalconForCausalLM, + FalconForSequenceClassification, + FalconModel, + FalconPreTrainedModel, +) + +from petals.client.from_pretrained import FromPretrainedMixin +from petals.client.lm_head import LMHead +from petals.client.ptune import PTuneMixin +from petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues +from petals.client.remote_sequential import RemoteSequential +from petals.models.falcon.config import DistributedFalconConfig +from petals.utils.auto_config import DefaultRevisionMixin + +logger = get_logger(__name__) + + +class DistributedFalconModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMixin, FalconModel): + """FalconModel, but all transformer layers are hosted by the swarm""" + + _keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing + _keys_to_ignore_on_load_unexpected = [r"^transformer\.h\."] + + config_class = DistributedFalconConfig + + def __init__(self, config: DistributedFalconConfig, *, dht: Optional[hivemind.DHT] = None): + n_layer, config.num_hidden_layers = config.num_hidden_layers, 0 # Prevent initialization + super().__init__(config) + assert len(self.h) == 0 + config.num_hidden_layers = n_layer + + self.h = RemoteSequential(config, dht=dht) + + self.requires_grad_(False) # Forbid accumulate grads for embeddings and layernorm + self.init_prompts(config) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[RemotePastKeyValues] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + # The causal mask will be added on the server-side + assert ( + attention_mask is None or (attention_mask == 1).all() + ), f"Custom attention masks are not supported, {attention_mask=}" + assert head_mask is None, f"Custom head masks are not supported, {head_mask=}" + assert use_cache is None or use_cache, f"{use_cache=} is not supported" + assert not output_attentions, f"{output_attentions=} is not supported" + assert not output_hidden_states, f"{output_hidden_states=} is not supported" + assert return_dict is None or return_dict, f"{return_dict=} is not supported" + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + if self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0: + batch_size = inputs_embeds.shape[0] + prompts, intermediate_prompts = self.get_prompt(batch_size) + inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1) + else: + prompts = intermediate_prompts = None + + hidden_states = self.word_embeddings_layernorm(inputs_embeds) + output_shape = input_shape + (hidden_states.size(-1),) + + hidden_states = self.h( + hidden_states, + prompts=intermediate_prompts, + hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None, + ) + + # Remove prefix + if self.config.tuning_mode and "ptune" in self.config.tuning_mode: + hidden_states = hidden_states[:, self.pre_seq_len :] + + # Add last hidden state + hidden_states = self.ln_f(hidden_states) + hidden_states = hidden_states.view(output_shape) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=RemotePastKeyValues(), + hidden_states=None, + attentions=None, + ) + + @property + def word_embeddings_layernorm(self) -> nn.Module: # For compatibility with RemoteGenerationMixin + return nn.Identity() + + +class DistributedFalconForCausalLM(DefaultRevisionMixin, FromPretrainedMixin, RemoteGenerationMixin, FalconForCausalLM): + _keys_to_ignore_on_load_missing = DistributedFalconModel._keys_to_ignore_on_load_missing + _keys_to_ignore_on_load_unexpected = DistributedFalconModel._keys_to_ignore_on_load_unexpected + + config_class = DistributedFalconConfig + + def __init__(self, config: DistributedFalconConfig): + FalconPreTrainedModel.__init__(self, config) + self.transformer = DistributedFalconModel(config) + self.lm_head = LMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head + + +class DistributedFalconForSequenceClassification( + DefaultRevisionMixin, FromPretrainedMixin, FalconForSequenceClassification +): + _keys_to_ignore_on_load_missing = DistributedFalconModel._keys_to_ignore_on_load_missing + _keys_to_ignore_on_load_unexpected = DistributedFalconModel._keys_to_ignore_on_load_unexpected + + config_class = DistributedFalconConfig + + def __init__(self, config: DistributedFalconConfig): + FalconPreTrainedModel.__init__(self, config) + self.num_labels = config.num_labels + + self.transformer = DistributedFalconModel(config) + self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() diff --git a/src/petals/server/server.py b/src/petals/server/server.py index ab646a5..fd9f766 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -5,6 +5,7 @@ import math import multiprocessing as mp import os import random +import sys import threading import time from typing import Dict, List, Optional, Sequence, Union @@ -186,10 +187,7 @@ class Server: check_device_balance(self.tensor_parallel_devices) if quant_type is None: - if device.type == "cuda": - quant_type = QuantType.NF4 if self.block_config.model_type == "llama" else QuantType.INT8 - else: - quant_type = QuantType.NONE + quant_type = QuantType.NF4 if device.type == "cuda" else QuantType.NONE self.quant_type = quant_type logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, quant_type)} format") @@ -234,8 +232,9 @@ class Server: self.attn_cache_bytes = self._cache_bytes_per_block * num_blocks logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB") - assert isinstance(throughput, float) or throughput in ["auto", "eval"] - if throughput in ["auto", "eval"]: + assert isinstance(throughput, float) or throughput in ["auto", "eval", "dry_run"] + if throughput in ["auto", "eval", "dry_run"]: + force_eval = throughput in ["eval", "dry_run"] throughput_info = get_server_throughput( converted_model_name_or_path, self.block_config, @@ -245,9 +244,12 @@ class Server: quant_type=quant_type, tensor_parallel_devices=self.tensor_parallel_devices, reachable_via_relay=reachable_via_relay, - force_eval=(throughput == "eval"), + force_eval=force_eval, cache_dir=cache_dir, ) + if throughput == "dry_run": + logger.info("Finished estimating throughput, exiting") + sys.exit(0) else: throughput_info = {"throughput": throughput} self.server_info = ServerInfo( diff --git a/src/petals/utils/auto_config.py b/src/petals/utils/auto_config.py index 70f37a3..0cec83d 100644 --- a/src/petals/utils/auto_config.py +++ b/src/petals/utils/auto_config.py @@ -1,12 +1,14 @@ import os -import re from dataclasses import dataclass from typing import Optional, Type, Union +from hivemind import get_logger from transformers import AutoConfig, PretrainedConfig, PreTrainedModel from petals.utils.hf_auth import always_needs_auth +logger = get_logger(__name__) + @dataclass class _ModelClasses: @@ -49,17 +51,44 @@ class _AutoDistributedBase: return proper_cls.from_pretrained(model_name_or_path, *args, **kwargs) -class AutoDistributedConfig(_AutoDistributedBase): +class DefaultRevisionMixin: + """ + Petals only supports Falcon loaded in the new in-library format (transformers.FalconModel). + TII models were recently converted to this format but then reverted back due to compatibility issues. + We chose to support only the new format since HF staff promised to eventually convert these models + to the new format again, see https://huggingface.co/tiiuae/falcon-40b/discussions/90#64b4d23bf44fd957492f7602 + Until it happens, we override the default `main` revision for the TII repos with the commit + pointing out to the model in the in-library format. + """ + + DEFAULT_REVISIONS = { + "tiiuae/falcon-40b": "f1ba7d328c06aa6fbb4a8afd3c756f46d7e6b232", + "tiiuae/falcon-40b-instruct": "7475ff8cfc36ed9a962b658ae3c33391566a85a5", + "tiiuae/falcon-7b": "4e2d06f0a7c6370ebabbc30c6f59377ae8f73d76", + "tiiuae/falcon-7b-instruct": "f8dac3fff96d5debd43edf56fb4e1abcfffbef28", + } + + @classmethod + def from_pretrained( + cls, model_name_or_path: Union[str, os.PathLike, None], *args, revision: Optional[str] = None, **kwargs + ): + if revision is None and model_name_or_path in cls.DEFAULT_REVISIONS: + revision = cls.DEFAULT_REVISIONS[model_name_or_path] + logger.info(f"Loading {model_name_or_path}, revision {revision}") + return super().from_pretrained(model_name_or_path, *args, revision=revision, **kwargs) + + +class AutoDistributedConfig(DefaultRevisionMixin, _AutoDistributedBase): _mapping_field = "config" -class AutoDistributedModel(_AutoDistributedBase): +class AutoDistributedModel(DefaultRevisionMixin, _AutoDistributedBase): _mapping_field = "model" -class AutoDistributedModelForCausalLM(_AutoDistributedBase): +class AutoDistributedModelForCausalLM(DefaultRevisionMixin, _AutoDistributedBase): _mapping_field = "model_for_causal_lm" -class AutoDistributedModelForSequenceClassification(_AutoDistributedBase): +class AutoDistributedModelForSequenceClassification(DefaultRevisionMixin, _AutoDistributedBase): _mapping_field = "model_for_sequence_classification"