Bump transformers and accelerate versions (#554)

Bump versions for transformers and accelerate, remove falcon-rw-1b CI tests
pull/558/head
Denis Mazur 2 months ago committed by GitHub
parent d59c15c578
commit 0d91bbdac3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -14,8 +14,6 @@ jobs:
- { model: 'bigscience/bloom-560m', os: 'ubuntu', python-version: '3.11' }
- { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.8' }
- { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.11' }
- { model: 'petals-team/falcon-rw-1b', os: 'ubuntu', python-version: '3.8' }
- { model: 'petals-team/falcon-rw-1b', os: 'ubuntu', python-version: '3.11' }
- { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.10' }
- { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.11' }
fail-fast: false

@ -34,10 +34,10 @@ python_requires = >=3.8
install_requires =
torch>=1.12
bitsandbytes==0.41.1
accelerate>=0.22.0
accelerate>=0.27.2
huggingface-hub>=0.11.1,<1.0.0
tokenizers>=0.13.3
transformers>=4.32.0,<4.35.0 # if you change this, please also change version assert in petals/__init__.py
transformers==4.37.1 # if you change this, please also change version assert in petals/__init__.py
speedtest-cli==2.1.3
pydantic>=1.10,<2.0 # 2.0 is incompatible with hivemind yet
hivemind==1.1.10.post2

@ -17,13 +17,13 @@ from petals.models import *
from petals.utils import *
from petals.utils.logging import initialize_logs as _initialize_logs
__version__ = "2.3.0.dev1"
__version__ = "2.3.0.dev2"
if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
assert (
version.parse("4.32.0") <= version.parse(transformers.__version__) < version.parse("4.35.0")
), "Please install a proper transformers version: pip install transformers>=4.32.0,<4.35.0"
version.parse("4.37.1") <= version.parse(transformers.__version__) < version.parse("4.38.0")
), "Please install a proper transformers version: pip install transformers>=4.37.1,<4.38.0"
def _override_bfloat16_mode_default():

@ -211,6 +211,7 @@ class InferenceSession:
self._position = 0
self._max_length = max_length
self.output_ids = None
self.past_key_values = None
@property
def num_blocks(self) -> int:

@ -1,11 +1,13 @@
import contextlib
import dataclasses
from contextvars import ContextVar
from typing import ContextManager, List, Optional
from typing import Any, ContextManager, Dict, List, Optional, Tuple
import torch
import transformers
from hivemind.utils.logging import get_logger
from torch import Tensor
from transformers.cache_utils import Cache, DynamicCache
from transformers.generation.utils import ModelOutput
from petals.client.inference_session import InferenceSession
@ -15,15 +17,29 @@ from petals.utils.misc import DUMMY, docstring_from
logger = get_logger(__name__)
@dataclasses.dataclass(frozen=True)
class RemotePastKeyValues:
"""A mock class representing the fact that `past_key_values` do exist but are stored on remote servers."""
class RemotePastKeyValues(Cache):
"""only keeps the number of seen tokens. pretends to be a legit cache"""
hypo_ids: Optional[torch.LongTensor] = None
def __init__(self) -> None:
super().__init__()
self.seen_tokens = 0
self.hypo_ids: Optional[torch.LongTensor] = None
def __getitem__(self, _index: int) -> List[torch.Tensor]:
return [DUMMY] # For compatibility with BloomForCausalLM.prepare_inputs_for_generation()
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
return self.seen_tokens
def get_max_length(self) -> Optional[int]:
return None
def update_seen(self, new_seen: int) -> None:
self.seen_tokens += new_seen
def reorder_cache(self, beam_idx):
pass
_skipped_tokens = ContextVar("skipped_tokens", default=0)
@ -113,6 +129,11 @@ class RemoteGenerationMixin(_SkipTokensMixin):
# but keep them for transformers.GenerationMixin (e.g., to compute repetition_penalty)
_skipped_tokens.set(max(0, n_prev_tokens - 1))
if self._supports_cache_class and "past_key_values" not in kwargs:
past_key_values = RemotePastKeyValues()
past_key_values.update_seen(session.position)
kwargs["past_key_values"] = past_key_values
result = super().generate(inputs, *args, **kwargs)
sequences = result.sequences if isinstance(result, ModelOutput) else result

@ -6,6 +6,7 @@ See commit history for authorship.
from typing import Optional, Tuple
import torch
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel, build_alibi_tensor
@ -26,7 +27,13 @@ class WrappedBloomBlock(BloomBlock):
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
if alibi is None:
alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
attention_mask = BloomModel._prepare_attn_mask(None, attention_mask, (batch_size, seq_length), past_length)
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask=attention_mask,
input_shape=(batch_size, seq_length),
inputs_embeds=hidden_states,
past_key_values_length=past_length,
)
attention_mask = attention_mask.bool()
return super().forward(
hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs
)

@ -4,6 +4,7 @@ import hivemind
import torch
import torch.nn as nn
from hivemind.utils.logging import get_logger
from transformers.cache_utils import Cache
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from transformers.models.bloom import BloomForCausalLM, BloomForSequenceClassification, BloomModel, BloomPreTrainedModel
@ -92,12 +93,16 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
if use_prompts:
hidden_states = hidden_states[:, self.pre_seq_len :]
if past_key_values is None:
past_key_values = RemotePastKeyValues()
past_key_values.update_seen(hidden_states.size(1))
# Add last hidden state
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=RemotePastKeyValues(),
past_key_values=past_key_values,
hidden_states=None,
attentions=None,
)
@ -107,6 +112,7 @@ class DistributedBloomForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, Bl
_keys_to_ignore_on_load_missing = DistributedBloomModel._keys_to_ignore_on_load_missing
_keys_to_ignore_on_load_missing += [r"^lm_head\."] # Missing since they are shared with input embeddings
_keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected
_supports_cache_class = True
config_class = DistributedBloomConfig
@ -118,6 +124,58 @@ class DistributedBloomForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, Bl
# Initialize weights and apply final processing
self.post_init()
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
) -> dict:
# Omit tokens covered by past_key_values
if past_key_values is not None:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length()
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
if (
max_cache_length is not None
and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs
def _temporary_reorder_cache(self, past_key_values, beam_idx):
return past_key_values
def get_output_embeddings(self):
return self.lm_head

@ -9,6 +9,7 @@ from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaConfig,
@ -84,8 +85,8 @@ class OptimizedLlamaAttention(LlamaAttention):
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
cos = cos[:, :, kv_seq_len - q_len :]
sin = sin[:, :, kv_seq_len - q_len :]
cos = cos[kv_seq_len - q_len :]
sin = sin[kv_seq_len - q_len :]
if q_len == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
query_states, key_states = self._optimized_apply_rotary(query_states, key_states, cos, sin)
@ -244,8 +245,11 @@ class WrappedLlamaBlock(OptimizedLlamaDecoderLayer):
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
)
attention_mask = LlamaModel._prepare_decoder_attention_mask(
None, attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask=attention_mask,
input_shape=(batch_size, seq_length),
inputs_embeds=hidden_states,
past_key_values_length=past_key_values_length,
)
outputs = super().forward(

@ -90,6 +90,10 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None,
)
if past_key_values is None:
past_key_values = RemotePastKeyValues()
past_key_values.update_seen(hidden_states.size(1))
# Remove prefix
if use_prompts:
hidden_states = hidden_states[:, self.pre_seq_len :]
@ -97,9 +101,10 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
# Add last hidden state
hidden_states = self.norm(hidden_states)
hidden_states = hidden_states.view(output_shape)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=RemotePastKeyValues(),
past_key_values=past_key_values,
hidden_states=None,
attentions=None,
)

