Bump transformers to 4.25.1 (#151)

- latest accelerate, transformers, huggingface_hub
- rearrange attention caches to support https://github.com/huggingface/transformers/pull/18344
- remove unused code
- fix edge case where session crashes when receiving seq length 0
- assert transformer version when importing WrappedBloomBlock

Co-authored-by: Alexander Borzunov <borzunov.alexander@gmail.com>
Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
pull/153/head
justheuristic 1 year ago committed by GitHub
parent e4dc938dfe
commit b04982c1a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -33,9 +33,9 @@ python_requires = >=3.7
install_requires =
torch>=1.12
bitsandbytes==0.34.0
accelerate==0.10.0
huggingface-hub==0.7.0
transformers==4.21.3
accelerate==0.15.0
huggingface-hub==0.11.1
transformers==4.25.1
protobuf>=3.20.3,<4.0dev
hivemind==1.1.3
humanfriendly

@ -1,2 +0,0 @@
from petals.bloom.block import BloomBlock
from petals.bloom.model import BloomConfig, BloomForCausalLM, BloomModel, BloomPreTrainedModel

@ -3,253 +3,57 @@ Bloom intermediate layer
Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
See commit history for authorship.
"""
import math
import os
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.quantized.dynamic.modules.linear
import transformers
from transformers.models.bloom.modeling_bloom import BloomBlock, _expand_mask, _make_causal_mask, build_alibi_tensor
from petals.bloom.ops import (
BloomGelu,
BloomScaledSoftmax,
attention_mask_func,
build_alibi_tensor,
dropout_add,
pre_process_alibi_for_pad,
split_tensor_along_last_dim,
)
if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
assert transformers.__version__.startswith("4.25."), "Please install transformers 4.25.1"
class BloomAttention(nn.Module):
def __init__(self, config, layer_number=None):
super().__init__()
self.hidden_size = config.hidden_size
self.num_heads = config.n_head
self.head_dim = self.hidden_size // self.num_heads
self.split_size = self.hidden_size
self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
self.masked_softmax_fusion = config.masked_softmax_fusion
self.hidden_dropout = config.hidden_dropout
if self.head_dim * self.num_heads != self.hidden_size:
raise ValueError(
f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
f" {self.num_heads})."
)
# Layer-wise attention scaling
self.layer_number = max(1, layer_number)
self.norm_factor = math.sqrt(self.head_dim) * self.layer_number
# Scaled Softmax
self.scale_mask_softmax = BloomScaledSoftmax(
self.masked_softmax_fusion,
attention_mask_func,
self.attention_softmax_in_fp32,
self.layer_number,
)
self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
self.dense = nn.Linear(self.hidden_size, self.hidden_size)
self.attention_dropout = nn.Dropout(config.attention_dropout)
class WrappedBloomBlock(BloomBlock):
def forward(
self,
hidden_states,
residual,
layer_past=None,
attention_mask=None,
alibi=None,
head_mask=None,
use_cache=False,
output_attentions=False,
hidden_states: torch.Tensor,
*args,
attention_mask: Optional[torch.Tensor] = None,
alibi: Optional[torch.Tensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs
):
assert attention_mask is None
batch_size, seq_length = hidden_states.shape[:2]
past_length = 0 if layer_past is None else layer_past[0].shape[-1]
seq_length_with_past = seq_length + past_length
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
if alibi is None:
current_sequence_length = hidden_states.shape[1] + (0 if layer_past is None else layer_past[0].shape[1])
alibi = build_alibi_tensor(
current_sequence_length, n_head=self.num_heads, dtype=hidden_states.dtype, device=hidden_states.device
)
# hidden_states: [batch_size, seq_length, hidden_size]
# apply preprocessing if the input is padded
if attention_mask is not None:
alibi = pre_process_alibi_for_pad(alibi, attention_mask)
# otherwise repeat alibi tensor with the batch size
else:
alibi = alibi.repeat(hidden_states.shape[0], 1, 1)
mixed_x_layer = self.query_key_value(hidden_states)
# [batch_size, seq_length, 3 x hidden_size] --> [batch_size, seq_length, num_heads, 3 x head_dim]
new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_heads, 3 * self.head_dim)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [batch_size, seq_length, num_heads, 3 x head_dim] --> 3 [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
if layer_past is not None:
past_key, past_value = layer_past
key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=1)
value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=1)
if use_cache is True:
present = (key_layer, value_layer)
else:
present = None
# [batch_size, head_dim, q_length, k_length]
output_size = (query_layer.size(0), query_layer.size(2), query_layer.size(1), key_layer.size(1))
# [batch_size, q_length, num_heads, head_dim] -> [q_length, batch_size * num_heads, head_dim]
query_layer = query_layer.transpose(1, 0).reshape(output_size[2], output_size[0] * output_size[1], -1)
# [batch_size, k_length, num_heads, head_dim] -> [k_length, batch_size * num_heads, head_dim]
key_layer = key_layer.transpose(1, 0).reshape(output_size[3], output_size[0] * output_size[1], -1)
# Raw attention scores. [batch_size * num_heads, q_length, k_length]
beta = 1.0 / self.layer_number
matmul_result = torch.baddbmm(
alibi,
query_layer.transpose(1, 0),
key_layer.transpose(1, 0).transpose(1, 2),
beta=beta,
alpha=(1.0 / self.norm_factor),
alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
attention_mask = self._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length)
return super().forward(
hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs
)
# change view to [batch_size, num_heads, q_length, k_length]
attention_scores = matmul_result.view(*output_size)
# attention scores and attention mask [b, np, sq, sk]
max_positions = max(attention_scores.shape[-1], attention_scores.shape[-2])
attention_probs = self.scale_mask_softmax(attention_scores, attention_mask, max_positions).to(value_layer.dtype)
attention_probs = self.attention_dropout(attention_probs)
if head_mask is not None:
attention_probs = attention_probs * head_mask
# context layer shape: [batch_size, num_heads, q_length, head_dim]
output_size = (value_layer.size(0), value_layer.size(2), query_layer.size(0), value_layer.size(3))
# change view [k_length, batch_size x num_heads, head_dim]
value_layer = value_layer.transpose(1, 0).reshape(value_layer.size(1), output_size[0] * output_size[1], -1)
# change view [batch_size x num_heads, q_length, k_length]
attention_probs_reshaped = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
# matmul: [batch_size * num_heads, q_length, head_dim]
context_layer = torch.bmm(attention_probs_reshaped, value_layer.transpose(0, 1))
# change view [batch_size, num_heads, q_length, head_dim]
context_layer = context_layer.view(*output_size)
# [batchs_size, num_heads, q_length, head_dim] --> [q_length, batch_size, num_heads, head_dim]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [q_length, batch_size, num_heads, head_dim] --> [q_length, batch_size, hidden_size]
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,)
context_layer = context_layer.view(*new_context_layer_shape)
# Output. [q_length, batch_size, hidden_size]
# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
output_tensor = self.dense(context_layer)
output = output_tensor.transpose(1, 0)
output = dropout_add(output, residual, self.hidden_dropout, self.training)
outputs = (output, present)
if output_attentions:
outputs += (attention_probs,)
return outputs
class BloomMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
self.dense_h_to_4h = nn.Linear(self.hidden_size, 4 * self.hidden_size)
self.dense_4h_to_h = nn.Linear(4 * self.hidden_size, self.hidden_size)
self.hidden_dropout = config.hidden_dropout
self.gelu_impl = BloomGelu()
def forward(self, hidden_states, residual):
hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))
intermediate_output = self.dense_4h_to_h(hidden_states)
output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
return output
class BloomBlock(nn.Module):
def __init__(self, config, layer_number=None):
super().__init__()
self.hidden_size = config.hidden_size
self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_epsilon)
self.n_head = config.n_head
self.self_attention = BloomAttention(config, layer_number=layer_number)
self.post_attention_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_epsilon)
self.mlp = BloomMLP(config)
self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
self.hidden_dropout = config.hidden_dropout
def forward(
self,
hidden_states,
layer_past=None,
attention_mask=None,
head_mask=None,
use_cache=False,
output_attentions=False,
alibi=None,
):
# hidden_states: [batch_size, seq_length, hidden_size]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Layer norm post the self attention.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
def _prepare_attn_mask(
self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
) -> torch.BoolTensor:
# create causal mask
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
combined_attention_mask = None
device = attention_mask.device
_, src_length = input_shape
if src_length > 1:
combined_attention_mask = _make_causal_mask(
torch.Size(input_shape), device=device, past_key_values_length=past_key_values_length
)
# Self attention.
attn_outputs = self.self_attention(
layernorm_output,
residual,
layer_past=layer_past,
attention_mask=attention_mask,
alibi=alibi,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
)
attention_output = attn_outputs[0]
outputs = attn_outputs[1:]
layernorm_output = self.post_attention_layernorm(attention_output)
# Get residual
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = attention_output
# MLP.
output = self.mlp(layernorm_output, residual)
if use_cache:
outputs = (output,) + outputs
else:
outputs = (output,) + outputs[1:]
return outputs # hidden_states, present, attentions
return combined_attention_mask

@ -13,9 +13,10 @@ from typing import Optional, OrderedDict, Union
import torch
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from transformers.modeling_utils import WEIGHTS_NAME
from transformers.utils.hub import cached_path, hf_bucket_url
from transformers.models.bloom.configuration_bloom import BloomConfig
from transformers.utils import get_file_from_repo
from petals.bloom import BloomBlock, BloomConfig
from petals.bloom.block import WrappedBloomBlock
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
use_hivemind_log_handler("in_root_logger")
@ -23,10 +24,6 @@ logger = get_logger(__file__)
CLIENT_BRANCH = "main"
BLOCK_BRANCH_PREFIX = "block_"
USER_AGENT = {"file_type": "model", "framework": "pytorch", "from_auto_class": False}
FORCE_DOWNLOAD = False
RESUME_DOWNLOAD = False
LOCAL_FILES_ONLY = False
def load_pretrained_block(
@ -36,15 +33,15 @@ def load_pretrained_block(
torch_dtype: Union[torch.dtype, str] = "auto",
use_auth_token: Optional[str] = None,
cache_dir: Optional[str] = None,
) -> BloomBlock:
"""Load one BloomBlock from a converted model. See convert_model.py (or README.md) on how to convert it."""
) -> WrappedBloomBlock:
"""Load one BLOOM block from a converted model. See convert_model.py (or README.md) on how to convert it."""
if config is None:
config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
if cache_dir is None:
cache_dir = DEFAULT_CACHE_DIR
block = BloomBlock(config, layer_number=block_index)
block = WrappedBloomBlock(config)
state_dict = _load_state_dict(
converted_model_name_or_path, block_index, use_auth_token=use_auth_token, cache_dir=cache_dir
)
@ -70,20 +67,14 @@ def _load_state_dict(
cache_dir: Optional[str] = None,
) -> OrderedDict[str, torch.Tensor]:
revision = BLOCK_BRANCH_PREFIX + str(block_index) if block_index is not None else CLIENT_BRANCH
archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME, revision=revision, mirror=None)
# Load from URL or cache if already cached
resolved_archive_file = cached_path(
archive_file,
cache_dir=cache_dir,
force_download=FORCE_DOWNLOAD,
proxies=None,
resume_download=RESUME_DOWNLOAD,
local_files_only=LOCAL_FILES_ONLY,
archive_file = get_file_from_repo(
pretrained_model_name_or_path,
filename=WEIGHTS_NAME,
revision=revision,
use_auth_token=use_auth_token,
user_agent=USER_AGENT,
cache_dir=cache_dir,
)
state_dict = torch.load(resolved_archive_file, map_location="cpu")
state_dict = torch.load(archive_file, map_location="cpu")
return state_dict

@ -1,595 +0,0 @@
"""
PyTorch BLOOM model that implements several memory-efficient modes.
Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
See commit history for authorship.
"""
from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from hivemind import use_hivemind_log_handler
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
from transformers.file_utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
)
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
SequenceClassifierOutputWithPast,
)
from transformers.models.bloom.configuration_bloom import BloomConfig
from transformers.models.bloom.modeling_bloom import BloomPreTrainedModel
from transformers.utils import logging
from petals.bloom.block import BloomBlock
use_hivemind_log_handler("in_root_logger")
logger = logging.get_logger(__file__)
_CHECKPOINT_FOR_DOC = "bigscience/Bloom"
_CONFIG_FOR_DOC = "BloomConfig"
_TOKENIZER_FOR_DOC = "BloomTokenizer"
BLOOM_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`MemoryEfficientBloomConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
BLOOM_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
`input_ids_length` = `sequence_length` if `past_key_values` is `None` else
`past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
sequence tokens in the vocabulary.
If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
`input_ids`.
Indices can be obtained using [`BloomTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`):
Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
`past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
their past given to this model should not be passed as `input_ids` as they have already been computed.
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.max_position_embeddings - 1]`.
[What are position IDs?](../glossary#position-ids)
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
`past_key_values`).
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
"""
class _BloomPreTrainedModelWithModifiedDefaults(BloomPreTrainedModel):
@classmethod
def from_pretrained(cls, *args, low_cpu_mem_usage: Optional[bool] = None, **kwargs):
if low_cpu_mem_usage is None:
low_cpu_mem_usage = True
return super().from_pretrained(*args, low_cpu_mem_usage=low_cpu_mem_usage, **kwargs)
from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace(
"low_cpu_mem_usage(`bool`, *optional*)",
"low_cpu_mem_usage(`bool`, *optional*, defaults to `True` in Petals)",
)
@add_start_docstrings(
"The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.",
BLOOM_START_DOCSTRING,
)
class BloomModel(_BloomPreTrainedModelWithModifiedDefaults):
def __init__(self, config):
super().__init__(config)
assert not config.slow_but_exact, "slow_but_exact mode was removed for code simplicity"
self.embed_dim = config.hidden_size
self.n_head = config.n_head
# Embedding + LN Embedding
self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
# Transformer blocks
self.h = nn.ModuleList([BloomBlock(config, layer_number=i) for i in range(config.num_hidden_layers)])
# Final Layer Norm
self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.word_embeddings
def set_input_embeddings(self, new_embeddings):
self.word_embeddings = new_embeddings
@add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutputWithPastAndCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
if position_ids is not None:
logger.warning("position_ids are ignored in this bloom implementation")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if past_key_values is None:
past_key_values = tuple([None] * len(self.h))
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_head x N x N
# head_mask has shape n_layer x batch x n_head x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
# Note: it supports only float32 or bfloat16 inputs
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
output_shape = input_shape + (hidden_states.size(-1),)
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
# Compute alibi tensor: check build_alibi_tensor documentation
current_sequence_length = hidden_states.shape[1]
if past_key_values and past_key_values[0]:
current_sequence_length += past_key_values[0][0].shape[1]
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions, alibi=None)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
None,
attention_mask,
head_mask[i],
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
alibi=None,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
# Add last hidden state
hidden_states = self.ln_f(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
hidden_states = hidden_states.view(output_shape)
if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
@add_start_docstrings(
"""
The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input
embeddings).
""",
BLOOM_START_DOCSTRING,
)
class BloomForCausalLM(_BloomPreTrainedModelWithModifiedDefaults):
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.transformer = BloomModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
# only last token for inputs_ids if past is defined in kwargs
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None
return {
"input_ids": input_ids,
"past_key_values": past,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
}
@add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=CausalLMOutputWithCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
@staticmethod
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
"""
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
beam_idx at every generation step.
"""
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past
)
@add_start_docstrings(
"""
The modified language modeling head which does not create extra tensor for the linear layer with weights tied to the input
embeddings. Thus, it reduces initial memory consumption which might be crucial for large dictionaries.
In addition, it provides an effcient way to deal with half-precision word embeddings on CPU.
""",
BLOOM_START_DOCSTRING,
)
class LMHead(nn.Module):
def __init__(self, config, word_embeddings: nn.Embedding):
super().__init__()
self.word_embeddings = word_embeddings
self.chunk_size = config.chunk_size_for_efficient_fp16_on_cpu
@property
def in_features(self) -> int:
return self.word_embeddings.num_embeddings
@property
def out_features(self) -> int:
return self.word_embeddings.embedding_dim
@property
def weight(self):
return self.word_embeddings.weight
@property
def bias(self):
return None
def forward(self, hidden_states):
word_embeddings = self.word_embeddings.weight
# We use 'chunked_forward' only when embeddings are in half-precision on CPU.
if word_embeddings.dtype in [torch.float16, torch.bfloat16] and word_embeddings.device.type == "cpu":
lm_logits = self.chunked_forward(hidden_states)
else:
# Switch dtype in case word_embeddings are fp16/bf16
hidden_states = hidden_states.to(word_embeddings.dtype)
lm_logits = F.linear(hidden_states, word_embeddings)
return lm_logits
def chunked_forward(self, hidden_states):
"""Splits word embeddings on chunks and iteratively casts them into fp32 to perform matmul more efficiently on CPU.
chunk_size: provides trade-off between efficiency and extra memory consumption.
"""
assert self.chunk_size > 0, "Chunk size for chunked forward must be positive"
word_embeddings = self.word_embeddings.weight
num_embeddings = self.word_embeddings.num_embeddings
hidden_states = hidden_states.float()
output = torch.zeros(*hidden_states.shape[:-1], num_embeddings)
for i in range(0, num_embeddings, self.chunk_size):
chunk = word_embeddings[i : i + self.chunk_size].float()
output[..., i : i + self.chunk_size] = F.linear(hidden_states, chunk)
return output
@add_start_docstrings(
"""
The Bloom Model transformer with a sequence classification head on top (linear layer).
[`BloomForSequenceClassification`] uses the last token in order to do the classification, as other causal models
(e.g. GPT-1) do.
Since it does classification on the last token, it requires to know the position of the last token. If a
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
each row of the batch).
""",
BLOOM_START_DOCSTRING,
)
class BloomForSequenceClassification(_BloomPreTrainedModelWithModifiedDefaults):
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.transformer = BloomModel(config)
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=SequenceClassifierOutputWithPast,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)
if input_ids is not None:
batch_size = input_ids.shape[0]
else:
batch_size = inputs_embeds.shape[0]
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
else:
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)

