diff --git a/src/bloom/model.py b/src/bloom/model.py index 5d6afdb..a5c7d9e 100644 --- a/src/bloom/model.py +++ b/src/bloom/model.py @@ -23,6 +23,7 @@ from transformers.modeling_outputs import ( ) from transformers.modeling_utils import PreTrainedModel from transformers.models.bloom.configuration_bloom import BloomConfig +from transformers.models.bloom.modeling_bloom import BloomPreTrainedModel from transformers.utils import logging from src.bloom.block import BloomBlock @@ -35,42 +36,6 @@ _CONFIG_FOR_DOC = "BloomConfig" _TOKENIZER_FOR_DOC = "BloomTokenizer" -class BloomPreTrainedModel(PreTrainedModel): - _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"] - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = BloomConfig - base_model_prefix = "transformer" - supports_gradient_checkpointing = True - _no_split_modules = ["BloomBlock"] - - def __init__(self, *inputs, **kwargs): - super().__init__(*inputs, **kwargs) - - def _init_weights(self, module): - """Initialize the weights.""" - if isinstance(module, (nn.Linear)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, BloomModel): - module.gradient_checkpointing = value - - BLOOM_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the