diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index 663b400..05cebdd 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -14,11 +14,13 @@ jobs: - { model: 'bigscience/bloom-560m', os: 'ubuntu', python-version: '3.11' } - { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.8' } - { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.11' } + - { model: 'petals-team/falcon-rw-1b', os: 'ubuntu', python-version: '3.8' } + - { model: 'petals-team/falcon-rw-1b', os: 'ubuntu', python-version: '3.11' } - { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.10' } - { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.11' } fail-fast: false runs-on: ${{ matrix.os }}-latest - timeout-minutes: 15 + timeout-minutes: 20 steps: - name: Increase swap space if: ${{ matrix.os == 'ubuntu' }} @@ -93,6 +95,9 @@ jobs: # [Step 2] Run PyTest + # Share disk cache between Petals servers, clients, and HF Transformers + export TRANSFORMERS_CACHE=~/.cache/petals + # Necessary for @pytest.mark.forked to work properly on macOS, see https://github.com/kevlened/pytest-parallel/issues/93 export no_proxy=* export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES diff --git a/setup.cfg b/setup.cfg index cf14434..c8dbc9a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -40,7 +40,7 @@ install_requires = transformers>=4.32.0,<5.0.0 # if you change this, please also change version assert in petals/__init__.py speedtest-cli==2.1.3 pydantic>=1.10,<2.0 # 2.0 is incompatible with hivemind yet - hivemind @ git+https://github.com/learning-at-home/hivemind + hivemind==1.1.10.post2 tensor_parallel==1.0.23 humanfriendly async-timeout>=4.0.2 diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index 3728c16..94f5c2e 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -106,12 +106,13 @@ def main(): "and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.") parser.add_argument('--throughput', - type=lambda value: value if value in ['auto', 'eval'] else float(value), + type=lambda value: value if value in ['auto', 'eval', 'dry_run'] else float(value), default='auto', help='Expected server throughput (a float measured in RPS). ' 'If set to "auto" (default), the script evaluates network and compute throughput ' 'on the first run and uses these estimates for future runs. ' - 'If set to "eval", the script re-evaluates the throughput and overrides the cache.') + 'If set to "eval", the script re-evaluates the throughput and overrides the cache. ' + 'If set to "dry_run", the script re-evaluates the throughput and exits.') parser.add_argument('--update_period', type=float, required=False, default=120, help='Server will report blocks to DHT once in this many seconds') parser.add_argument('--expiration', type=float, required=False, default=None, diff --git a/src/petals/client/remote_generation.py b/src/petals/client/remote_generation.py index e392b4f..97a115a 100644 --- a/src/petals/client/remote_generation.py +++ b/src/petals/client/remote_generation.py @@ -87,10 +87,11 @@ class RemoteGenerationMixin(_SkipTokensMixin): max_new_tokens is None ), "You should set `max_length` or `max_new_tokens` (but not both) to reserve server-side attention caches" + session_max_length = self.transformer.config.pre_seq_len if max_length is not None: - session_max_length = max_length + session_max_length += max_length else: - session_max_length = (inputs.shape[1] if inputs is not None else 0) + max_new_tokens + session_max_length += (inputs.shape[1] if inputs is not None else 0) + max_new_tokens context_manager = self.inference_session(max_length=session_max_length) with context_manager as session: diff --git a/src/petals/models/__init__.py b/src/petals/models/__init__.py index acb4d38..f52a429 100644 --- a/src/petals/models/__init__.py +++ b/src/petals/models/__init__.py @@ -1,2 +1,3 @@ from petals.models.bloom import * +from petals.models.falcon import * from petals.models.llama import * diff --git a/src/petals/models/bloom/model.py b/src/petals/models/bloom/model.py index cf83822..784418f 100644 --- a/src/petals/models/bloom/model.py +++ b/src/petals/models/bloom/model.py @@ -71,7 +71,8 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel): if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) - if self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0: + use_prompts = self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0 + if use_prompts: batch_size = inputs_embeds.shape[0] prompts, intermediate_prompts = self.get_prompt(batch_size) inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1) @@ -88,7 +89,7 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel): ) # Remove prefix - if self.config.tuning_mode and "ptune" in self.config.tuning_mode: + if use_prompts: hidden_states = hidden_states[:, self.pre_seq_len :] # Add last hidden state diff --git a/src/petals/models/falcon/__init__.py b/src/petals/models/falcon/__init__.py new file mode 100644 index 0000000..019ca5d --- /dev/null +++ b/src/petals/models/falcon/__init__.py @@ -0,0 +1,15 @@ +from petals.models.falcon.block import WrappedFalconBlock +from petals.models.falcon.config import DistributedFalconConfig +from petals.models.falcon.model import ( + DistributedFalconForCausalLM, + DistributedFalconForSequenceClassification, + DistributedFalconModel, +) +from petals.utils.auto_config import register_model_classes + +register_model_classes( + config=DistributedFalconConfig, + model=DistributedFalconModel, + model_for_causal_lm=DistributedFalconForCausalLM, + model_for_sequence_classification=DistributedFalconForSequenceClassification, +) diff --git a/src/petals/models/falcon/block.py b/src/petals/models/falcon/block.py new file mode 100644 index 0000000..a510aba --- /dev/null +++ b/src/petals/models/falcon/block.py @@ -0,0 +1,480 @@ +""" +Falcon intermediate layer +Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py +See commit history for authorship. +""" +import math +from functools import partial +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers.models.falcon.modeling_falcon import ( + FalconAttention, + FalconConfig, + FalconDecoderLayer, + FalconLinear, + FalconMLP, + FalconModel, + LayerNorm, + build_alibi_tensor, + dropout_add, + rotate_half, +) + +KVCache = Tuple[torch.Tensor, torch.Tensor] +INFERENCE_MAX_LENGTH = 8192 + + +def apply_rotary(query, key, cos, sin): + return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin) + + +class OptimizedFalconRotaryEmbedding(nn.Module): + def __init__(self, head_dim: int, base=10000): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.head_dim = head_dim + self.seq_len_cached = -1 + + self.cuda_graph = None + self.input_surface = None + self.static_outputs = None + + def _optimized_apply_rotary(self, query, key, cos, sin): + if self.cuda_graph is None: + self.cuda_graph = torch.cuda.CUDAGraph() + self.input_surface = (query, key, cos, sin) + + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(3): + apply_rotary(*self.input_surface) + torch.cuda.current_stream().wait_stream(s) + + with torch.cuda.graph(self.cuda_graph): + self.static_outputs = apply_rotary(*self.input_surface) + + inputs = (query, key, cos, sin) + for static_input, data in zip(self.input_surface, inputs): + static_input.copy_(data) + self.cuda_graph.replay() + return tuple(o.detach() for o in self.static_outputs) + + def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor: + total_length = seq_len + past_key_values_length + if self.seq_len_cached == -1: + # warm up the cache + total_length = max(INFERENCE_MAX_LENGTH, total_length) + + if total_length > self.seq_len_cached: + with torch.inference_mode(False): + self.seq_len_cached = total_length + t = torch.arange(total_length, device=device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1).to(device) + + if dtype in [torch.float16, torch.bfloat16]: + emb = emb.float() + + self.register_buffer("cos_cached", emb.cos()[None, :, :].type(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, :, :].type(dtype), persistent=False) + + return ( + self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length].type(dtype), + self.sin_cached[:, past_key_values_length : seq_len + past_key_values_length].type(dtype), + ) + + def forward(self, query, key, past_key_values_length=0): + batch, seq_len, head_dim = query.shape + cos, sin = self.cos_sin(seq_len, past_key_values_length, query.device, query.dtype) + if seq_len == 1 and torch.is_inference_mode_enabled() and query.device.type == "cuda": + return self._optimized_apply_rotary(query, key, cos, sin) + else: + return apply_rotary(query, key, cos, sin) + + +def split_heads( + fused_qkv: torch.Tensor, num_heads: int, num_kv_heads: int, head_dim: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + batch, seq_len, _ = fused_qkv.shape + qkv = fused_qkv.view(batch, seq_len, -1, num_heads // num_kv_heads + 2, head_dim) + query, key, value = torch.split(qkv, [num_heads // num_kv_heads, 1, 1], dim=3) + key = torch.broadcast_to(key, query.shape) + value = torch.broadcast_to(value, query.shape) + + query, key, value = [x.flatten(2, 3) for x in (query, key, value)] + return query, key, value + + +class OptimizedFalconAttention(FalconAttention): + def __init__(self, config: FalconConfig): + nn.Module.__init__(self) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.split_size = self.hidden_size + self.hidden_dropout = config.hidden_dropout + + if self.head_dim * self.num_heads != self.hidden_size: + raise ValueError( + f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:" + f" {self.num_heads})." + ) + + self.maybe_rotary = OptimizedFalconRotaryEmbedding(config.head_dim) if config.rotary else lambda q, k, t: (q, k) + + # Layer-wise attention scaling + self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim) + self.beta = self.inv_norm_factor + if config.new_decoder_architecture: + qkv_out_dim = (config.num_kv_heads * 2 + config.num_attention_heads) * self.head_dim + elif config.multi_query: + qkv_out_dim = self.hidden_size + 2 * self.head_dim + else: + qkv_out_dim = 3 * self.hidden_size + self.query_key_value = FalconLinear(self.hidden_size, qkv_out_dim, bias=config.bias) + self.new_decoder_architecture = config.new_decoder_architecture + self.multi_query = config.multi_query + self.dense = FalconLinear(self.hidden_size, self.hidden_size, bias=config.bias) + self.attention_dropout = nn.Dropout(config.attention_dropout) + self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1 + + if self.new_decoder_architecture: + self._split_heads = partial( + split_heads, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, head_dim=self.head_dim + ) + self.split_graph = None + self.input_surface = None + self.static_outputs = None + + def _optimized_split_heads(self, fused_qkv): + if self.split_graph is None: + self.split_graph = torch.cuda.CUDAGraph() + self.input_surface = fused_qkv + + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(3): + self._split_heads(fused_qkv) + torch.cuda.current_stream().wait_stream(s) + + with torch.cuda.graph(self.split_graph): + self.static_outputs = self._split_heads(self.input_surface) + + self.input_surface.copy_(fused_qkv) + self.split_graph.replay() + return tuple(o.detach() for o in self.static_outputs) + + def forward( + self, + hidden_states: torch.Tensor, + alibi: Optional[torch.Tensor], + attention_mask: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + ): + assert not output_attentions + + fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] + + if ( + self.new_decoder_architecture + and hidden_states.size(1) == 1 + and torch.is_inference_mode_enabled() + and hidden_states.device.type == "cuda" + ): + query_layer, key_layer, value_layer = self._optimized_split_heads(fused_qkv) + else: + # 3 x [batch_size, seq_length, num_heads, head_dim] + (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) + + num_kv_heads = self.num_heads + batch_size, query_length, _, _ = query_layer.shape + + query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim) + key_layer = key_layer.transpose(1, 2).reshape( + batch_size * num_kv_heads, + query_length, + self.head_dim, + ) + value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim) + + past_kv_length = 0 if layer_past is None else layer_past[0].shape[1] + query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length) + + if layer_past is not None: + past_key, past_value = layer_past + # concatenate along seq_length dimension: + # - key: [batch_size * self.num_heads, kv_length, head_dim] + # - value: [batch_size * self.num_heads, kv_length, head_dim] + key_layer = torch.cat((past_key, key_layer), dim=1) + value_layer = torch.cat((past_value, value_layer), dim=1) + + _, kv_length, _ = key_layer.shape + if use_cache: + present = (key_layer, value_layer) + else: + present = None + + query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim) + key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim) + value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim) + + attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype) + + if alibi is None: + attn_output = F.scaled_dot_product_attention( + query_layer_, key_layer_, value_layer_, attn_mask=attention_mask_float, dropout_p=0.0, is_causal=False + ) + + attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim) + attn_output = attn_output.permute(0, 2, 1, 3) + attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) + + output_tensor = self.dense(attn_output) + + return output_tensor, present + else: + matmul_result = query_layer_ @ key_layer_.transpose(-1, -2) + + # change view to [batch_size, num_heads, q_length, kv_length] + attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length) + + # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] + input_dtype = attention_scores.dtype + # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38` + if input_dtype == torch.float16 or input_dtype == torch.bfloat16: + attention_scores = attention_scores.to(torch.float32) + # Matt (HF) note: We could possibly use F.scaled_dot_product_attention here too, by + # adding (alibi * self.inv_norm_factor) to attention_mask_float. I think this would be mathematically + # equivalent and more performant, but there might be a numerical difference. If you're reading this + # and you'd like to experiment and maybe file a PR, feel free! + attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1) + attention_logits *= self.inv_norm_factor + attention_probs = F.softmax(attention_logits + attention_mask_float, dim=-1, dtype=hidden_states.dtype) + # [batch_size, num_heads, q_length, kv_length] + attention_probs = self.attention_dropout(attention_probs) + + if head_mask is not None: + attention_probs = attention_probs * head_mask + + # change view [batch_size, num_heads, q_length, kv_length] + attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length) + + # matmul: [batch_size * num_heads, q_length, head_dim] + context_layer = (attention_probs_reshaped @ value_layer_).flatten(0, 1) + + # change view [batch_size, q_length, num_heads * head_dim] + context_layer = self._merge_heads(context_layer) + + output_tensor = self.dense(context_layer) + + if output_attentions: + return output_tensor, present, attention_probs + else: + return output_tensor, present + + +class OptimizedFalconDecoderLayer(FalconDecoderLayer): + def __init__(self, config: FalconConfig): + nn.Module.__init__(self) + hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + + self.mlp = FalconMLP(config) + self.hidden_dropout = config.hidden_dropout + self.config = config + + self.self_attention = OptimizedFalconAttention(config) + + if self.config.alibi or not config.new_decoder_architecture: + if config.new_decoder_architecture: + # The layer norm before self-attention + self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + # The layer norm before the MLP + self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + else: + self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + if not config.parallel_attn: + self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + else: + self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + self.ln_graph = None + self.static_input = None + self.static_outputs = None + + def _optimized_apply_ln(self, hidden_states): + if self.ln_graph is None: + self.ln_graph = torch.cuda.CUDAGraph() + self.static_input = hidden_states + + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(3): + self.ln_attn(hidden_states) + self.ln_mlp(hidden_states) + torch.cuda.current_stream().wait_stream(s) + + with torch.cuda.graph(self.ln_graph): + ln_attn_output = self.ln_attn(hidden_states) + ln_mlp_output = self.ln_mlp(hidden_states) + self.static_outputs = (ln_attn_output, ln_mlp_output) + + self.static_input.copy_(hidden_states) + self.ln_graph.replay() + return tuple(o.detach() for o in self.static_outputs) + + def forward( + self, + hidden_states: torch.Tensor, + alibi: Optional[torch.Tensor], + attention_mask: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + ): + residual = hidden_states + + if self.config.new_decoder_architecture: + if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda": + attention_layernorm_out, mlp_layernorm_out = self._optimized_apply_ln(hidden_states) + else: + attention_layernorm_out = self.ln_attn(hidden_states) + mlp_layernorm_out = self.ln_mlp(hidden_states) + else: + attention_layernorm_out = self.input_layernorm(hidden_states) + + attn_outputs = self.self_attention( + attention_layernorm_out, + layer_past=layer_past, + attention_mask=attention_mask, + alibi=alibi, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + attention_output = attn_outputs[0] + + if not self.config.new_decoder_architecture: + if self.config.parallel_attn: + mlp_layernorm_out = attention_layernorm_out + else: + residual = dropout_add( + attention_output, residual, self.config.attention_dropout, training=self.training + ) + mlp_layernorm_out = self.post_attention_layernorm(residual) + + outputs = attn_outputs[1:] + + mlp_output = self.mlp(mlp_layernorm_out) + + if self.config.new_decoder_architecture or self.config.parallel_attn: + mlp_output += attention_output + + output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training) + + if use_cache: + outputs = (output,) + outputs + else: + outputs = (output,) + outputs[1:] + + return outputs # hidden_states, present, attentions + + +class WrappedFalconBlock(OptimizedFalconDecoderLayer): + def forward( + self, + hidden_states: torch.Tensor, + *args, + attention_mask: Optional[torch.Tensor] = None, + alibi: Optional[torch.Tensor] = None, + layer_past: Optional[KVCache] = None, + use_cache: bool = False, + **kwargs, + ): + assert attention_mask is None + + batch_size, seq_length = hidden_states.shape[:2] + + if layer_past is not None: + layer_past = self._reorder_cache_from_bloom_to_falcon(layer_past) + past_length = 0 if layer_past is None else layer_past[0].shape[1] + seq_length_with_past = seq_length + past_length + + attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) + if alibi is None and self.config.alibi: + alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype) + attention_mask = FalconModel._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length) + + outputs = super().forward( + hidden_states, + *args, + attention_mask=attention_mask, + alibi=alibi, + layer_past=layer_past, + use_cache=use_cache, + **kwargs, + ) + + if use_cache: + present_key_value = outputs[-1] + present_key_value = self._reorder_cache_from_falcon_to_bloom(present_key_value) + outputs = outputs[:-1] + (present_key_value,) + + return outputs + + def _reorder_cache_from_bloom_to_falcon(self, key_value: KVCache) -> KVCache: + key_states, value_states = key_value + + key_states = key_states.permute(0, 2, 1) + assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim] + + if self.config.new_decoder_architecture: + key_states = self._expand_states(key_states) + value_states = self._expand_states(value_states) + + return (key_states, value_states) + + def _reorder_cache_from_falcon_to_bloom(self, key_value: KVCache) -> KVCache: + key_states, value_states = key_value + + if self.config.new_decoder_architecture: + key_states = self._collapse_states(key_states) + value_states = self._collapse_states(value_states) + + assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim] + key_states = key_states.permute(0, 2, 1) + + return (key_states, value_states) + + def _expand_states(self, state: torch.Tensor) -> torch.Tensor: + batch_size_x_num_kv_heads, seq_len, head_dim = state.shape + batch_size = batch_size_x_num_kv_heads // self.config.num_kv_heads + + state = state.view(batch_size, self.config.num_kv_heads, 1, seq_len, head_dim) + state = state.expand(-1, -1, self.config.num_key_value_groups, -1, -1) # No copy + state = state.reshape(batch_size * self.config.num_attention_heads, seq_len, head_dim) # Involves a copy + return state + + def _collapse_states(self, state: torch.Tensor) -> torch.Tensor: + batch_size_x_num_attn_heads, seq_len, head_dim = state.shape + batch_size = batch_size_x_num_attn_heads // self.config.num_attention_heads + + state = state.view(batch_size, self.config.num_kv_heads, self.config.num_key_value_groups, seq_len, head_dim) + state = state[:, :, 0] + state = state.view(batch_size * self.config.num_kv_heads, seq_len, head_dim) + return state diff --git a/src/petals/models/falcon/config.py b/src/petals/models/falcon/config.py new file mode 100644 index 0000000..a1ae5e9 --- /dev/null +++ b/src/petals/models/falcon/config.py @@ -0,0 +1,45 @@ +import os +from typing import Optional, Union + +from hivemind import get_logger +from transformers.models.falcon import FalconConfig +from transformers.models.falcon.modeling_falcon import FalconAttention + +from petals.client.config import ClientConfig +from petals.client.lm_head import LMHeadConfig +from petals.client.ptune import PTuneConfig +from petals.models.falcon.block import WrappedFalconBlock +from petals.utils.auto_config import DefaultRevisionMixin + +logger = get_logger(__name__) + + +class DistributedFalconConfig(DefaultRevisionMixin, FalconConfig, ClientConfig, PTuneConfig, LMHeadConfig): + block_class = WrappedFalconBlock + attn_class = FalconAttention + block_prefix = "transformer.h" + + @property + def num_key_value_groups(self) -> int: + if self.new_decoder_architecture: + return self.num_attention_heads // self.num_kv_heads + if self.multi_query: + return self.num_attention_heads + return 1 + + @classmethod + def from_pretrained( + cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs + ): + loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path) + if loading_from_repo and dht_prefix is None: + dht_prefix = str(model_name_or_path) + dht_prefix = dht_prefix.split("/")[-1] # Use only repo name to merge blocks hosted by different accounts + dht_prefix = dht_prefix.replace(".", "-") + logger.info(f"Using DHT prefix: {dht_prefix}") + + result = super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs) + config = result[0] if isinstance(result, tuple) else result + if config.pad_token_id is None: + config.pad_token_id = 0 + return result diff --git a/src/petals/models/falcon/model.py b/src/petals/models/falcon/model.py new file mode 100644 index 0000000..32c0b6f --- /dev/null +++ b/src/petals/models/falcon/model.py @@ -0,0 +1,150 @@ +from typing import Optional + +import hivemind +import torch +import torch.nn as nn +from hivemind.utils.logging import get_logger +from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions +from transformers.models.falcon import ( + FalconForCausalLM, + FalconForSequenceClassification, + FalconModel, + FalconPreTrainedModel, +) + +from petals.client.from_pretrained import FromPretrainedMixin +from petals.client.lm_head import LMHead +from petals.client.ptune import PTuneMixin +from petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues +from petals.client.remote_sequential import RemoteSequential +from petals.models.falcon.config import DistributedFalconConfig +from petals.utils.auto_config import DefaultRevisionMixin + +logger = get_logger(__name__) + + +class DistributedFalconModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMixin, FalconModel): + """FalconModel, but all transformer layers are hosted by the swarm""" + + _keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing + _keys_to_ignore_on_load_unexpected = [r"^transformer\.h\."] + + config_class = DistributedFalconConfig + + def __init__(self, config: DistributedFalconConfig, *, dht: Optional[hivemind.DHT] = None): + n_layer, config.num_hidden_layers = config.num_hidden_layers, 0 # Prevent initialization + super().__init__(config) + assert len(self.h) == 0 + config.num_hidden_layers = n_layer + + self.h = RemoteSequential(config, dht=dht) + + self.requires_grad_(False) # Forbid accumulate grads for embeddings and layernorm + self.init_prompts(config) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[RemotePastKeyValues] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + # The causal mask will be added on the server-side + assert ( + attention_mask is None or (attention_mask == 1).all() + ), f"Custom attention masks are not supported, {attention_mask=}" + assert head_mask is None, f"Custom head masks are not supported, {head_mask=}" + assert use_cache is None or use_cache, f"{use_cache=} is not supported" + assert not output_attentions, f"{output_attentions=} is not supported" + assert not output_hidden_states, f"{output_hidden_states=} is not supported" + assert return_dict is None or return_dict, f"{return_dict=} is not supported" + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + use_prompts = self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0 + if use_prompts: + batch_size = inputs_embeds.shape[0] + prompts, intermediate_prompts = self.get_prompt(batch_size) + inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1) + else: + prompts = intermediate_prompts = None + + hidden_states = self.word_embeddings_layernorm(inputs_embeds) + output_shape = input_shape + (hidden_states.size(-1),) + + hidden_states = self.h( + hidden_states, + prompts=intermediate_prompts, + hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None, + ) + + # Remove prefix + if use_prompts: + hidden_states = hidden_states[:, self.pre_seq_len :] + + # Add last hidden state + hidden_states = self.ln_f(hidden_states) + hidden_states = hidden_states.view(output_shape) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=RemotePastKeyValues(), + hidden_states=None, + attentions=None, + ) + + @property + def word_embeddings_layernorm(self) -> nn.Module: # For compatibility with RemoteGenerationMixin + return nn.Identity() + + +class DistributedFalconForCausalLM(DefaultRevisionMixin, FromPretrainedMixin, RemoteGenerationMixin, FalconForCausalLM): + _keys_to_ignore_on_load_missing = DistributedFalconModel._keys_to_ignore_on_load_missing + _keys_to_ignore_on_load_unexpected = DistributedFalconModel._keys_to_ignore_on_load_unexpected + + config_class = DistributedFalconConfig + + def __init__(self, config: DistributedFalconConfig): + FalconPreTrainedModel.__init__(self, config) + self.transformer = DistributedFalconModel(config) + self.lm_head = LMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head + + +class DistributedFalconForSequenceClassification( + DefaultRevisionMixin, FromPretrainedMixin, FalconForSequenceClassification +): + _keys_to_ignore_on_load_missing = DistributedFalconModel._keys_to_ignore_on_load_missing + _keys_to_ignore_on_load_unexpected = DistributedFalconModel._keys_to_ignore_on_load_unexpected + + config_class = DistributedFalconConfig + + def __init__(self, config: DistributedFalconConfig): + FalconPreTrainedModel.__init__(self, config) + self.num_labels = config.num_labels + + self.transformer = DistributedFalconModel(config) + self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() diff --git a/src/petals/models/llama/config.py b/src/petals/models/llama/config.py index 43a5843..ae71a4c 100644 --- a/src/petals/models/llama/config.py +++ b/src/petals/models/llama/config.py @@ -43,4 +43,5 @@ class DistributedLlamaConfig(LlamaConfig, ClientConfig, PTuneConfig, LMHeadConfi result = super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs) config = result[0] if isinstance(result, tuple) else result config.pretraining_tp = 1 # This may give less accurate results but it doesn't matter if we use quantization + config.use_cache = True # use_cache=False leads to identical results but is slower and not supported by Petals return result diff --git a/src/petals/models/llama/model.py b/src/petals/models/llama/model.py index a9dfcc1..3360f40 100644 --- a/src/petals/models/llama/model.py +++ b/src/petals/models/llama/model.py @@ -73,7 +73,8 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel): if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.layers.position == 0: + use_prompts = self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.layers.position == 0 + if use_prompts: batch_size = inputs_embeds.shape[0] prompts, intermediate_prompts = self.get_prompt(batch_size) inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1) @@ -90,7 +91,7 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel): ) # Remove prefix - if self.config.tuning_mode and "ptune" in self.config.tuning_mode: + if use_prompts: hidden_states = hidden_states[:, self.pre_seq_len :] # Add last hidden state diff --git a/src/petals/server/server.py b/src/petals/server/server.py index ab646a5..fd9f766 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -5,6 +5,7 @@ import math import multiprocessing as mp import os import random +import sys import threading import time from typing import Dict, List, Optional, Sequence, Union @@ -186,10 +187,7 @@ class Server: check_device_balance(self.tensor_parallel_devices) if quant_type is None: - if device.type == "cuda": - quant_type = QuantType.NF4 if self.block_config.model_type == "llama" else QuantType.INT8 - else: - quant_type = QuantType.NONE + quant_type = QuantType.NF4 if device.type == "cuda" else QuantType.NONE self.quant_type = quant_type logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, quant_type)} format") @@ -234,8 +232,9 @@ class Server: self.attn_cache_bytes = self._cache_bytes_per_block * num_blocks logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB") - assert isinstance(throughput, float) or throughput in ["auto", "eval"] - if throughput in ["auto", "eval"]: + assert isinstance(throughput, float) or throughput in ["auto", "eval", "dry_run"] + if throughput in ["auto", "eval", "dry_run"]: + force_eval = throughput in ["eval", "dry_run"] throughput_info = get_server_throughput( converted_model_name_or_path, self.block_config, @@ -245,9 +244,12 @@ class Server: quant_type=quant_type, tensor_parallel_devices=self.tensor_parallel_devices, reachable_via_relay=reachable_via_relay, - force_eval=(throughput == "eval"), + force_eval=force_eval, cache_dir=cache_dir, ) + if throughput == "dry_run": + logger.info("Finished estimating throughput, exiting") + sys.exit(0) else: throughput_info = {"throughput": throughput} self.server_info = ServerInfo( diff --git a/src/petals/utils/auto_config.py b/src/petals/utils/auto_config.py index 70f37a3..0cec83d 100644 --- a/src/petals/utils/auto_config.py +++ b/src/petals/utils/auto_config.py @@ -1,12 +1,14 @@ import os -import re from dataclasses import dataclass from typing import Optional, Type, Union +from hivemind import get_logger from transformers import AutoConfig, PretrainedConfig, PreTrainedModel from petals.utils.hf_auth import always_needs_auth +logger = get_logger(__name__) + @dataclass class _ModelClasses: @@ -49,17 +51,44 @@ class _AutoDistributedBase: return proper_cls.from_pretrained(model_name_or_path, *args, **kwargs) -class AutoDistributedConfig(_AutoDistributedBase): +class DefaultRevisionMixin: + """ + Petals only supports Falcon loaded in the new in-library format (transformers.FalconModel). + TII models were recently converted to this format but then reverted back due to compatibility issues. + We chose to support only the new format since HF staff promised to eventually convert these models + to the new format again, see https://huggingface.co/tiiuae/falcon-40b/discussions/90#64b4d23bf44fd957492f7602 + Until it happens, we override the default `main` revision for the TII repos with the commit + pointing out to the model in the in-library format. + """ + + DEFAULT_REVISIONS = { + "tiiuae/falcon-40b": "f1ba7d328c06aa6fbb4a8afd3c756f46d7e6b232", + "tiiuae/falcon-40b-instruct": "7475ff8cfc36ed9a962b658ae3c33391566a85a5", + "tiiuae/falcon-7b": "4e2d06f0a7c6370ebabbc30c6f59377ae8f73d76", + "tiiuae/falcon-7b-instruct": "f8dac3fff96d5debd43edf56fb4e1abcfffbef28", + } + + @classmethod + def from_pretrained( + cls, model_name_or_path: Union[str, os.PathLike, None], *args, revision: Optional[str] = None, **kwargs + ): + if revision is None and model_name_or_path in cls.DEFAULT_REVISIONS: + revision = cls.DEFAULT_REVISIONS[model_name_or_path] + logger.info(f"Loading {model_name_or_path}, revision {revision}") + return super().from_pretrained(model_name_or_path, *args, revision=revision, **kwargs) + + +class AutoDistributedConfig(DefaultRevisionMixin, _AutoDistributedBase): _mapping_field = "config" -class AutoDistributedModel(_AutoDistributedBase): +class AutoDistributedModel(DefaultRevisionMixin, _AutoDistributedBase): _mapping_field = "model" -class AutoDistributedModelForCausalLM(_AutoDistributedBase): +class AutoDistributedModelForCausalLM(DefaultRevisionMixin, _AutoDistributedBase): _mapping_field = "model_for_causal_lm" -class AutoDistributedModelForSequenceClassification(_AutoDistributedBase): +class AutoDistributedModelForSequenceClassification(DefaultRevisionMixin, _AutoDistributedBase): _mapping_field = "model_for_sequence_classification" diff --git a/tests/test_optimized_layers.py b/tests/test_optimized_layers.py new file mode 100644 index 0000000..5baa1a2 --- /dev/null +++ b/tests/test_optimized_layers.py @@ -0,0 +1,128 @@ +from typing import Optional, Tuple + +import pytest +import torch +from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor + +from petals.utils.auto_config import AutoDistributedConfig +from petals.utils.convert_block import QuantType, convert_block +from test_utils import MODEL_NAME + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class UnoptimizedWrappedFalconBlock(FalconDecoderLayer): + def forward( + self, + hidden_states: torch.Tensor, + *args, + attention_mask: Optional[torch.Tensor] = None, + alibi: Optional[torch.Tensor] = None, + layer_past: Optional[KVCache] = None, + use_cache: bool = False, + **kwargs, + ): + batch_size, seq_length = hidden_states.shape[:2] + + if layer_past is not None: + layer_past = self._reorder_cache_from_bloom_to_falcon(layer_past) + past_length = 0 if layer_past is None else layer_past[0].shape[1] + seq_length_with_past = seq_length + past_length + + attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) + if alibi is None and self.config.alibi: + alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype) + attention_mask = FalconModel._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length) + + outputs = super().forward( + hidden_states, + *args, + attention_mask=attention_mask, + alibi=alibi, + layer_past=layer_past, + use_cache=use_cache, + **kwargs, + ) + + if use_cache: + present_key_value = outputs[-1] + present_key_value = self._reorder_cache_from_falcon_to_bloom(present_key_value) + outputs = outputs[:-1] + (present_key_value,) + + return outputs + + def _reorder_cache_from_bloom_to_falcon(self, key_value: KVCache) -> KVCache: + key_states, value_states = key_value + + key_states = key_states.permute(0, 2, 1) + assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim] + + if self.config.new_decoder_architecture: + key_states = self._expand_states(key_states) + value_states = self._expand_states(value_states) + + return (key_states, value_states) + + def _reorder_cache_from_falcon_to_bloom(self, key_value: KVCache) -> KVCache: + key_states, value_states = key_value + + if self.config.new_decoder_architecture: + key_states = self._collapse_states(key_states) + value_states = self._collapse_states(value_states) + + assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim] + key_states = key_states.permute(0, 2, 1) + + return (key_states, value_states) + + def _expand_states(self, state: torch.Tensor) -> torch.Tensor: + batch_size_x_num_kv_heads, seq_len, head_dim = state.shape + batch_size = batch_size_x_num_kv_heads // self.config.num_kv_heads + + state = state.view(batch_size, self.config.num_kv_heads, 1, seq_len, head_dim) + state = state.expand(-1, -1, self.config.num_key_value_groups, -1, -1) # No copy + state = state.reshape(batch_size * self.config.num_attention_heads, seq_len, head_dim) # Involves a copy + return state + + def _collapse_states(self, state: torch.Tensor) -> torch.Tensor: + batch_size_x_num_attn_heads, seq_len, head_dim = state.shape + batch_size = batch_size_x_num_attn_heads // self.config.num_attention_heads + + state = state.view(batch_size, self.config.num_kv_heads, self.config.num_key_value_groups, seq_len, head_dim) + state = state[:, :, 0] + state = state.view(batch_size * self.config.num_kv_heads, seq_len, head_dim) + return state + + +@pytest.mark.skipif("falcon" not in MODEL_NAME, reason="This test is applicable only to Falcon models") +@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +@pytest.mark.forked +def test_falcon(device): + if device == "cuda:0" and not torch.cuda.is_available(): + pytest.skip("CUDA tests can be run only in CUDA-enabled setups") + + config = AutoDistributedConfig.from_pretrained(MODEL_NAME) + + tensor_parallel_devices = (device,) + dtype = torch.bfloat16 + quant_type = QuantType.NONE + + block = config.block_class(config).to(dtype) + block = convert_block(block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True) + + unopt_block = UnoptimizedWrappedFalconBlock(config).to(dtype) + unopt_block = convert_block( + unopt_block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True + ) + + unopt_block.load_state_dict(block.state_dict()) + cache = unopt_cache = None + + with torch.inference_mode(): + for length in [10, 1, 1, 1]: + dummy_input = torch.randn(1, length, config.hidden_size, device=device, dtype=dtype) + block_output, cache = block(dummy_input, layer_past=cache, use_cache=True) + unopt_block_output, unopt_cache = unopt_block(dummy_input, layer_past=unopt_cache, use_cache=True) + assert torch.allclose(block_output, unopt_block_output, atol=1e-6, rtol=0), length + assert torch.allclose(cache[0], unopt_cache[0], atol=1e-6, rtol=0), length + assert torch.allclose(cache[1], unopt_cache[1], atol=1e-6, rtol=0), length