@ -0,0 +1,74 @@
"""
PyTorch BLOOM model that implements several memory-efficient modes.
Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
See commit history for authorship.
"""
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from hivemind import use_hivemind_log_handler
from torch import nn
from transformers import BloomConfig
from transformers.utils import logging
use_hivemind_log_handler("in_root_logger")
logger = logging.get_logger(__file__)
class LMHead(nn.Module):
"""
The modified language modeling head which does not create extra tensor for the linear layer with weights tied to the input
embeddings. Thus, it reduces initial memory consumption which might be crucial for large dictionaries.
In addition, it provides an effcient way to deal with half-precision word embeddings on CPU.
"""
def __init__(self, config: BloomConfig, word_embeddings: nn.Embedding):
super().__init__()
self.word_embeddings = word_embeddings
self.chunk_size = config.chunk_size_for_efficient_fp16_on_cpu
@property
def in_features(self) -> int:
return self.word_embeddings.num_embeddings
@property
def out_features(self) -> int:
return self.word_embeddings.embedding_dim
@property
def weight(self):
return self.word_embeddings.weight
@property
def bias(self):
return None
def forward(self, hidden_states):
word_embeddings = self.word_embeddings.weight
# We use 'chunked_forward' only when embeddings are in half-precision on CPU.
if word_embeddings.dtype in [torch.float16, torch.bfloat16] and word_embeddings.device.type == "cpu":
lm_logits = self.chunked_forward(hidden_states)
else:
# Switch dtype in case word_embeddings are fp16/bf16
hidden_states = hidden_states.to(word_embeddings.dtype)
lm_logits = F.linear(hidden_states, word_embeddings)
return lm_logits
def chunked_forward(self, hidden_states):
"""Splits word embeddings on chunks and iteratively casts them into fp32 to perform matmul more efficiently on CPU.
chunk_size: provides trade-off between efficiency and extra memory consumption.
"""
assert self.chunk_size > 0, "Chunk size for chunked forward must be positive"
word_embeddings = self.word_embeddings.weight
num_embeddings = self.word_embeddings.num_embeddings
hidden_states = hidden_states.float()
output = torch.empty(*hidden_states.shape[:-1], num_embeddings)
for i in range(0, num_embeddings, self.chunk_size):
chunk = word_embeddings[i : i + self.chunk_size].float()
output[..., i : i + self.chunk_size] = F.linear(hidden_states, chunk)
return output

