WIP: make DistributedBloom compliant with HF interface

standardize
justheuristic 2 years ago
parent 6c437c9249
commit 4695071ad2

@ -4,6 +4,8 @@ Early dev prototype for decentralized bloom. Not for public eyes **yet**.
Roadmap: [issue #12](https://github.com/learning-at-home/bloom-demo/issues/12)
Latest news @ main branch (max 5):
- [Jul 4] @dbaranchuk implemented chained rpc_forward and rpc_backward (for prompt tuning)
- [Jul 3] @dbaranchuk optimized DistributedBloom to reduce embeddings/logits RAM usage
- [Jul 1] @yozh added RemoteSequential and test for full model exact match
- [June 28] @dbaranchunk added quick deployment scripts for testnet

@ -5,7 +5,7 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from tqdm.auto import trange
from src.bloom.block import BloomBlock
from src.bloom.model import DistributedBloomConfig
from src.bloom.model import BloomConfig
from src.bloom.ops import build_alibi_tensor
use_hivemind_log_handler("in_root_logger")
@ -39,7 +39,7 @@ if __name__ == "__main__":
if args.device is None:
args.device = "cuda" if torch.cuda.is_available() else "cpu"
config = DistributedBloomConfig.from_json_file(args.config)
config = BloomConfig.from_json_file(args.config)
block = BloomBlock(config, args.layer_index).to(args.device)
cache = None

@ -43,16 +43,8 @@ class BloomAttention(nn.Module):
self.layer_number,
)
if config.compression == "qint8":
self.query_key_value = nn.quantized.dynamic.modules.Linear(
self.hidden_size, 3 * self.hidden_size, bias_=True, dtype=torch.qint8
)
self.dense = nn.quantized.dynamic.modules.Linear(
self.hidden_size, self.hidden_size, bias_=True, dtype=torch.qint8
)
else:
self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
self.dense = nn.Linear(self.hidden_size, self.hidden_size)
self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
self.dense = nn.Linear(self.hidden_size, self.hidden_size)
self.attention_dropout = nn.Dropout(config.attention_dropout)
@ -173,16 +165,8 @@ class BloomMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
if config.compression == "qint8":
self.dense_h_to_4h = nn.quantized.dynamic.modules.Linear(
self.hidden_size, 4 * self.hidden_size, bias_=True, dtype=torch.qint8
)
self.dense_4h_to_h = nn.quantized.dynamic.modules.Linear(
4 * self.hidden_size, self.hidden_size, bias_=True, dtype=torch.qint8
)
else:
self.dense_h_to_4h = nn.Linear(self.hidden_size, 4 * self.hidden_size)
self.dense_4h_to_h = nn.Linear(4 * self.hidden_size, self.hidden_size)
self.dense_h_to_4h = nn.Linear(self.hidden_size, 4 * self.hidden_size)
self.dense_4h_to_h = nn.Linear(4 * self.hidden_size, self.hidden_size)
self.hidden_dropout = config.hidden_dropout
self.gelu_impl = BloomGelu()

