@ -1,4 +1,4 @@
# this code is in active development, interfaces may change
from contextlib import contextmanager
from typing import List , Optional
import hivemind
@ -38,9 +38,35 @@ class DistributedBloomConfig(BloomConfig):
tuning_mode : Optional [ str ] = None # One of the finetune options: [None, 'shallow_ptune', 'deep_ptune', 'adapters']
original_register_parameter = nn . Module . register_parameter
@contextmanager
def force_non_empty_weights ( ) :
"""
This context manager allows to bypass the accelerate . init_empty_weights ( ) context manager
( that forces all nn . Parameters to be PyTorch ' s meta tensors) used when low_cpu_mem_usage=True.
The transformers library should replace all meta tensors by empty tensors by itself
but this feature does not work due to a bug ( [ 1 ] fails if ` add_prefix_to_model == True ` ) .
[ 1 ] https : / / github . com / huggingface / transformers / blob / ab9fe45236cd99b8797df78219438f8f6662bb42 / src / transformers / modeling_utils . py #L2515
"""
try :
possibly_patched_register_parameter = nn . Module . register_parameter
nn . Module . register_parameter = original_register_parameter
yield
finally :
nn . Module . register_parameter = possibly_patched_register_parameter
class DistributedBloomModel ( BloomModel ) :
""" BloomModel, but all transformer layers are hosted by the swarm """
_keys_to_ignore_on_load_missing = BloomModel . _keys_to_ignore_on_load_missing + [
r " ^(intermediate_)?prompt_embeddings \ .weight$ " ,
]
config_class = DistributedBloomConfig
def __init__ ( self , config : DistributedBloomConfig ) :
@ -66,16 +92,22 @@ class DistributedBloomModel(BloomModel):
if config . tuning_mode and " ptune " in config . tuning_mode :
assert config . pre_seq_len > 0 , " The number of prefix tokens must be > 0 "
self . pre_seq_len = config . pre_seq_len
self . prompt_embeddings = nn . Embedding ( self . pre_seq_len , config . hidden_size )
self . prefix_tokens = torch . arange ( self . pre_seq_len ) . long ( )
if config . tuning_mode == " deep_ptune " :
self . intermediate_prompt_embeddings = nn . Embedding (
self . pre_seq_len ,
config . num_hidden_layers * config . hidden_size
# ^-- TODO: should be num_hidden_layers - 1
)
self . intermediate_prompt_embeddings . weight . data . zero_ ( )
with force_non_empty_weights ( ) :
if self . word_embeddings_layernorm . weight . dtype in ( torch . float16 , torch . bfloat16 ) :
logger . info (
" Prompt embeddings and their optimizer statistics will be kept in float32 "
" to increase ptune quality "
)
self . prompt_embeddings = nn . Embedding ( self . pre_seq_len , config . hidden_size , dtype = torch . float32 )
if config . tuning_mode == " deep_ptune " :
self . intermediate_prompt_embeddings = nn . Embedding (
self . pre_seq_len ,
config . num_hidden_layers * config . hidden_size ,
# ^-- TODO: should be num_hidden_layers - 1
dtype = torch . float32 ,
)
elif config . tuning_mode :
raise NotImplementedError ( f " { self . tuning_mode } mode is not supported for now " )
@ -96,7 +128,9 @@ class DistributedBloomModel(BloomModel):
intermediate_prompts = intermediate_prompts . permute ( [ 2 , 0 , 1 , 3 ] )
else :
intermediate_prompts = DUMMY
return prompts , intermediate_prompts
dtype = self . word_embeddings . weight . dtype
return prompts . to ( dtype ) , intermediate_prompts . to ( dtype )
def forward (
self ,
@ -155,6 +189,12 @@ class DistributedBloomModel(BloomModel):
class DistributedBloomForCausalLM ( RemoteGenerationMixin , BloomForCausalLM ) :
""" DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm """
_keys_to_ignore_on_load_missing = (
BloomForCausalLM . _keys_to_ignore_on_load_missing
+ DistributedBloomModel . _keys_to_ignore_on_load_missing
+ [ r " ^lm_head.word_embeddings \ .weight$ " ] # Missing since they are shared with input embeddings
)
config_class = DistributedBloomConfig
def __init__ ( self , config : DistributedBloomConfig ) :
@ -185,6 +225,11 @@ class DistributedBloomForCausalLM(RemoteGenerationMixin, BloomForCausalLM):
class DistributedBloomForSequenceClassification ( BloomForSequenceClassification ) :
_keys_to_ignore_on_load_missing = (
BloomForSequenceClassification . _keys_to_ignore_on_load_missing
+ DistributedBloomModel . _keys_to_ignore_on_load_missing
)
config_class = DistributedBloomConfig
def __init__ ( self , config : DistributedBloomConfig ) :