@ -1,242 +0,0 @@
"""
Utility operations used in the the BLOOM model
Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
See commit history for authorship.
"""
import math
import torch
import torch.autograd
import torch.nn.functional as F
from torch import nn
def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False):
"""Split a tensor along its last dimension.
Args:
tensor: ([`torch.tensor`], *required*):
input tensor to split
num_partitions ([`int`], *required*):
number of partitions to split the tensor
contiguous_split_chunks ([`bool`], *optional*, default=`False`)::
If True, make each chunk contiguous in memory.
"""
# Get the size and dimension.
last_dim = tensor.dim() - 1
numerator, denominator = tensor.size()[last_dim], num_partitions
if not (numerator % denominator == 0):
raise ValueError(f"{numerator} is not divisible by {denominator}")
last_dim_size = numerator // denominator
# Split.
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
# Note: torch.split does not create contiguous tensors by default.
if contiguous_split_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list)
return tensor_list
def attention_mask_func(attention_scores, attention_mask, causal_mask):
if attention_mask.dtype == torch.bool:
attention_mask_bool = ~attention_mask
else:
attention_mask_bool = (1 - attention_mask).bool()
query_length, key_length, n_heads = attention_scores.size(2), attention_scores.size(3), attention_scores.size(1)
padded_causal_mask = (
attention_mask_bool[:, None, key_length - query_length : key_length, None]
+ ~causal_mask[:, :, key_length - query_length : key_length, :key_length]
).bool()
padded_causal_mask = padded_causal_mask + attention_mask_bool[:, None, None, :key_length].bool()
# Make use of floats
return (
attention_scores.masked_fill_(padded_causal_mask.expand(-1, n_heads, -1, -1), -10000.0),
padded_causal_mask,
)
def build_alibi_tensor(
max_seq_len: int, n_head: int, dtype: torch.dtype = torch.bfloat16, device: torch.device = torch.device("cpu")
) -> torch.Tensor:
"""
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
`softmax(l+a) = softmax(l)`. Based on
https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
Args:
Returns tensor shaped (n_head, 1, max_seq_len)
max_seq_len: (`int`, *required*):
max sequence length
n_head: (`int`, *required*):
number of heads
dtype: (`torch.dtype`, *optional*, default=`torch.bfloat16`):
dtype of the output tensor
device: (`torch.device`, *optional*, default=`torch.device('cpu')`):
device of the output alibi tensor
"""
closest_power_of_2 = 2 ** math.floor(math.log2(n_head))
base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=device, dtype=torch.float32)
powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32)
slopes = torch.pow(base, powers)
if closest_power_of_2 != n_head:
extra_base = torch.tensor(
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=device, dtype=torch.float32
)
num_remaining_heads = min(closest_power_of_2, n_head - closest_power_of_2)
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32)
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
lengths = torch.arange(max_seq_len, device=device, dtype=torch.int32)
return (slopes.view(-1, 1, 1) * lengths.view(1, 1, -1)).to(dtype)
def pre_process_alibi_for_pad(alibi: torch.Tensor, attention_mask: torch.Tensor):
"""
Args:
Pre-process the alibi tensor for padding.
alibi: ([`torch.tensor`], *required*):
alibi tensor to pre-process
attention_mask: ([`torch.tensor`], *required*):
attention mask to pre-process
"""
assert attention_mask.ndim == 2, "mask should be [batch_size, seq_length]"
unpadded_indices = torch.relu(attention_mask.cumsum(dim=1) - 1)
# ^-- [batch, max_len], values correspond to element indices after removing padding
# We shift the alibi tensor + replace all the values where attention_mask==0.0 by 0
alibi = alibi.take_along_dim(unpadded_indices.unsqueeze(0), -1) * attention_mask.unsqueeze(0)
return alibi.reshape(alibi.shape[0] * alibi.shape[1], 1, -1)
def dropout_add(x, residual, prob, training):
"""
Dropout add function
Args:
x (`torch.tensor`, *required*):
input tensor
residual (`torch.tensor`, *required*):
esidual tensor
prob (`float`, *required*):
dropout probability
training (`bool`, *required*):
training mode
"""
out = nn.functional.dropout(x, p=prob, training=training)
out = residual + out
return out
def bloom_gelu_forward(x):
"""
Custom bias GELU function. Adapted from Megatron-DeepSpeed code. Here we use a simple implementation (inference) to
make the model jitable.
Args:
x (`torch.tensor`, *required*):
input hidden states
"""
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
def bloom_gelu_back(g, x):
"""
gradient of tanh approximation of gelu gradient of actual gelu is: 0.5 * (1. + torch.erf(x * 0.70710678)) +
0.3989423 * x * torch.exp(-0.5 * x * x)
Args:
g (`torch.tensor`, *required*):
gradient output tensor
x (`torch.tensor`, *required*):
input tensor
"""
x = x[0] # x is a tuple of 1 element, needs to unpack it first
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
return ff * g
class GeLUFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return bloom_gelu_forward(input)
@staticmethod
def backward(ctx, grad_output):
input = ctx.saved_tensors
tmp = bloom_gelu_back(grad_output, input)
return tmp
class BloomGelu(nn.Module):
"""
BloomBiasGelu wrapper function that make use of the simple function on inference mode to make the model
torchscriptable and use the autograd function in training mode to get the accurate results of the gradients Partly
copied from Megatron-DeepSpeed code and adapted for our needs
See here why autograd functions are not torchscriptable: https://github.com/pytorch/pytorch/issues/22329
"""
def __init__(self):
super().__init__()
def forward(self, x):
if self.training:
return GeLUFunction.apply(x)
else:
return bloom_gelu_forward(x)
class BloomScaledSoftmax(nn.Module):
"""
fused operation: scaling + mask + softmax
Args:
scaled_masked_softmax_fusion (`bool`, *required*):
flag to indicate user want to use softmax fusion
mask_func (`function`, *required*):
mask function to be applied.
softmax_in_fp32 (`bool`, *required*):
if true, softmax in performed at fp32 precision.
scale (`float`, *required*):
scaling factor used in input tensor scaling.
"""
def __init__(self, scaled_masked_softmax_fusion, mask_func, softmax_in_fp32, scale):
super().__init__()
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
self.mask_func = mask_func
self.softmax_in_fp32 = softmax_in_fp32
self.scale = scale
if not (self.scale is None or softmax_in_fp32):
raise ValueError("softmax should be in fp32 when scaled")
def forward(self, input, mask, max_positions):
input_dtype = input.dtype
input_in_16bit = input_dtype in [torch.float16, torch.bfloat16]
softmax_dtype = torch.float32 if self.softmax_in_fp32 else input_dtype
if self.scale is not None:
input = input * self.scale
if mask is None:
mask = torch.ones(input.shape[0], max_positions, dtype=torch.bool, device=input.device)
mask = mask.to(input.device)
causal_mask = (
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool))
.view(1, 1, max_positions, max_positions)
.to(input.device)
)
mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask)
probs = F.softmax(mask_output, dim=-1, dtype=softmax_dtype) * (~padded_causal_mask)
if input_in_16bit and self.softmax_in_fp32:
probs = probs.to(dtype=input_dtype)
return probs

