diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index 74b731d..1e76235 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -14,8 +14,6 @@ jobs: - { model: 'bigscience/bloom-560m', os: 'ubuntu', python-version: '3.11' } - { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.8' } - { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.11' } - - { model: 'petals-team/falcon-rw-1b', os: 'ubuntu', python-version: '3.8' } - - { model: 'petals-team/falcon-rw-1b', os: 'ubuntu', python-version: '3.11' } - { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.10' } - { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.11' } fail-fast: false diff --git a/setup.cfg b/setup.cfg index ef35f84..b06dd5c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -34,10 +34,10 @@ python_requires = >=3.8 install_requires = torch>=1.12 bitsandbytes==0.41.1 - accelerate>=0.22.0 + accelerate>=0.27.2 huggingface-hub>=0.11.1,<1.0.0 tokenizers>=0.13.3 - transformers>=4.32.0,<4.35.0 # if you change this, please also change version assert in petals/__init__.py + transformers==4.37.1 # if you change this, please also change version assert in petals/__init__.py speedtest-cli==2.1.3 pydantic>=1.10,<2.0 # 2.0 is incompatible with hivemind yet hivemind==1.1.10.post2 diff --git a/src/petals/__init__.py b/src/petals/__init__.py index 1af8bf9..fd38936 100644 --- a/src/petals/__init__.py +++ b/src/petals/__init__.py @@ -17,13 +17,13 @@ from petals.models import * from petals.utils import * from petals.utils.logging import initialize_logs as _initialize_logs -__version__ = "2.3.0.dev1" +__version__ = "2.3.0.dev2" if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"): assert ( - version.parse("4.32.0") <= version.parse(transformers.__version__) < version.parse("4.35.0") - ), "Please install a proper transformers version: pip install transformers>=4.32.0,<4.35.0" + version.parse("4.37.1") <= version.parse(transformers.__version__) < version.parse("4.38.0") + ), "Please install a proper transformers version: pip install transformers>=4.37.1,<4.38.0" def _override_bfloat16_mode_default(): diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index 34d24c7..0938df2 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -211,6 +211,7 @@ class InferenceSession: self._position = 0 self._max_length = max_length self.output_ids = None + self.past_key_values = None @property def num_blocks(self) -> int: diff --git a/src/petals/client/remote_generation.py b/src/petals/client/remote_generation.py index 97a115a..0249c0b 100644 --- a/src/petals/client/remote_generation.py +++ b/src/petals/client/remote_generation.py @@ -1,11 +1,13 @@ import contextlib import dataclasses from contextvars import ContextVar -from typing import ContextManager, List, Optional +from typing import Any, ContextManager, Dict, List, Optional, Tuple import torch import transformers from hivemind.utils.logging import get_logger +from torch import Tensor +from transformers.cache_utils import Cache, DynamicCache from transformers.generation.utils import ModelOutput from petals.client.inference_session import InferenceSession @@ -15,15 +17,29 @@ from petals.utils.misc import DUMMY, docstring_from logger = get_logger(__name__) -@dataclasses.dataclass(frozen=True) -class RemotePastKeyValues: - """A mock class representing the fact that `past_key_values` do exist but are stored on remote servers.""" +class RemotePastKeyValues(Cache): + """only keeps the number of seen tokens. pretends to be a legit cache""" - hypo_ids: Optional[torch.LongTensor] = None + def __init__(self) -> None: + super().__init__() + self.seen_tokens = 0 + self.hypo_ids: Optional[torch.LongTensor] = None def __getitem__(self, _index: int) -> List[torch.Tensor]: return [DUMMY] # For compatibility with BloomForCausalLM.prepare_inputs_for_generation() + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + return self.seen_tokens + + def get_max_length(self) -> Optional[int]: + return None + + def update_seen(self, new_seen: int) -> None: + self.seen_tokens += new_seen + + def reorder_cache(self, beam_idx): + pass + _skipped_tokens = ContextVar("skipped_tokens", default=0) @@ -113,6 +129,11 @@ class RemoteGenerationMixin(_SkipTokensMixin): # but keep them for transformers.GenerationMixin (e.g., to compute repetition_penalty) _skipped_tokens.set(max(0, n_prev_tokens - 1)) + if self._supports_cache_class and "past_key_values" not in kwargs: + past_key_values = RemotePastKeyValues() + past_key_values.update_seen(session.position) + kwargs["past_key_values"] = past_key_values + result = super().generate(inputs, *args, **kwargs) sequences = result.sequences if isinstance(result, ModelOutput) else result diff --git a/src/petals/models/bloom/block.py b/src/petals/models/bloom/block.py index f246bd8..86fc4aa 100644 --- a/src/petals/models/bloom/block.py +++ b/src/petals/models/bloom/block.py @@ -6,6 +6,7 @@ See commit history for authorship. from typing import Optional, Tuple import torch +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel, build_alibi_tensor @@ -26,7 +27,13 @@ class WrappedBloomBlock(BloomBlock): attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) if alibi is None: alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype) - attention_mask = BloomModel._prepare_attn_mask(None, attention_mask, (batch_size, seq_length), past_length) + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask=attention_mask, + input_shape=(batch_size, seq_length), + inputs_embeds=hidden_states, + past_key_values_length=past_length, + ) + attention_mask = attention_mask.bool() return super().forward( hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs ) diff --git a/src/petals/models/bloom/model.py b/src/petals/models/bloom/model.py index 784418f..67d2f35 100644 --- a/src/petals/models/bloom/model.py +++ b/src/petals/models/bloom/model.py @@ -4,6 +4,7 @@ import hivemind import torch import torch.nn as nn from hivemind.utils.logging import get_logger +from transformers.cache_utils import Cache from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions from transformers.models.bloom import BloomForCausalLM, BloomForSequenceClassification, BloomModel, BloomPreTrainedModel @@ -92,12 +93,16 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel): if use_prompts: hidden_states = hidden_states[:, self.pre_seq_len :] + if past_key_values is None: + past_key_values = RemotePastKeyValues() + past_key_values.update_seen(hidden_states.size(1)) + # Add last hidden state hidden_states = self.ln_f(hidden_states) hidden_states = hidden_states.view(output_shape) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=RemotePastKeyValues(), + past_key_values=past_key_values, hidden_states=None, attentions=None, ) @@ -107,6 +112,7 @@ class DistributedBloomForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, Bl _keys_to_ignore_on_load_missing = DistributedBloomModel._keys_to_ignore_on_load_missing _keys_to_ignore_on_load_missing += [r"^lm_head\."] # Missing since they are shared with input embeddings _keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected + _supports_cache_class = True config_class = DistributedBloomConfig @@ -118,6 +124,58 @@ class DistributedBloomForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, Bl # Initialize weights and apply final processing self.post_init() + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ) -> dict: + # Omit tokens covered by past_key_values + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + def _temporary_reorder_cache(self, past_key_values, beam_idx): + return past_key_values + def get_output_embeddings(self): return self.lm_head diff --git a/src/petals/models/llama/block.py b/src/petals/models/llama/block.py index a8d433d..6f539a8 100644 --- a/src/petals/models/llama/block.py +++ b/src/petals/models/llama/block.py @@ -9,6 +9,7 @@ from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.models.llama.modeling_llama import ( LlamaAttention, LlamaConfig, @@ -84,8 +85,8 @@ class OptimizedLlamaAttention(LlamaAttention): if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - cos = cos[:, :, kv_seq_len - q_len :] - sin = sin[:, :, kv_seq_len - q_len :] + cos = cos[kv_seq_len - q_len :] + sin = sin[kv_seq_len - q_len :] if q_len == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda": query_states, key_states = self._optimized_apply_rotary(query_states, key_states, cos, sin) @@ -244,8 +245,11 @@ class WrappedLlamaBlock(OptimizedLlamaDecoderLayer): attention_mask = torch.ones( (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device ) - attention_mask = LlamaModel._prepare_decoder_attention_mask( - None, attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask=attention_mask, + input_shape=(batch_size, seq_length), + inputs_embeds=hidden_states, + past_key_values_length=past_key_values_length, ) outputs = super().forward( diff --git a/src/petals/models/llama/model.py b/src/petals/models/llama/model.py index 3360f40..611bb2b 100644 --- a/src/petals/models/llama/model.py +++ b/src/petals/models/llama/model.py @@ -90,6 +90,10 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel): hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None, ) + if past_key_values is None: + past_key_values = RemotePastKeyValues() + past_key_values.update_seen(hidden_states.size(1)) + # Remove prefix if use_prompts: hidden_states = hidden_states[:, self.pre_seq_len :] @@ -97,9 +101,10 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel): # Add last hidden state hidden_states = self.norm(hidden_states) hidden_states = hidden_states.view(output_shape) + return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=RemotePastKeyValues(), + past_key_values=past_key_values, hidden_states=None, attentions=None, ) diff --git a/src/petals/utils/peft.py b/src/petals/utils/peft.py index e4d29fc..5d285ef 100644 --- a/src/petals/utils/peft.py +++ b/src/petals/utils/peft.py @@ -26,9 +26,7 @@ logger = get_logger(__name__) def check_peft_repository(repo_id: str) -> bool: - fs = HfFileSystem() - list_of_files = fs.glob(f"{repo_id}/{SAFETENSORS_WEIGHTS_NAME}", detail=False) - return len(list_of_files) > 0 + return HfFileSystem().exists(f"{repo_id}/{SAFETENSORS_WEIGHTS_NAME}") def load_specific_module(block_idx: int, filepath: str, framework: str = "pt", device: Optional[int] = None): diff --git a/tests/test_optimized_layers.py b/tests/test_optimized_layers.py index 84cbfff..70f763e 100644 --- a/tests/test_optimized_layers.py +++ b/tests/test_optimized_layers.py @@ -2,6 +2,8 @@ from typing import Optional, Tuple import pytest import torch +from transformers.cache_utils import DynamicCache +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel @@ -116,6 +118,8 @@ class UnoptimizedWrappedLlamaBlock(LlamaDecoderLayer): past_key_values_length = past_key_value[0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length) + elif use_cache: + past_key_value = DynamicCache() if position_ids is None: device = hidden_states.device @@ -131,8 +135,9 @@ class UnoptimizedWrappedLlamaBlock(LlamaDecoderLayer): attention_mask = torch.ones( (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device ) - attention_mask = LlamaModel._prepare_decoder_attention_mask( - None, attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length + + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length ) outputs = super().forward( @@ -156,19 +161,20 @@ class UnoptimizedWrappedLlamaBlock(LlamaDecoderLayer): def _reorder_cache_from_bloom_to_llama( self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int - ) -> Tuple[torch.Tensor]: + ) -> DynamicCache: key_states, value_states = key_value key_states = key_states.permute(0, 2, 1) key_states = key_states.view( batch_size, self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim ) value_states = value_states.view(*key_states.shape) - return (key_states, value_states) + past_key_values = ((key_states, value_states),) + return DynamicCache.from_legacy_cache(past_key_values) def _reorder_cache_from_llama_to_bloom( - self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int + self, key_value: DynamicCache, batch_size: int, seq_length: int ) -> Tuple[torch.Tensor]: - key_states, value_states = key_value + key_states, value_states = key_value.to_legacy_cache()[0] value_states = value_states.view( batch_size * self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim ) @@ -195,7 +201,7 @@ def test_optimized_block(device): if config.model_type == "falcon": unopt_block = UnoptimizedWrappedFalconBlock(config).to(dtype) elif config.model_type == "llama": - unopt_block = UnoptimizedWrappedLlamaBlock(config).to(dtype) + unopt_block = UnoptimizedWrappedLlamaBlock(config, layer_idx=0).to(dtype) else: pytest.skip(f"This test is not applicable to {config.model_type} models")