From abd547735f97196351336ed13b9f1b67468bbf1f Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Sat, 2 Sep 2023 22:57:18 +0400 Subject: [PATCH 1/5] Force use_cache=True (#496) --- setup.cfg | 2 +- src/petals/models/bloom/model.py | 3 +-- src/petals/models/llama/model.py | 3 +-- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/setup.cfg b/setup.cfg index cf14434..c8dbc9a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -40,7 +40,7 @@ install_requires = transformers>=4.32.0,<5.0.0 # if you change this, please also change version assert in petals/__init__.py speedtest-cli==2.1.3 pydantic>=1.10,<2.0 # 2.0 is incompatible with hivemind yet - hivemind @ git+https://github.com/learning-at-home/hivemind + hivemind==1.1.10.post2 tensor_parallel==1.0.23 humanfriendly async-timeout>=4.0.2 diff --git a/src/petals/models/bloom/model.py b/src/petals/models/bloom/model.py index cf83822..53e4a98 100644 --- a/src/petals/models/bloom/model.py +++ b/src/petals/models/bloom/model.py @@ -43,7 +43,7 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel): attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, + use_cache: Optional[bool] = None, # Not used here but needed for HF Transformers compatibility output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, @@ -63,7 +63,6 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel): attention_mask is None or (attention_mask == 1).all() ), f"Custom attention masks are not supported, {attention_mask=}" assert head_mask is None, f"Custom head masks are not supported, {head_mask=}" - assert use_cache is None or use_cache, f"{use_cache=} is not supported" assert not output_attentions, f"{output_attentions=} is not supported" assert not output_hidden_states, f"{output_hidden_states=} is not supported" assert return_dict is None or return_dict, f"{return_dict=} is not supported" diff --git a/src/petals/models/llama/model.py b/src/petals/models/llama/model.py index a9dfcc1..cf7d150 100644 --- a/src/petals/models/llama/model.py +++ b/src/petals/models/llama/model.py @@ -43,7 +43,7 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel): position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[RemotePastKeyValues] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, + use_cache: Optional[bool] = None, # Not used here but needed for HF Transformers compatibility output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, @@ -65,7 +65,6 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel): assert ( position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all() ), f"Non-consecutive position_ids are not supported, {position_ids=}" - assert use_cache is None or use_cache, f"{use_cache=} is not supported" assert not output_attentions, f"{output_attentions=} is not supported" assert not output_hidden_states, f"{output_hidden_states=} is not supported" assert return_dict is None or return_dict, f"{return_dict=} is not supported" From b4d822afb275deccd32b1b26bda46b80a3719467 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Sun, 3 Sep 2023 01:16:00 +0400 Subject: [PATCH 2/5] Force use_cache=True in config only (#497) This reverts a part of #496 and instead overrides `use_cache` in `LlamaConfig`s only (so the correct value is visible by HF `.generate()` as well). --- src/petals/models/bloom/model.py | 3 ++- src/petals/models/llama/config.py | 1 + src/petals/models/llama/model.py | 3 ++- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/petals/models/bloom/model.py b/src/petals/models/bloom/model.py index 53e4a98..cf83822 100644 --- a/src/petals/models/bloom/model.py +++ b/src/petals/models/bloom/model.py @@ -43,7 +43,7 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel): attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, # Not used here but needed for HF Transformers compatibility + use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, @@ -63,6 +63,7 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel): attention_mask is None or (attention_mask == 1).all() ), f"Custom attention masks are not supported, {attention_mask=}" assert head_mask is None, f"Custom head masks are not supported, {head_mask=}" + assert use_cache is None or use_cache, f"{use_cache=} is not supported" assert not output_attentions, f"{output_attentions=} is not supported" assert not output_hidden_states, f"{output_hidden_states=} is not supported" assert return_dict is None or return_dict, f"{return_dict=} is not supported" diff --git a/src/petals/models/llama/config.py b/src/petals/models/llama/config.py index 43a5843..ae71a4c 100644 --- a/src/petals/models/llama/config.py +++ b/src/petals/models/llama/config.py @@ -43,4 +43,5 @@ class DistributedLlamaConfig(LlamaConfig, ClientConfig, PTuneConfig, LMHeadConfi result = super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs) config = result[0] if isinstance(result, tuple) else result config.pretraining_tp = 1 # This may give less accurate results but it doesn't matter if we use quantization + config.use_cache = True # use_cache=False leads to identical results but is slower and not supported by Petals return result diff --git a/src/petals/models/llama/model.py b/src/petals/models/llama/model.py index cf7d150..a9dfcc1 100644 --- a/src/petals/models/llama/model.py +++ b/src/petals/models/llama/model.py @@ -43,7 +43,7 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel): position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[RemotePastKeyValues] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, # Not used here but needed for HF Transformers compatibility + use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, @@ -65,6 +65,7 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel): assert ( position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all() ), f"Non-consecutive position_ids are not supported, {position_ids=}" + assert use_cache is None or use_cache, f"{use_cache=} is not supported" assert not output_attentions, f"{output_attentions=} is not supported" assert not output_hidden_states, f"{output_hidden_states=} is not supported" assert return_dict is None or return_dict, f"{return_dict=} is not supported" From dd4a3230bced491c22fc44587ef72f83a676fc2d Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Mon, 4 Sep 2023 01:45:37 +0400 Subject: [PATCH 3/5] Add Falcon support (#499) This PR adds: - Support for models based on `transformers.FalconModel` (the in-library format for Falcon). Tested on Falcon-40B. - CI tests for Falcon-RW-1B. - `--throughput dry_run` option to evaluate throughput and exit right away (implemented by @mryab). Limitations: - Backward pass support is broken for now, will be fixed in #500. Co-authored-by: Max Ryabinin --- .github/workflows/run-tests.yaml | 7 +- src/petals/cli/run_server.py | 5 +- src/petals/models/__init__.py | 1 + src/petals/models/falcon/__init__.py | 15 +++ src/petals/models/falcon/block.py | 94 +++++++++++++++++ src/petals/models/falcon/config.py | 45 ++++++++ src/petals/models/falcon/model.py | 149 +++++++++++++++++++++++++++ src/petals/server/server.py | 16 +-- src/petals/utils/auto_config.py | 39 ++++++- 9 files changed, 356 insertions(+), 15 deletions(-) create mode 100644 src/petals/models/falcon/__init__.py create mode 100644 src/petals/models/falcon/block.py create mode 100644 src/petals/models/falcon/config.py create mode 100644 src/petals/models/falcon/model.py diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index 663b400..05cebdd 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -14,11 +14,13 @@ jobs: - { model: 'bigscience/bloom-560m', os: 'ubuntu', python-version: '3.11' } - { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.8' } - { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.11' } + - { model: 'petals-team/falcon-rw-1b', os: 'ubuntu', python-version: '3.8' } + - { model: 'petals-team/falcon-rw-1b', os: 'ubuntu', python-version: '3.11' } - { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.10' } - { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.11' } fail-fast: false runs-on: ${{ matrix.os }}-latest - timeout-minutes: 15 + timeout-minutes: 20 steps: - name: Increase swap space if: ${{ matrix.os == 'ubuntu' }} @@ -93,6 +95,9 @@ jobs: # [Step 2] Run PyTest + # Share disk cache between Petals servers, clients, and HF Transformers + export TRANSFORMERS_CACHE=~/.cache/petals + # Necessary for @pytest.mark.forked to work properly on macOS, see https://github.com/kevlened/pytest-parallel/issues/93 export no_proxy=* export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index 3728c16..94f5c2e 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -106,12 +106,13 @@ def main(): "and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.") parser.add_argument('--throughput', - type=lambda value: value if value in ['auto', 'eval'] else float(value), + type=lambda value: value if value in ['auto', 'eval', 'dry_run'] else float(value), default='auto', help='Expected server throughput (a float measured in RPS). ' 'If set to "auto" (default), the script evaluates network and compute throughput ' 'on the first run and uses these estimates for future runs. ' - 'If set to "eval", the script re-evaluates the throughput and overrides the cache.') + 'If set to "eval", the script re-evaluates the throughput and overrides the cache. ' + 'If set to "dry_run", the script re-evaluates the throughput and exits.') parser.add_argument('--update_period', type=float, required=False, default=120, help='Server will report blocks to DHT once in this many seconds') parser.add_argument('--expiration', type=float, required=False, default=None, diff --git a/src/petals/models/__init__.py b/src/petals/models/__init__.py index acb4d38..f52a429 100644 --- a/src/petals/models/__init__.py +++ b/src/petals/models/__init__.py @@ -1,2 +1,3 @@ from petals.models.bloom import * +from petals.models.falcon import * from petals.models.llama import * diff --git a/src/petals/models/falcon/__init__.py b/src/petals/models/falcon/__init__.py new file mode 100644 index 0000000..019ca5d --- /dev/null +++ b/src/petals/models/falcon/__init__.py @@ -0,0 +1,15 @@ +from petals.models.falcon.block import WrappedFalconBlock +from petals.models.falcon.config import DistributedFalconConfig +from petals.models.falcon.model import ( + DistributedFalconForCausalLM, + DistributedFalconForSequenceClassification, + DistributedFalconModel, +) +from petals.utils.auto_config import register_model_classes + +register_model_classes( + config=DistributedFalconConfig, + model=DistributedFalconModel, + model_for_causal_lm=DistributedFalconForCausalLM, + model_for_sequence_classification=DistributedFalconForSequenceClassification, +) diff --git a/src/petals/models/falcon/block.py b/src/petals/models/falcon/block.py new file mode 100644 index 0000000..e677e06 --- /dev/null +++ b/src/petals/models/falcon/block.py @@ -0,0 +1,94 @@ +""" +Falcon intermediate layer +Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py +See commit history for authorship. +""" +from typing import Optional, Tuple + +import torch +from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class WrappedFalconBlock(FalconDecoderLayer): + def forward( + self, + hidden_states: torch.Tensor, + *args, + attention_mask: Optional[torch.Tensor] = None, + alibi: Optional[torch.Tensor] = None, + layer_past: Optional[KVCache] = None, + use_cache: bool = False, + **kwargs + ): + batch_size, seq_length = hidden_states.shape[:2] + + if layer_past is not None: + layer_past = self._reorder_cache_from_bloom_to_falcon(layer_past) + past_length = 0 if layer_past is None else layer_past[0].shape[1] + seq_length_with_past = seq_length + past_length + + attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) + if alibi is None and self.config.alibi: + alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype) + attention_mask = FalconModel._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length) + + outputs = super().forward( + hidden_states, + *args, + attention_mask=attention_mask, + alibi=alibi, + layer_past=layer_past, + use_cache=use_cache, + **kwargs + ) + + if use_cache: + present_key_value = outputs[-1] + present_key_value = self._reorder_cache_from_falcon_to_bloom(present_key_value) + outputs = outputs[:-1] + (present_key_value,) + + return outputs + + def _reorder_cache_from_bloom_to_falcon(self, key_value: KVCache) -> KVCache: + key_states, value_states = key_value + + key_states = key_states.permute(0, 2, 1) + assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim] + + if self.config.new_decoder_architecture: + key_states = self._expand_states(key_states) + value_states = self._expand_states(value_states) + + return (key_states, value_states) + + def _reorder_cache_from_falcon_to_bloom(self, key_value: KVCache) -> KVCache: + key_states, value_states = key_value + + if self.config.new_decoder_architecture: + key_states = self._collapse_states(key_states) + value_states = self._collapse_states(value_states) + + assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim] + key_states = key_states.permute(0, 2, 1) + + return (key_states, value_states) + + def _expand_states(self, state: torch.Tensor) -> torch.Tensor: + batch_size_x_num_kv_heads, seq_len, head_dim = state.shape + batch_size = batch_size_x_num_kv_heads // self.config.num_kv_heads + + state = state.view(batch_size, self.config.num_kv_heads, 1, seq_len, head_dim) + state = state.expand(-1, -1, self.config.num_key_value_groups, -1, -1) # No copy + state = state.reshape(batch_size * self.config.num_attention_heads, seq_len, head_dim) # Involves a copy + return state + + def _collapse_states(self, state: torch.Tensor) -> torch.Tensor: + batch_size_x_num_attn_heads, seq_len, head_dim = state.shape + batch_size = batch_size_x_num_attn_heads // self.config.num_attention_heads + + state = state.view(batch_size, self.config.num_kv_heads, self.config.num_key_value_groups, seq_len, head_dim) + state = state[:, :, 0] + state = state.view(batch_size * self.config.num_kv_heads, seq_len, head_dim) + return state diff --git a/src/petals/models/falcon/config.py b/src/petals/models/falcon/config.py new file mode 100644 index 0000000..a1ae5e9 --- /dev/null +++ b/src/petals/models/falcon/config.py @@ -0,0 +1,45 @@ +import os +from typing import Optional, Union + +from hivemind import get_logger +from transformers.models.falcon import FalconConfig +from transformers.models.falcon.modeling_falcon import FalconAttention + +from petals.client.config import ClientConfig +from petals.client.lm_head import LMHeadConfig +from petals.client.ptune import PTuneConfig +from petals.models.falcon.block import WrappedFalconBlock +from petals.utils.auto_config import DefaultRevisionMixin + +logger = get_logger(__name__) + + +class DistributedFalconConfig(DefaultRevisionMixin, FalconConfig, ClientConfig, PTuneConfig, LMHeadConfig): + block_class = WrappedFalconBlock + attn_class = FalconAttention + block_prefix = "transformer.h" + + @property + def num_key_value_groups(self) -> int: + if self.new_decoder_architecture: + return self.num_attention_heads // self.num_kv_heads + if self.multi_query: + return self.num_attention_heads + return 1 + + @classmethod + def from_pretrained( + cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs + ): + loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path) + if loading_from_repo and dht_prefix is None: + dht_prefix = str(model_name_or_path) + dht_prefix = dht_prefix.split("/")[-1] # Use only repo name to merge blocks hosted by different accounts + dht_prefix = dht_prefix.replace(".", "-") + logger.info(f"Using DHT prefix: {dht_prefix}") + + result = super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs) + config = result[0] if isinstance(result, tuple) else result + if config.pad_token_id is None: + config.pad_token_id = 0 + return result diff --git a/src/petals/models/falcon/model.py b/src/petals/models/falcon/model.py new file mode 100644 index 0000000..3a2a6b0 --- /dev/null +++ b/src/petals/models/falcon/model.py @@ -0,0 +1,149 @@ +from typing import Optional + +import hivemind +import torch +import torch.nn as nn +from hivemind.utils.logging import get_logger +from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions +from transformers.models.falcon import ( + FalconForCausalLM, + FalconForSequenceClassification, + FalconModel, + FalconPreTrainedModel, +) + +from petals.client.from_pretrained import FromPretrainedMixin +from petals.client.lm_head import LMHead +from petals.client.ptune import PTuneMixin +from petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues +from petals.client.remote_sequential import RemoteSequential +from petals.models.falcon.config import DistributedFalconConfig +from petals.utils.auto_config import DefaultRevisionMixin + +logger = get_logger(__name__) + + +class DistributedFalconModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMixin, FalconModel): + """FalconModel, but all transformer layers are hosted by the swarm""" + + _keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing + _keys_to_ignore_on_load_unexpected = [r"^transformer\.h\."] + + config_class = DistributedFalconConfig + + def __init__(self, config: DistributedFalconConfig, *, dht: Optional[hivemind.DHT] = None): + n_layer, config.num_hidden_layers = config.num_hidden_layers, 0 # Prevent initialization + super().__init__(config) + assert len(self.h) == 0 + config.num_hidden_layers = n_layer + + self.h = RemoteSequential(config, dht=dht) + + self.requires_grad_(False) # Forbid accumulate grads for embeddings and layernorm + self.init_prompts(config) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[RemotePastKeyValues] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + # The causal mask will be added on the server-side + assert ( + attention_mask is None or (attention_mask == 1).all() + ), f"Custom attention masks are not supported, {attention_mask=}" + assert head_mask is None, f"Custom head masks are not supported, {head_mask=}" + assert use_cache is None or use_cache, f"{use_cache=} is not supported" + assert not output_attentions, f"{output_attentions=} is not supported" + assert not output_hidden_states, f"{output_hidden_states=} is not supported" + assert return_dict is None or return_dict, f"{return_dict=} is not supported" + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + if self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0: + batch_size = inputs_embeds.shape[0] + prompts, intermediate_prompts = self.get_prompt(batch_size) + inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1) + else: + prompts = intermediate_prompts = None + + hidden_states = self.word_embeddings_layernorm(inputs_embeds) + output_shape = input_shape + (hidden_states.size(-1),) + + hidden_states = self.h( + hidden_states, + prompts=intermediate_prompts, + hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None, + ) + + # Remove prefix + if self.config.tuning_mode and "ptune" in self.config.tuning_mode: + hidden_states = hidden_states[:, self.pre_seq_len :] + + # Add last hidden state + hidden_states = self.ln_f(hidden_states) + hidden_states = hidden_states.view(output_shape) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=RemotePastKeyValues(), + hidden_states=None, + attentions=None, + ) + + @property + def word_embeddings_layernorm(self) -> nn.Module: # For compatibility with RemoteGenerationMixin + return nn.Identity() + + +class DistributedFalconForCausalLM(DefaultRevisionMixin, FromPretrainedMixin, RemoteGenerationMixin, FalconForCausalLM): + _keys_to_ignore_on_load_missing = DistributedFalconModel._keys_to_ignore_on_load_missing + _keys_to_ignore_on_load_unexpected = DistributedFalconModel._keys_to_ignore_on_load_unexpected + + config_class = DistributedFalconConfig + + def __init__(self, config: DistributedFalconConfig): + FalconPreTrainedModel.__init__(self, config) + self.transformer = DistributedFalconModel(config) + self.lm_head = LMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head + + +class DistributedFalconForSequenceClassification( + DefaultRevisionMixin, FromPretrainedMixin, FalconForSequenceClassification +): + _keys_to_ignore_on_load_missing = DistributedFalconModel._keys_to_ignore_on_load_missing + _keys_to_ignore_on_load_unexpected = DistributedFalconModel._keys_to_ignore_on_load_unexpected + + config_class = DistributedFalconConfig + + def __init__(self, config: DistributedFalconConfig): + FalconPreTrainedModel.__init__(self, config) + self.num_labels = config.num_labels + + self.transformer = DistributedFalconModel(config) + self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() diff --git a/src/petals/server/server.py b/src/petals/server/server.py index ab646a5..fd9f766 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -5,6 +5,7 @@ import math import multiprocessing as mp import os import random +import sys import threading import time from typing import Dict, List, Optional, Sequence, Union @@ -186,10 +187,7 @@ class Server: check_device_balance(self.tensor_parallel_devices) if quant_type is None: - if device.type == "cuda": - quant_type = QuantType.NF4 if self.block_config.model_type == "llama" else QuantType.INT8 - else: - quant_type = QuantType.NONE + quant_type = QuantType.NF4 if device.type == "cuda" else QuantType.NONE self.quant_type = quant_type logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, quant_type)} format") @@ -234,8 +232,9 @@ class Server: self.attn_cache_bytes = self._cache_bytes_per_block * num_blocks logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB") - assert isinstance(throughput, float) or throughput in ["auto", "eval"] - if throughput in ["auto", "eval"]: + assert isinstance(throughput, float) or throughput in ["auto", "eval", "dry_run"] + if throughput in ["auto", "eval", "dry_run"]: + force_eval = throughput in ["eval", "dry_run"] throughput_info = get_server_throughput( converted_model_name_or_path, self.block_config, @@ -245,9 +244,12 @@ class Server: quant_type=quant_type, tensor_parallel_devices=self.tensor_parallel_devices, reachable_via_relay=reachable_via_relay, - force_eval=(throughput == "eval"), + force_eval=force_eval, cache_dir=cache_dir, ) + if throughput == "dry_run": + logger.info("Finished estimating throughput, exiting") + sys.exit(0) else: throughput_info = {"throughput": throughput} self.server_info = ServerInfo( diff --git a/src/petals/utils/auto_config.py b/src/petals/utils/auto_config.py index 70f37a3..0cec83d 100644 --- a/src/petals/utils/auto_config.py +++ b/src/petals/utils/auto_config.py @@ -1,12 +1,14 @@ import os -import re from dataclasses import dataclass from typing import Optional, Type, Union +from hivemind import get_logger from transformers import AutoConfig, PretrainedConfig, PreTrainedModel from petals.utils.hf_auth import always_needs_auth +logger = get_logger(__name__) + @dataclass class _ModelClasses: @@ -49,17 +51,44 @@ class _AutoDistributedBase: return proper_cls.from_pretrained(model_name_or_path, *args, **kwargs) -class AutoDistributedConfig(_AutoDistributedBase): +class DefaultRevisionMixin: + """ + Petals only supports Falcon loaded in the new in-library format (transformers.FalconModel). + TII models were recently converted to this format but then reverted back due to compatibility issues. + We chose to support only the new format since HF staff promised to eventually convert these models + to the new format again, see https://huggingface.co/tiiuae/falcon-40b/discussions/90#64b4d23bf44fd957492f7602 + Until it happens, we override the default `main` revision for the TII repos with the commit + pointing out to the model in the in-library format. + """ + + DEFAULT_REVISIONS = { + "tiiuae/falcon-40b": "f1ba7d328c06aa6fbb4a8afd3c756f46d7e6b232", + "tiiuae/falcon-40b-instruct": "7475ff8cfc36ed9a962b658ae3c33391566a85a5", + "tiiuae/falcon-7b": "4e2d06f0a7c6370ebabbc30c6f59377ae8f73d76", + "tiiuae/falcon-7b-instruct": "f8dac3fff96d5debd43edf56fb4e1abcfffbef28", + } + + @classmethod + def from_pretrained( + cls, model_name_or_path: Union[str, os.PathLike, None], *args, revision: Optional[str] = None, **kwargs + ): + if revision is None and model_name_or_path in cls.DEFAULT_REVISIONS: + revision = cls.DEFAULT_REVISIONS[model_name_or_path] + logger.info(f"Loading {model_name_or_path}, revision {revision}") + return super().from_pretrained(model_name_or_path, *args, revision=revision, **kwargs) + + +class AutoDistributedConfig(DefaultRevisionMixin, _AutoDistributedBase): _mapping_field = "config" -class AutoDistributedModel(_AutoDistributedBase): +class AutoDistributedModel(DefaultRevisionMixin, _AutoDistributedBase): _mapping_field = "model" -class AutoDistributedModelForCausalLM(_AutoDistributedBase): +class AutoDistributedModelForCausalLM(DefaultRevisionMixin, _AutoDistributedBase): _mapping_field = "model_for_causal_lm" -class AutoDistributedModelForSequenceClassification(_AutoDistributedBase): +class AutoDistributedModelForSequenceClassification(DefaultRevisionMixin, _AutoDistributedBase): _mapping_field = "model_for_sequence_classification" From d40eb6c7015c0de2914cb013601bdd47544d16ef Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Mon, 4 Sep 2023 12:25:29 +0400 Subject: [PATCH 4/5] Fix prompt tuning after #464 (#501) Unfortunately, running inference in models with `"ptune" in config.tuning_mode` was broken after #464. --- src/petals/client/remote_generation.py | 5 +++-- src/petals/models/bloom/model.py | 5 +++-- src/petals/models/falcon/model.py | 5 +++-- src/petals/models/llama/model.py | 5 +++-- 4 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/petals/client/remote_generation.py b/src/petals/client/remote_generation.py index e392b4f..97a115a 100644 --- a/src/petals/client/remote_generation.py +++ b/src/petals/client/remote_generation.py @@ -87,10 +87,11 @@ class RemoteGenerationMixin(_SkipTokensMixin): max_new_tokens is None ), "You should set `max_length` or `max_new_tokens` (but not both) to reserve server-side attention caches" + session_max_length = self.transformer.config.pre_seq_len if max_length is not None: - session_max_length = max_length + session_max_length += max_length else: - session_max_length = (inputs.shape[1] if inputs is not None else 0) + max_new_tokens + session_max_length += (inputs.shape[1] if inputs is not None else 0) + max_new_tokens context_manager = self.inference_session(max_length=session_max_length) with context_manager as session: diff --git a/src/petals/models/bloom/model.py b/src/petals/models/bloom/model.py index cf83822..784418f 100644 --- a/src/petals/models/bloom/model.py +++ b/src/petals/models/bloom/model.py @@ -71,7 +71,8 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel): if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) - if self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0: + use_prompts = self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0 + if use_prompts: batch_size = inputs_embeds.shape[0] prompts, intermediate_prompts = self.get_prompt(batch_size) inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1) @@ -88,7 +89,7 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel): ) # Remove prefix - if self.config.tuning_mode and "ptune" in self.config.tuning_mode: + if use_prompts: hidden_states = hidden_states[:, self.pre_seq_len :] # Add last hidden state diff --git a/src/petals/models/falcon/model.py b/src/petals/models/falcon/model.py index 3a2a6b0..32c0b6f 100644 --- a/src/petals/models/falcon/model.py +++ b/src/petals/models/falcon/model.py @@ -77,7 +77,8 @@ class DistributedFalconModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMix if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) - if self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0: + use_prompts = self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0 + if use_prompts: batch_size = inputs_embeds.shape[0] prompts, intermediate_prompts = self.get_prompt(batch_size) inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1) @@ -94,7 +95,7 @@ class DistributedFalconModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMix ) # Remove prefix - if self.config.tuning_mode and "ptune" in self.config.tuning_mode: + if use_prompts: hidden_states = hidden_states[:, self.pre_seq_len :] # Add last hidden state diff --git a/src/petals/models/llama/model.py b/src/petals/models/llama/model.py index a9dfcc1..3360f40 100644 --- a/src/petals/models/llama/model.py +++ b/src/petals/models/llama/model.py @@ -73,7 +73,8 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel): if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.layers.position == 0: + use_prompts = self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.layers.position == 0 + if use_prompts: batch_size = inputs_embeds.shape[0] prompts, intermediate_prompts = self.get_prompt(batch_size) inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1) @@ -90,7 +91,7 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel): ) # Remove prefix - if self.config.tuning_mode and "ptune" in self.config.tuning_mode: + if use_prompts: hidden_states = hidden_states[:, self.pre_seq_len :] # Add last hidden state From 1ebd88ae7b8238fe1778409f55fa88aa542a2541 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Mon, 4 Sep 2023 14:38:32 +0200 Subject: [PATCH 5/5] Optimize the Falcon block for inference (#500) This PR attempts to optimize the inference of Falcon models in the single-token setup by reducing the majority of Python overhead and making several assumptions about the setup. Specifically, * Layer normalization, QKV projection (with splitting) and rotary embeddings are executed through CUDA graphs, which reduces most overhead related to small kernel launche * If no sin/cos tensors are cached by the rotary embedding layer, we cache them for 8192 tokens (INFERENCE_MAX_LENGTH) during the first forward pass. In general, it should be beneficial to always run a max-length sequence before starting a block, but this is a question for another PR The PR also adds a small test to ensure that the results (without quantization) of the block before and after quantization indeed match. Lastly, the pull request makes the backward pass work (as discussed in https://github.com/bigscience-workshop/petals/pull/499) by making cached sin/cos for RotaryEmbedding into buffers and disabling the inference mode during their creation. --- src/petals/models/falcon/block.py | 394 +++++++++++++++++++++++++++++- tests/test_optimized_layers.py | 128 ++++++++++ 2 files changed, 518 insertions(+), 4 deletions(-) create mode 100644 tests/test_optimized_layers.py diff --git a/src/petals/models/falcon/block.py b/src/petals/models/falcon/block.py index e677e06..a510aba 100644 --- a/src/petals/models/falcon/block.py +++ b/src/petals/models/falcon/block.py @@ -3,15 +3,399 @@ Falcon intermediate layer Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py See commit history for authorship. """ +import math +from functools import partial from typing import Optional, Tuple import torch -from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor +import torch.nn as nn +import torch.nn.functional as F +from transformers.models.falcon.modeling_falcon import ( + FalconAttention, + FalconConfig, + FalconDecoderLayer, + FalconLinear, + FalconMLP, + FalconModel, + LayerNorm, + build_alibi_tensor, + dropout_add, + rotate_half, +) KVCache = Tuple[torch.Tensor, torch.Tensor] +INFERENCE_MAX_LENGTH = 8192 -class WrappedFalconBlock(FalconDecoderLayer): +def apply_rotary(query, key, cos, sin): + return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin) + + +class OptimizedFalconRotaryEmbedding(nn.Module): + def __init__(self, head_dim: int, base=10000): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.head_dim = head_dim + self.seq_len_cached = -1 + + self.cuda_graph = None + self.input_surface = None + self.static_outputs = None + + def _optimized_apply_rotary(self, query, key, cos, sin): + if self.cuda_graph is None: + self.cuda_graph = torch.cuda.CUDAGraph() + self.input_surface = (query, key, cos, sin) + + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(3): + apply_rotary(*self.input_surface) + torch.cuda.current_stream().wait_stream(s) + + with torch.cuda.graph(self.cuda_graph): + self.static_outputs = apply_rotary(*self.input_surface) + + inputs = (query, key, cos, sin) + for static_input, data in zip(self.input_surface, inputs): + static_input.copy_(data) + self.cuda_graph.replay() + return tuple(o.detach() for o in self.static_outputs) + + def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor: + total_length = seq_len + past_key_values_length + if self.seq_len_cached == -1: + # warm up the cache + total_length = max(INFERENCE_MAX_LENGTH, total_length) + + if total_length > self.seq_len_cached: + with torch.inference_mode(False): + self.seq_len_cached = total_length + t = torch.arange(total_length, device=device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1).to(device) + + if dtype in [torch.float16, torch.bfloat16]: + emb = emb.float() + + self.register_buffer("cos_cached", emb.cos()[None, :, :].type(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, :, :].type(dtype), persistent=False) + + return ( + self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length].type(dtype), + self.sin_cached[:, past_key_values_length : seq_len + past_key_values_length].type(dtype), + ) + + def forward(self, query, key, past_key_values_length=0): + batch, seq_len, head_dim = query.shape + cos, sin = self.cos_sin(seq_len, past_key_values_length, query.device, query.dtype) + if seq_len == 1 and torch.is_inference_mode_enabled() and query.device.type == "cuda": + return self._optimized_apply_rotary(query, key, cos, sin) + else: + return apply_rotary(query, key, cos, sin) + + +def split_heads( + fused_qkv: torch.Tensor, num_heads: int, num_kv_heads: int, head_dim: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + batch, seq_len, _ = fused_qkv.shape + qkv = fused_qkv.view(batch, seq_len, -1, num_heads // num_kv_heads + 2, head_dim) + query, key, value = torch.split(qkv, [num_heads // num_kv_heads, 1, 1], dim=3) + key = torch.broadcast_to(key, query.shape) + value = torch.broadcast_to(value, query.shape) + + query, key, value = [x.flatten(2, 3) for x in (query, key, value)] + return query, key, value + + +class OptimizedFalconAttention(FalconAttention): + def __init__(self, config: FalconConfig): + nn.Module.__init__(self) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.split_size = self.hidden_size + self.hidden_dropout = config.hidden_dropout + + if self.head_dim * self.num_heads != self.hidden_size: + raise ValueError( + f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:" + f" {self.num_heads})." + ) + + self.maybe_rotary = OptimizedFalconRotaryEmbedding(config.head_dim) if config.rotary else lambda q, k, t: (q, k) + + # Layer-wise attention scaling + self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim) + self.beta = self.inv_norm_factor + if config.new_decoder_architecture: + qkv_out_dim = (config.num_kv_heads * 2 + config.num_attention_heads) * self.head_dim + elif config.multi_query: + qkv_out_dim = self.hidden_size + 2 * self.head_dim + else: + qkv_out_dim = 3 * self.hidden_size + self.query_key_value = FalconLinear(self.hidden_size, qkv_out_dim, bias=config.bias) + self.new_decoder_architecture = config.new_decoder_architecture + self.multi_query = config.multi_query + self.dense = FalconLinear(self.hidden_size, self.hidden_size, bias=config.bias) + self.attention_dropout = nn.Dropout(config.attention_dropout) + self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1 + + if self.new_decoder_architecture: + self._split_heads = partial( + split_heads, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, head_dim=self.head_dim + ) + self.split_graph = None + self.input_surface = None + self.static_outputs = None + + def _optimized_split_heads(self, fused_qkv): + if self.split_graph is None: + self.split_graph = torch.cuda.CUDAGraph() + self.input_surface = fused_qkv + + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(3): + self._split_heads(fused_qkv) + torch.cuda.current_stream().wait_stream(s) + + with torch.cuda.graph(self.split_graph): + self.static_outputs = self._split_heads(self.input_surface) + + self.input_surface.copy_(fused_qkv) + self.split_graph.replay() + return tuple(o.detach() for o in self.static_outputs) + + def forward( + self, + hidden_states: torch.Tensor, + alibi: Optional[torch.Tensor], + attention_mask: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + ): + assert not output_attentions + + fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] + + if ( + self.new_decoder_architecture + and hidden_states.size(1) == 1 + and torch.is_inference_mode_enabled() + and hidden_states.device.type == "cuda" + ): + query_layer, key_layer, value_layer = self._optimized_split_heads(fused_qkv) + else: + # 3 x [batch_size, seq_length, num_heads, head_dim] + (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) + + num_kv_heads = self.num_heads + batch_size, query_length, _, _ = query_layer.shape + + query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim) + key_layer = key_layer.transpose(1, 2).reshape( + batch_size * num_kv_heads, + query_length, + self.head_dim, + ) + value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim) + + past_kv_length = 0 if layer_past is None else layer_past[0].shape[1] + query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length) + + if layer_past is not None: + past_key, past_value = layer_past + # concatenate along seq_length dimension: + # - key: [batch_size * self.num_heads, kv_length, head_dim] + # - value: [batch_size * self.num_heads, kv_length, head_dim] + key_layer = torch.cat((past_key, key_layer), dim=1) + value_layer = torch.cat((past_value, value_layer), dim=1) + + _, kv_length, _ = key_layer.shape + if use_cache: + present = (key_layer, value_layer) + else: + present = None + + query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim) + key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim) + value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim) + + attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype) + + if alibi is None: + attn_output = F.scaled_dot_product_attention( + query_layer_, key_layer_, value_layer_, attn_mask=attention_mask_float, dropout_p=0.0, is_causal=False + ) + + attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim) + attn_output = attn_output.permute(0, 2, 1, 3) + attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) + + output_tensor = self.dense(attn_output) + + return output_tensor, present + else: + matmul_result = query_layer_ @ key_layer_.transpose(-1, -2) + + # change view to [batch_size, num_heads, q_length, kv_length] + attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length) + + # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] + input_dtype = attention_scores.dtype + # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38` + if input_dtype == torch.float16 or input_dtype == torch.bfloat16: + attention_scores = attention_scores.to(torch.float32) + # Matt (HF) note: We could possibly use F.scaled_dot_product_attention here too, by + # adding (alibi * self.inv_norm_factor) to attention_mask_float. I think this would be mathematically + # equivalent and more performant, but there might be a numerical difference. If you're reading this + # and you'd like to experiment and maybe file a PR, feel free! + attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1) + attention_logits *= self.inv_norm_factor + attention_probs = F.softmax(attention_logits + attention_mask_float, dim=-1, dtype=hidden_states.dtype) + # [batch_size, num_heads, q_length, kv_length] + attention_probs = self.attention_dropout(attention_probs) + + if head_mask is not None: + attention_probs = attention_probs * head_mask + + # change view [batch_size, num_heads, q_length, kv_length] + attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length) + + # matmul: [batch_size * num_heads, q_length, head_dim] + context_layer = (attention_probs_reshaped @ value_layer_).flatten(0, 1) + + # change view [batch_size, q_length, num_heads * head_dim] + context_layer = self._merge_heads(context_layer) + + output_tensor = self.dense(context_layer) + + if output_attentions: + return output_tensor, present, attention_probs + else: + return output_tensor, present + + +class OptimizedFalconDecoderLayer(FalconDecoderLayer): + def __init__(self, config: FalconConfig): + nn.Module.__init__(self) + hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + + self.mlp = FalconMLP(config) + self.hidden_dropout = config.hidden_dropout + self.config = config + + self.self_attention = OptimizedFalconAttention(config) + + if self.config.alibi or not config.new_decoder_architecture: + if config.new_decoder_architecture: + # The layer norm before self-attention + self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + # The layer norm before the MLP + self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + else: + self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + if not config.parallel_attn: + self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + else: + self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + self.ln_graph = None + self.static_input = None + self.static_outputs = None + + def _optimized_apply_ln(self, hidden_states): + if self.ln_graph is None: + self.ln_graph = torch.cuda.CUDAGraph() + self.static_input = hidden_states + + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(3): + self.ln_attn(hidden_states) + self.ln_mlp(hidden_states) + torch.cuda.current_stream().wait_stream(s) + + with torch.cuda.graph(self.ln_graph): + ln_attn_output = self.ln_attn(hidden_states) + ln_mlp_output = self.ln_mlp(hidden_states) + self.static_outputs = (ln_attn_output, ln_mlp_output) + + self.static_input.copy_(hidden_states) + self.ln_graph.replay() + return tuple(o.detach() for o in self.static_outputs) + + def forward( + self, + hidden_states: torch.Tensor, + alibi: Optional[torch.Tensor], + attention_mask: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + ): + residual = hidden_states + + if self.config.new_decoder_architecture: + if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda": + attention_layernorm_out, mlp_layernorm_out = self._optimized_apply_ln(hidden_states) + else: + attention_layernorm_out = self.ln_attn(hidden_states) + mlp_layernorm_out = self.ln_mlp(hidden_states) + else: + attention_layernorm_out = self.input_layernorm(hidden_states) + + attn_outputs = self.self_attention( + attention_layernorm_out, + layer_past=layer_past, + attention_mask=attention_mask, + alibi=alibi, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + attention_output = attn_outputs[0] + + if not self.config.new_decoder_architecture: + if self.config.parallel_attn: + mlp_layernorm_out = attention_layernorm_out + else: + residual = dropout_add( + attention_output, residual, self.config.attention_dropout, training=self.training + ) + mlp_layernorm_out = self.post_attention_layernorm(residual) + + outputs = attn_outputs[1:] + + mlp_output = self.mlp(mlp_layernorm_out) + + if self.config.new_decoder_architecture or self.config.parallel_attn: + mlp_output += attention_output + + output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training) + + if use_cache: + outputs = (output,) + outputs + else: + outputs = (output,) + outputs[1:] + + return outputs # hidden_states, present, attentions + + +class WrappedFalconBlock(OptimizedFalconDecoderLayer): def forward( self, hidden_states: torch.Tensor, @@ -20,8 +404,10 @@ class WrappedFalconBlock(FalconDecoderLayer): alibi: Optional[torch.Tensor] = None, layer_past: Optional[KVCache] = None, use_cache: bool = False, - **kwargs + **kwargs, ): + assert attention_mask is None + batch_size, seq_length = hidden_states.shape[:2] if layer_past is not None: @@ -41,7 +427,7 @@ class WrappedFalconBlock(FalconDecoderLayer): alibi=alibi, layer_past=layer_past, use_cache=use_cache, - **kwargs + **kwargs, ) if use_cache: diff --git a/tests/test_optimized_layers.py b/tests/test_optimized_layers.py new file mode 100644 index 0000000..5baa1a2 --- /dev/null +++ b/tests/test_optimized_layers.py @@ -0,0 +1,128 @@ +from typing import Optional, Tuple + +import pytest +import torch +from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor + +from petals.utils.auto_config import AutoDistributedConfig +from petals.utils.convert_block import QuantType, convert_block +from test_utils import MODEL_NAME + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class UnoptimizedWrappedFalconBlock(FalconDecoderLayer): + def forward( + self, + hidden_states: torch.Tensor, + *args, + attention_mask: Optional[torch.Tensor] = None, + alibi: Optional[torch.Tensor] = None, + layer_past: Optional[KVCache] = None, + use_cache: bool = False, + **kwargs, + ): + batch_size, seq_length = hidden_states.shape[:2] + + if layer_past is not None: + layer_past = self._reorder_cache_from_bloom_to_falcon(layer_past) + past_length = 0 if layer_past is None else layer_past[0].shape[1] + seq_length_with_past = seq_length + past_length + + attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) + if alibi is None and self.config.alibi: + alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype) + attention_mask = FalconModel._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length) + + outputs = super().forward( + hidden_states, + *args, + attention_mask=attention_mask, + alibi=alibi, + layer_past=layer_past, + use_cache=use_cache, + **kwargs, + ) + + if use_cache: + present_key_value = outputs[-1] + present_key_value = self._reorder_cache_from_falcon_to_bloom(present_key_value) + outputs = outputs[:-1] + (present_key_value,) + + return outputs + + def _reorder_cache_from_bloom_to_falcon(self, key_value: KVCache) -> KVCache: + key_states, value_states = key_value + + key_states = key_states.permute(0, 2, 1) + assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim] + + if self.config.new_decoder_architecture: + key_states = self._expand_states(key_states) + value_states = self._expand_states(value_states) + + return (key_states, value_states) + + def _reorder_cache_from_falcon_to_bloom(self, key_value: KVCache) -> KVCache: + key_states, value_states = key_value + + if self.config.new_decoder_architecture: + key_states = self._collapse_states(key_states) + value_states = self._collapse_states(value_states) + + assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim] + key_states = key_states.permute(0, 2, 1) + + return (key_states, value_states) + + def _expand_states(self, state: torch.Tensor) -> torch.Tensor: + batch_size_x_num_kv_heads, seq_len, head_dim = state.shape + batch_size = batch_size_x_num_kv_heads // self.config.num_kv_heads + + state = state.view(batch_size, self.config.num_kv_heads, 1, seq_len, head_dim) + state = state.expand(-1, -1, self.config.num_key_value_groups, -1, -1) # No copy + state = state.reshape(batch_size * self.config.num_attention_heads, seq_len, head_dim) # Involves a copy + return state + + def _collapse_states(self, state: torch.Tensor) -> torch.Tensor: + batch_size_x_num_attn_heads, seq_len, head_dim = state.shape + batch_size = batch_size_x_num_attn_heads // self.config.num_attention_heads + + state = state.view(batch_size, self.config.num_kv_heads, self.config.num_key_value_groups, seq_len, head_dim) + state = state[:, :, 0] + state = state.view(batch_size * self.config.num_kv_heads, seq_len, head_dim) + return state + + +@pytest.mark.skipif("falcon" not in MODEL_NAME, reason="This test is applicable only to Falcon models") +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +@pytest.mark.forked +def test_falcon(device): + if device == "cuda:0" and not torch.cuda.is_available(): + pytest.skip("CUDA tests can be run only in CUDA-enabled setups") + + config = AutoDistributedConfig.from_pretrained(MODEL_NAME) + + tensor_parallel_devices = (device,) + dtype = torch.bfloat16 + quant_type = QuantType.NONE + + block = config.block_class(config).to(dtype) + block = convert_block(block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True) + + unopt_block = UnoptimizedWrappedFalconBlock(config).to(dtype) + unopt_block = convert_block( + unopt_block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True + ) + + unopt_block.load_state_dict(block.state_dict()) + cache = unopt_cache = None + + with torch.inference_mode(): + for length in [10, 1, 1, 1]: + dummy_input = torch.randn(1, length, config.hidden_size, device=device, dtype=dtype) + block_output, cache = block(dummy_input, layer_past=cache, use_cache=True) + unopt_block_output, unopt_cache = unopt_block(dummy_input, layer_past=unopt_cache, use_cache=True) + assert torch.allclose(block_output, unopt_block_output, atol=1e-6, rtol=0), length + assert torch.allclose(cache[0], unopt_cache[0], atol=1e-6, rtol=0), length + assert torch.allclose(cache[1], unopt_cache[1], atol=1e-6, rtol=0), length