@ -8,8 +8,8 @@ import transformers
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from huggingface_hub import Repository
from tqdm.auto import tqdm
from transformers.models.bloom.modeling_bloom import BloomModel
from petals.bloom import BloomModel
from petals.bloom.from_pretrained import BLOCK_BRANCH_PREFIX, CLIENT_BRANCH
from petals.client import DistributedBloomConfig

@ -3,10 +3,10 @@ import argparse
import torch
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from tqdm.auto import trange
from transformers import BloomConfig
from transformers.models.bloom.modeling_bloom import build_alibi_tensor
from petals.bloom.block import BloomBlock
from petals.bloom.model import BloomConfig
from petals.bloom.ops import build_alibi_tensor
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
@ -31,7 +31,6 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run a single bloom block locally on dummy data")
parser.add_argument("--config", required=True, type=str, help="Path to a config json file")
parser.add_argument("--state_dict", default=None, type=str, help="Optional path to saved block state dict")
parser.add_argument("--layer_index", default=0, type=int, help="Optional path to saved block state dict")
parser.add_argument("--num_steps", default=500, type=int, help="How many inference steps to run")
parser.add_argument("--device", default=None, type=str, help="Run inference on this device")
args = parser.parse_args()
@ -40,7 +39,7 @@ if __name__ == "__main__":
args.device = "cuda" if torch.cuda.is_available() else "cpu"
config = BloomConfig.from_json_file(args.config)
block = BloomBlock(config, args.layer_index).to(args.device)
block = BloomBlock(config).to(args.device)
cache = None

