Merge remote-tracking branch 'origin/main' into optimize_seq

optimize_seq
justheuristic 2 years ago
commit a5fdb5753e

@ -23,6 +23,7 @@ from transformers.modeling_outputs import (
)
from transformers.modeling_utils import PreTrainedModel
from transformers.models.bloom.configuration_bloom import BloomConfig
from transformers.models.bloom.modeling_bloom import BloomPreTrainedModel
from transformers.utils import logging
from src.bloom.block import BloomBlock
@ -35,42 +36,6 @@ _CONFIG_FOR_DOC = "BloomConfig"
_TOKENIZER_FOR_DOC = "BloomTokenizer"
class BloomPreTrainedModel(PreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = BloomConfig
base_model_prefix = "transformer"
supports_gradient_checkpointing = True
_no_split_modules = ["BloomBlock"]
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(module, (nn.Linear)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, BloomModel):
module.gradient_checkpointing = value
BLOOM_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the

@ -187,8 +187,11 @@ class DistributedBloomForSequenceClassification(BloomForSequenceClassification):
config_class = DistributedBloomConfig
def __init__(self, config: DistributedBloomConfig):
super().__init__(config)
BloomPreTrainedModel.__init__(self, config)
self.num_labels = config.num_labels
self.transformer = DistributedBloomModel(config)
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
# Initialize weights and apply final processing
self.post_init()

@ -81,8 +81,6 @@ class TransformerBackend(ModuleBackend):
assert new_k.shape[0] == past_k.shape[0] and new_v.shape[0] == past_v.shape[0]
assert new_k.shape[1] == new_length and new_v.shape[1] == new_length
assert new_k.shape[2:] == past_k.shape[2:] and new_v.shape[2:] == past_v.shape[2:]
assert torch.allclose(new_v[:, : past_v.shape[1]], past_v)
assert torch.allclose(new_k[:, : past_k.shape[1]], past_k)
cache[0, :, prefix_length:new_length, :] = new_k[:, prefix_length:new_length]
cache[1, :, prefix_length:new_length, :] = new_v[:, prefix_length:new_length]
return (hidden_states,)

@ -4,14 +4,13 @@ import torch
def replace_8bit_linear(model, threshold=6.0):
"""
A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules from the `bitsandbytes`
A helper function to convert all `torch.nn.Linear` modules to `bnb.nn.Linear8bit` modules from the `bitsandbytes`
library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8():
8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA
version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/
bitsandbytes-cudaXXX` with `XXX` is your CUDA version (e.g., 11.6 = 116)
The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` that should
be kept as a `torch.nn.Linear` module. The replacement is done under `init_empty_weights` context manager so no
CPU/GPU memory is required to run this function.
The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` and 'score' that should
be kept as a `torch.nn.Linear` module.
Parameters:
model (`torch.nn.Module`):
Input model or `torch.nn.Module` as the function is run recursively.
@ -23,12 +22,15 @@ def replace_8bit_linear(model, threshold=6.0):
if len(list(module.children())) > 0:
replace_8bit_linear(module, threshold)
if isinstance(module, torch.nn.Linear) and n != "lm_head":
if isinstance(module, torch.nn.Linear) and n not in ["lm_head", "score"]:
model._modules[n] = bnb.nn.Linear8bitLt(
module.in_features,
module.out_features,
module.bias is not None,
has_fp16_weights=False,
threshold=threshold,
).to(model._modules[n].weight.device)
)
model._modules[n].weight = bnb.nn.Int8Params(
module.weight.data, requires_grad=False, has_fp16_weights=False
).to(module.weight.dtype)
return model

Loading…
Cancel
Save