@ -26,9 +26,7 @@ logger = get_logger(__name__)
def check_peft_repository(repo_id: str) -> bool:
fs = HfFileSystem()
list_of_files = fs.glob(f"{repo_id}/{SAFETENSORS_WEIGHTS_NAME}", detail=False)
return len(list_of_files) > 0
return HfFileSystem().exists(f"{repo_id}/{SAFETENSORS_WEIGHTS_NAME}")
def load_specific_module(block_idx: int, filepath: str, framework: str = "pt", device: Optional[int] = None):

@ -2,6 +2,8 @@ from typing import Optional, Tuple
import pytest
import torch
from transformers.cache_utils import DynamicCache
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
@ -116,6 +118,8 @@ class UnoptimizedWrappedLlamaBlock(LlamaDecoderLayer):
past_key_values_length = past_key_value[0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length)
elif use_cache:
past_key_value = DynamicCache()
if position_ids is None:
device = hidden_states.device
@ -131,8 +135,9 @@ class UnoptimizedWrappedLlamaBlock(LlamaDecoderLayer):
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
)
attention_mask = LlamaModel._prepare_decoder_attention_mask(
None, attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
)
outputs = super().forward(
@ -156,19 +161,20 @@ class UnoptimizedWrappedLlamaBlock(LlamaDecoderLayer):
def _reorder_cache_from_bloom_to_llama(
self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
) -> Tuple[torch.Tensor]:
) -> DynamicCache:
key_states, value_states = key_value
key_states = key_states.permute(0, 2, 1)
key_states = key_states.view(
batch_size, self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
)
value_states = value_states.view(*key_states.shape)
return (key_states, value_states)
past_key_values = ((key_states, value_states),)
return DynamicCache.from_legacy_cache(past_key_values)
def _reorder_cache_from_llama_to_bloom(
self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
self, key_value: DynamicCache, batch_size: int, seq_length: int
) -> Tuple[torch.Tensor]:
key_states, value_states = key_value
key_states, value_states = key_value.to_legacy_cache()[0]
value_states = value_states.view(
batch_size * self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
)
@ -195,7 +201,7 @@ def test_optimized_block(device):
if config.model_type == "falcon":
unopt_block = UnoptimizedWrappedFalconBlock(config).to(dtype)
elif config.model_type == "llama":
unopt_block = UnoptimizedWrappedLlamaBlock(config).to(dtype)
unopt_block = UnoptimizedWrappedLlamaBlock(config, layer_idx=0).to(dtype)
else:
pytest.skip(f"This test is not applicable to {config.model_type} models")

Loading…
Cancel
Save