@ -7,15 +7,15 @@ import torch
import torch.nn as nn
from hivemind.utils.logging import get_logger, loglevel, use_hivemind_log_handler
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from petals.bloom.model import (
from transformers.models.bloom import (
BloomConfig,
BloomForCausalLM,
BloomForSequenceClassification,
BloomModel,
BloomPreTrainedModel,
LMHead,
)
from petals.bloom.modeling_utils import LMHead
from petals.client.remote_generation import RemoteGenerationMixin
from petals.client.remote_sequential import RemoteSequential
from petals.constants import PUBLIC_INITIAL_PEERS
@ -66,7 +66,20 @@ def force_non_empty_weights():
nn.Module.register_parameter = possibly_patched_register_parameter
class DistributedBloomModel(BloomModel):
class _LowCPUMemoryMixin:
@classmethod
def from_pretrained(cls, *args, low_cpu_mem_usage: Optional[bool] = None, **kwargs):
if low_cpu_mem_usage is None:
low_cpu_mem_usage = True
return super().from_pretrained(*args, low_cpu_mem_usage=low_cpu_mem_usage, **kwargs)
from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace(
"low_cpu_mem_usage(`bool`, *optional*)",
"low_cpu_mem_usage(`bool`, *optional*, defaults to `True` in Petals)",
)
class DistributedBloomModel(_LowCPUMemoryMixin, BloomModel):
"""BloomModel, but all transformer layers are hosted by the swarm"""
_keys_to_ignore_on_load_missing = BloomModel._keys_to_ignore_on_load_missing + [
@ -192,7 +205,7 @@ class DistributedBloomModel(BloomModel):
)
class DistributedBloomForCausalLM(RemoteGenerationMixin, BloomForCausalLM):
class DistributedBloomForCausalLM(_LowCPUMemoryMixin, RemoteGenerationMixin, BloomForCausalLM):
"""DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
_keys_to_ignore_on_load_missing = (
@ -230,7 +243,7 @@ class DistributedBloomForCausalLM(RemoteGenerationMixin, BloomForCausalLM):
self.lm_head.bias[...] = new_lm_head.bias
class DistributedBloomForSequenceClassification(BloomForSequenceClassification):
class DistributedBloomForSequenceClassification(_LowCPUMemoryMixin, BloomForSequenceClassification):
_keys_to_ignore_on_load_missing = (
BloomForSequenceClassification._keys_to_ignore_on_load_missing
+ DistributedBloomModel._keys_to_ignore_on_load_missing

@ -57,7 +57,7 @@ class RemoteSequenceManager:
update_period: float = 30,
request_timeout: float = 30,
min_backoff: float = 1,
ban_timeout: float = 60,
ban_timeout: float = 15,
sequence_info: Optional[RemoteSequenceInfo] = None,
rpc_info: Optional[dict] = None,
banned_peers: Optional[Blacklist] = None,

@ -6,7 +6,7 @@ from hivemind import BatchTensorDescriptor, use_hivemind_log_handler
from hivemind.moe.server.module_backend import ModuleBackend
from hivemind.utils import get_logger
from petals.bloom.from_pretrained import BloomBlock
from petals.bloom.block import WrappedBloomBlock
from petals.server.memory_cache import MemoryCache
from petals.server.task_pool import PrioritizedTaskPool
from petals.utils.misc import is_dummy
@ -16,11 +16,11 @@ logger = get_logger(__file__)
class TransformerBackend(ModuleBackend):
"""A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward"""
"""A wrapper for a BLOOM block that can process requests for BLOOM layer forward, backward and inference"""
def __init__(self, *args, memory_cache: MemoryCache, backend_dtype: torch.dtype, **kwargs):
super().__init__(*args, **kwargs)
assert isinstance(self.module, BloomBlock)
assert isinstance(self.module, WrappedBloomBlock)
self.memory_cache = memory_cache
for name, param in self.module.named_parameters():
assert not param.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
@ -50,6 +50,7 @@ class TransformerBackend(ModuleBackend):
)
def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
num_heads, head_dim = self.module.self_attention.num_heads, self.module.self_attention.head_dim
with torch.inference_mode():
attention_cache_handle = int(cache_metadata[0, 0].item())
prefix_length = int(cache_metadata[0, 1].item())
@ -59,24 +60,31 @@ class TransformerBackend(ModuleBackend):
), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
with self.memory_cache.use_cache(attention_cache_handle) as cache:
assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5
batch_size = cache.shape[1]
max_length = cache.numel() // (2 * batch_size * head_dim * num_heads)
assert isinstance(self.module, WrappedBloomBlock) and cache.shape[0] == 2 and cache.ndim == 3
if not is_dummy(hypo_ids):
assert hypo_ids.shape[0] == cache.shape[1]
cache[:, :] = cache[:, hypo_ids] # in-place reorder cache by hypo ids
layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length]
logger.debug(f"Metadata: {cache_metadata}, past_k.shape={past_k.shape}, past_v.shape={past_v.shape}")
hidden_states, (new_k, new_v) = self.module.forward(
hidden_states, layer_past=layer_past, use_cache=True
)
key_cache = cache[0].view(batch_size, num_heads, head_dim, max_length)
value_cache = cache[1].view(batch_size, num_heads, max_length, head_dim)
# todo remove these asserts once we pass all tests
new_length = new_v.shape[1]
key_past = key_cache.flatten(0, 1)[:, :, :prefix_length] # [batch * num_heads, head_dim, kv_length]
value_past = value_cache.flatten(0, 1)[:, :prefix_length, :] # [batch * num_heads, kv_length, head_dim]
logger.debug(
f"Metadata: {cache_metadata}, past_k.shape={key_past.shape}, past_v.shape={value_past.shape}"
)
hidden_states, (new_key, new_value) = self.module.forward(
hidden_states, layer_past=(key_past, value_past), use_cache=True
)
new_length = new_key.shape[-1]
assert new_length > prefix_length
assert new_k.shape[0] == past_k.shape[0] and new_v.shape[0] == past_v.shape[0]
assert new_k.shape[1] == new_length and new_v.shape[1] == new_length
assert new_k.shape[2:] == past_k.shape[2:] and new_v.shape[2:] == past_v.shape[2:]
cache[0, :, prefix_length:new_length, :] = new_k[:, prefix_length:new_length]
cache[1, :, prefix_length:new_length, :] = new_v[:, prefix_length:new_length]
assert new_key.shape[0] == key_past.shape[0] and new_value.shape[0] == value_past.shape[0]
assert new_key.shape[-1] == new_length and new_value.shape[-2] == new_length
new_key = new_key.view(batch_size, num_heads, head_dim, -1)
new_value = new_value.view(batch_size, num_heads, -1, head_dim)
key_cache[:, :, :, prefix_length:new_length] = new_key[:, :, :, prefix_length:new_length]
value_cache[:, :, prefix_length:new_length, :] = new_value[:, :, prefix_length:new_length, :]
return (hidden_states,)
def get_pools(self) -> Sequence[PrioritizedTaskPool]:

@ -2,8 +2,9 @@ from typing import Optional, Union
import torch
from accelerate import init_empty_weights
from transformers import BloomConfig
from petals.bloom import BloomBlock, BloomConfig
from petals.bloom.block import WrappedBloomBlock
def resolve_block_dtype(config: BloomConfig, dtype: Union[str, torch.dtype]) -> Union[str, torch.dtype]:
@ -22,7 +23,6 @@ def get_block_size(
*,
dtype: Optional[Union[str, torch.dtype]] = None,
load_in_8bit: Optional[bool] = None,
layer_index: int = 0,
eps: float = 0.01, # eps accounts for ~1% of metainfo for tensor descriptions, quantization tables, etc.
) -> int:
if location == "memory":
@ -31,7 +31,7 @@ def get_block_size(
), 'get_block_size(..., location="memory") requires to specify dtype and load_in_8bit for calculations'
with init_empty_weights():
block = BloomBlock(config, layer_index)
block = WrappedBloomBlock(config)
n_params = sum(param.numel() for param in block.parameters())
if location == "memory" and load_in_8bit:

@ -146,6 +146,9 @@ class TransformerConnectionHandler(ConnectionHandler):
for backend, prompt, cache_handle in zip(requested_backends, prompts, cache_handles):
if not is_dummy(prompt):
hidden_states[:, : prompt.shape[1]] += prompt
if hidden_states.numel() == 0:
continue # user passed a tensor with 0 tokens. This is a special case that occurs, e.g.
# when user wants to pre-allocate cache or check that server *can* allocate that cache
cache_metadata[:, 0], cache_metadata[:, 1] = cache_handle, prefix_length
assert isinstance(
@ -343,10 +346,8 @@ class TransformerConnectionHandler(ConnectionHandler):
for backend in backends:
num_heads = backend.module.self_attention.num_heads
head_dim = backend.module.self_attention.head_dim
descr = TensorDescriptor(size=(2, batch_size, max_length, num_heads, head_dim), dtype=backend.dtype)
# [key_or_value, batch_size, max_length, num_heads, head_dim]
descr = TensorDescriptor(size=(2, batch_size, num_heads * head_dim * max_length), dtype=backend.dtype)
# ^-- flattened batch-first tensor of both keys and values; based on BLOOM layer_past layout
handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(descr)))
total_size += descr.numel() * torch.finfo(descr.dtype).bits // 8

@ -16,9 +16,9 @@ from hivemind.moe.server.layers import add_custom_models_from_file
from hivemind.moe.server.runtime import Runtime
from hivemind.proto.runtime_pb2 import CompressionType
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from transformers import BloomConfig
from petals.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
from petals.bloom.model import BloomConfig
from petals.constants import PUBLIC_INITIAL_PEERS
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
from petals.dht_utils import declare_active_modules, get_remote_module_infos

@ -9,10 +9,9 @@ from typing import Optional, Union
import torch
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from transformers import BloomConfig
from petals.bloom.block import BloomBlock
from petals.bloom.model import BloomConfig
from petals.bloom.ops import build_alibi_tensor
from petals.bloom.block import WrappedBloomBlock
from petals.server.block_utils import resolve_block_dtype
from petals.utils.convert_8bit import replace_8bit_linear
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
@ -115,10 +114,9 @@ def measure_compute_rps(
load_in_8bit: bool,
n_tokens: int = 16,
n_steps: int = 500,
layer_index: int = 0,
) -> float:
with torch.inference_mode():
block = BloomBlock(config, layer_index).to(dtype)
block = WrappedBloomBlock(config).to(dtype)
if load_in_8bit:
block = replace_8bit_linear(block)
block = block.to(device)
@ -127,10 +125,9 @@ def measure_compute_rps(
elapsed = 0
for step in range(n_steps + 1):
dummy_input = torch.randn(n_tokens, 1, config.hidden_size, device=device, dtype=dtype)
alibi = build_alibi_tensor(step + 1, config.num_attention_heads, device=device, dtype=dtype)
start_time = time.perf_counter()
_, cache = block.forward(dummy_input, alibi=alibi, use_cache=True, layer_past=cache)
_, cache = block.forward(dummy_input, use_cache=True, layer_past=cache)
if step >= 1: # Skip the 1st step to exclude the initialization time
elapsed += time.perf_counter() - start_time
device_rps = n_steps * n_tokens / elapsed

@ -0,0 +1,15 @@
import pytest
import torch
from test_utils import MODEL_NAME
from petals.client import DistributedBloomConfig
from petals.server.throughput import measure_compute_rps
@pytest.mark.forked
def test_throughput_basic():
config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
throughput = measure_compute_rps(
config, device=torch.device("cpu"), dtype=torch.bfloat16, load_in_8bit=False, n_steps=10
)
assert isinstance(throughput, float) and throughput > 0

@ -3,9 +3,9 @@ import torch
import transformers
from hivemind import get_logger, use_hivemind_log_handler
from test_utils import *
from transformers.generation_utils import BeamSearchScorer
from transformers.generation import BeamSearchScorer
from transformers.models.bloom import BloomForCausalLM
from petals.bloom.model import BloomForCausalLM
from petals.client.remote_model import DistributedBloomForCausalLM
use_hivemind_log_handler("in_root_logger")
@ -13,7 +13,8 @@ logger = get_logger(__file__)
@pytest.mark.forked
def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
@pytest.mark.parametrize("pass_empty_tensors", (True, False))
def test_full_model_exact_match(pass_empty_tensors: bool, atol_forward=1e-3, atol_inference=1e-3):
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
model = DistributedBloomForCausalLM.from_pretrained(
MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
@ -33,8 +34,15 @@ def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
embs = model.transformer.word_embeddings_layernorm(embs)
recurrent_outputs = []
with model.transformer.h.inference_session(max_length=embs.shape[1]) as sess:
if pass_empty_tensors:
recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
for t in range(embs.shape[1]):
recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
if t == int(embs.shape[1] // 2) and pass_empty_tensors:
recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
recurrent_outputs = model.transformer.ln_f(recurrent_outputs)
recurrent_outputs = model.lm_head(recurrent_outputs)

Loading…
Cancel
Save