@ -3,8 +3,10 @@ PyTorch BLOOM model that implements several memory-efficient modes.
Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
See commit history for authorship.
"""
from typing import Tuple
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from hivemind import use_hivemind_log_handler
from torch import nn
@ -13,25 +15,19 @@ from transformers.file_utils import (add_code_sample_docstrings, add_start_docst
add_start_docstrings_to_model_forward)
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
from transformers.modeling_utils import PreTrainedModel
from transformers.models.bloom.configuration_bloom import BloomConfig as _VanillaBloomConfig
from transformers.models.bloom.configuration_bloom import BloomConfig
from transformers.utils import logging
from src.bloom.block import BloomBlock
from src.bloom.ops import build_alibi_tensor
use_hivemind_log_handler("in_root_logger")
logger = logging.get_logger(__file__)
_CHECKPOINT_FOR_DOC = "bigscience/Bloom"
_CONFIG_FOR_DOC = "DistributedBloomConfig"
_CONFIG_FOR_DOC = "BloomConfig"
_TOKENIZER_FOR_DOC = "BloomTokenizer"
class DistributedBloomConfig(_VanillaBloomConfig):
compression: str = "none"
slow_but_exact: bool = False
class BloomPreTrainedModel(PreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
"""
@ -39,7 +35,7 @@ class BloomPreTrainedModel(PreTrainedModel):
models.
"""
config_class = DistributedBloomConfig
config_class = BloomConfig
base_model_prefix = "transformer"
supports_gradient_checkpointing = True
_no_split_modules = ["BloomBlock"]
@ -312,17 +308,107 @@ class BloomModel(BloomPreTrainedModel):
@add_start_docstrings(
"""
The Bloom interface for various applications, e.g., inference, classification...
The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input
embeddings).
""",
BLOOM_START_DOCSTRING,
)
class BloomForYou(BloomPreTrainedModel):
class BloomForCausalLM(BloomPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.transformer = BloomModel(config)
self.lm_head = None
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.transformer.word_embeddings
def set_output_embeddings(self, new_embeddings):
self.transformer.word_embeddings.weight = new_embeddings.weight
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
# only last token for inputs_ids if past is defined in kwargs
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None
return {
"input_ids": input_ids,
"past_key_values": past,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
}
@add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=CausalLMOutputWithCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
labels=None,
return_dict=None,
**kwargs
):
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer.forward(input_ids=input_ids, return_dict=return_dict, **kwargs)
word_embeddings = self.transformer.word_embeddings.weight
# Switch dtype in case word_embeddings are fp16/bf16
hidden_states = transformer_outputs[0].to(word_embeddings.dtype)
lm_logits = F.linear(hidden_states, word_embeddings).float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
@staticmethod
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
"""
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
beam_idx at every generation step.
"""
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past
)

@ -1,15 +1,13 @@
# this code is in active development, interfaces may change
import os
from typing import Optional, Tuple, Union
from typing import Optional, Union, Tuple
import hivemind
import torch
from hivemind import DHT, get_logger, use_hivemind_log_handler
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from src.bloom import BloomForYou, DistributedBloomConfig
from src.bloom.model import BloomModel, BloomForCausalLM, BloomConfig
from src.bloom.from_pretrained import CLIENT_BRANCH, _load_state_dict
from src.bloom.model import BloomPreTrainedModel
from src.client.remote_sequential import RemoteSequential
from src.data_structures import UID_DELIMITER
@ -17,111 +15,38 @@ use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
class DistributedBloomForYou(BloomForYou):
class DistributedBloomConfig(BloomConfig):
"""
A bloom config that contains information about DHT peers.
To create a distributed model, one must provide dht_prefix and either initial_peers or dht.
"""
initial_peers: Tuple[str, ...] = () # a list of initial peers for hivemind DHT
dht_prefix: str # a prefix for all dht keys that correspond to this model (usually equal to model name)
dht: Optional[hivemind.DHT] = None # a running DHT instance, e.g. when using the same DHT for multiple models
class DistributedBloomModel(BloomModel):
"""BloomModel, but all transformer layers are hosted by the swarm"""
def __init__(self, config: DistributedBloomConfig):
assert self.config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..."
assert self.config.initial_peers or config.dht, "Please specify initial_peers=list(...) or dht=hivemind.DHT(...)"
def __init__(self, config: DistributedBloomConfig, dht: DHT, prefix: str):
n_layer, config.n_layer = config.n_layer, 0 # temporarily set n_layer to 0 to prevent layer initialization
super().__init__(config)
assert len(self.transformer.h) == 0
assert len(self.h) == 0
config.n_layer = n_layer
self.transformer.h = RemoteSequential(config, dht, prefix)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
if "initial_peers" not in kwargs:
raise ValueError("Please specify initial_peers=...")
dht = hivemind.DHT(
initial_peers=kwargs.pop("initial_peers"), client_mode=kwargs.pop("client_mode", True), start=True
)
if "prefix" not in kwargs:
logger.debug(f"No DHT prefix specified; using automatic prefix {pretrained_model_name_or_path}")
assert (
UID_DELIMITER not in pretrained_model_name_or_path
), f"Cannot infer prefix automatically from {pretrained_model_name_or_path}; please specify prefix=..."
prefix = kwargs.pop("prefix", pretrained_model_name_or_path)
dht = config.dht if config.dht is not None else hivemind.DHT(
initial_peers=config.initial_peers, client_mode=True, start=True)
assert isinstance(dht, hivemind.DHT) and dht.is_alive(), "dht must be a running hivemind.DHT instance"
self.h = RemoteSequential(config, dht, config.dht_prefix)
config = DistributedBloomConfig.from_pretrained(pretrained_model_name_or_path, revision=CLIENT_BRANCH, **kwargs)
model = cls(config, dht, prefix)
model.transformer.load_state_dict(
_load_state_dict(pretrained_model_name_or_path, use_auth_token=kwargs.get("use_auth_token")), strict=True
)
return model
class DistributedBloomForCausalLM(DistributedBloomForYou):
class DistributedBloomForCausalLM(BloomForCausalLM):
"""DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
def __init__(self, config: DistributedBloomConfig):
BloomPreTrainedModel().__init__(config)
self.transformer = DistributedBloomModel(config)
# Initialize weights and apply final processing
self.post_init()
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
# only last token for inputs_ids if past is defined in kwargs
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None
return {
"input_ids": input_ids,
"past_key_values": past,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
}
def forward(self, input_ids, labels=None, return_dict=None, **kwargs):
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer.forward(input_ids=input_ids, return_dict=return_dict, **kwargs)
# Switch dtype in case word_embeddings are fp16
word_embeddings = self.transformer.word_embeddings.weight.t()
hidden_states = transformer_outputs[0].to(word_embeddings.dtype)
lm_logits = (hidden_states @ word_embeddings).float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
@staticmethod
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
"""
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
beam_idx at every generation step.
"""
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past
)

Loading…
Cancel
Save