|
|
|
@ -1,4 +1,4 @@
|
|
|
|
|
# this code is in active development, interfaces may change
|
|
|
|
|
from contextlib import contextmanager
|
|
|
|
|
from typing import List, Optional
|
|
|
|
|
|
|
|
|
|
import hivemind
|
|
|
|
@ -38,9 +38,35 @@ class DistributedBloomConfig(BloomConfig):
|
|
|
|
|
tuning_mode: Optional[str] = None # One of the finetune options: [None, 'shallow_ptune', 'deep_ptune', 'adapters']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
original_register_parameter = nn.Module.register_parameter
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
|
def force_non_empty_weights():
|
|
|
|
|
"""
|
|
|
|
|
This context manager allows to bypass the accelerate.init_empty_weights() context manager
|
|
|
|
|
(that forces all nn.Parameters to be PyTorch's meta tensors) used when low_cpu_mem_usage=True.
|
|
|
|
|
The transformers library should replace all meta tensors by empty tensors by itself
|
|
|
|
|
but this feature does not work due to a bug ([1] fails if `add_prefix_to_model == True`).
|
|
|
|
|
|
|
|
|
|
[1] https://github.com/huggingface/transformers/blob/ab9fe45236cd99b8797df78219438f8f6662bb42/src/transformers/modeling_utils.py#L2515
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
possibly_patched_register_parameter = nn.Module.register_parameter
|
|
|
|
|
nn.Module.register_parameter = original_register_parameter
|
|
|
|
|
yield
|
|
|
|
|
finally:
|
|
|
|
|
nn.Module.register_parameter = possibly_patched_register_parameter
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DistributedBloomModel(BloomModel):
|
|
|
|
|
"""BloomModel, but all transformer layers are hosted by the swarm"""
|
|
|
|
|
|
|
|
|
|
_keys_to_ignore_on_load_missing = BloomModel._keys_to_ignore_on_load_missing + [
|
|
|
|
|
r"^(intermediate_)?prompt_embeddings\.weight$",
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
config_class = DistributedBloomConfig
|
|
|
|
|
|
|
|
|
|
def __init__(self, config: DistributedBloomConfig):
|
|
|
|
@ -66,16 +92,16 @@ class DistributedBloomModel(BloomModel):
|
|
|
|
|
if config.tuning_mode and "ptune" in config.tuning_mode:
|
|
|
|
|
assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
|
|
|
|
|
self.pre_seq_len = config.pre_seq_len
|
|
|
|
|
self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size)
|
|
|
|
|
self.prefix_tokens = torch.arange(self.pre_seq_len).long()
|
|
|
|
|
|
|
|
|
|
if config.tuning_mode == "deep_ptune":
|
|
|
|
|
self.intermediate_prompt_embeddings = nn.Embedding(
|
|
|
|
|
self.pre_seq_len,
|
|
|
|
|
config.num_hidden_layers * config.hidden_size
|
|
|
|
|
# ^-- TODO: should be num_hidden_layers - 1
|
|
|
|
|
)
|
|
|
|
|
self.intermediate_prompt_embeddings.weight.data.zero_()
|
|
|
|
|
with force_non_empty_weights():
|
|
|
|
|
self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size)
|
|
|
|
|
if config.tuning_mode == "deep_ptune":
|
|
|
|
|
self.intermediate_prompt_embeddings = nn.Embedding(
|
|
|
|
|
self.pre_seq_len,
|
|
|
|
|
config.num_hidden_layers * config.hidden_size
|
|
|
|
|
# ^-- TODO: should be num_hidden_layers - 1
|
|
|
|
|
)
|
|
|
|
|
elif config.tuning_mode:
|
|
|
|
|
raise NotImplementedError(f"{self.tuning_mode} mode is not supported for now")
|
|
|
|
|
|
|
|
|
@ -155,6 +181,12 @@ class DistributedBloomModel(BloomModel):
|
|
|
|
|
class DistributedBloomForCausalLM(RemoteGenerationMixin, BloomForCausalLM):
|
|
|
|
|
"""DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
|
|
|
|
|
|
|
|
|
|
_keys_to_ignore_on_load_missing = (
|
|
|
|
|
BloomForCausalLM._keys_to_ignore_on_load_missing
|
|
|
|
|
+ DistributedBloomModel._keys_to_ignore_on_load_missing
|
|
|
|
|
+ [r"^lm_head.word_embeddings\.weight$"] # Missing since they are shared with input embeddings
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
config_class = DistributedBloomConfig
|
|
|
|
|
|
|
|
|
|
def __init__(self, config: DistributedBloomConfig):
|
|
|
|
@ -185,6 +217,11 @@ class DistributedBloomForCausalLM(RemoteGenerationMixin, BloomForCausalLM):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DistributedBloomForSequenceClassification(BloomForSequenceClassification):
|
|
|
|
|
_keys_to_ignore_on_load_missing = (
|
|
|
|
|
BloomForSequenceClassification._keys_to_ignore_on_load_missing
|
|
|
|
|
+ DistributedBloomModel._keys_to_ignore_on_load_missing
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
config_class = DistributedBloomConfig
|
|
|
|
|
|
|
|
|
|
def __init__(self, config: DistributedBloomConfig):
|
|
|
|
|