Fix ptune with `low_cpu_mem_usage=True` (as in Colab) (#103)

Fixes:

- An exception while creating a model with `ptune/deep_ptune` and `low_cpu_mem_usage=True` (which is currently default).
- dtype mismatch between the prompts and the rest of the model in `.forward()`.
pull/109/head
Alexander Borzunov 2 years ago committed by GitHub
parent 43ac6016ac
commit 0a1cd3b9ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,4 +1,4 @@
# this code is in active development, interfaces may change
from contextlib import contextmanager
from typing import List, Optional
import hivemind
@ -38,9 +38,35 @@ class DistributedBloomConfig(BloomConfig):
tuning_mode: Optional[str] = None # One of the finetune options: [None, 'shallow_ptune', 'deep_ptune', 'adapters']
original_register_parameter = nn.Module.register_parameter
@contextmanager
def force_non_empty_weights():
"""
This context manager allows to bypass the accelerate.init_empty_weights() context manager
(that forces all nn.Parameters to be PyTorch's meta tensors) used when low_cpu_mem_usage=True.
The transformers library should replace all meta tensors by empty tensors by itself
but this feature does not work due to a bug ([1] fails if `add_prefix_to_model == True`).
[1] https://github.com/huggingface/transformers/blob/ab9fe45236cd99b8797df78219438f8f6662bb42/src/transformers/modeling_utils.py#L2515
"""
try:
possibly_patched_register_parameter = nn.Module.register_parameter
nn.Module.register_parameter = original_register_parameter
yield
finally:
nn.Module.register_parameter = possibly_patched_register_parameter
class DistributedBloomModel(BloomModel):
"""BloomModel, but all transformer layers are hosted by the swarm"""
_keys_to_ignore_on_load_missing = BloomModel._keys_to_ignore_on_load_missing + [
r"^(intermediate_)?prompt_embeddings\.weight$",
]
config_class = DistributedBloomConfig
def __init__(self, config: DistributedBloomConfig):
@ -66,16 +92,22 @@ class DistributedBloomModel(BloomModel):
if config.tuning_mode and "ptune" in config.tuning_mode:
assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
self.pre_seq_len = config.pre_seq_len
self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size)
self.prefix_tokens = torch.arange(self.pre_seq_len).long()
if config.tuning_mode == "deep_ptune":
self.intermediate_prompt_embeddings = nn.Embedding(
self.pre_seq_len,
config.num_hidden_layers * config.hidden_size
# ^-- TODO: should be num_hidden_layers - 1
)
self.intermediate_prompt_embeddings.weight.data.zero_()
with force_non_empty_weights():
if self.word_embeddings_layernorm.weight.dtype in (torch.float16, torch.bfloat16):
logger.info(
"Prompt embeddings and their optimizer statistics will be kept in float32 "
"to increase ptune quality"
)
self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size, dtype=torch.float32)
if config.tuning_mode == "deep_ptune":
self.intermediate_prompt_embeddings = nn.Embedding(
self.pre_seq_len,
config.num_hidden_layers * config.hidden_size,
# ^-- TODO: should be num_hidden_layers - 1
dtype=torch.float32,
)
elif config.tuning_mode:
raise NotImplementedError(f"{self.tuning_mode} mode is not supported for now")
@ -96,7 +128,9 @@ class DistributedBloomModel(BloomModel):
intermediate_prompts = intermediate_prompts.permute([2, 0, 1, 3])
else:
intermediate_prompts = DUMMY
return prompts, intermediate_prompts
dtype = self.word_embeddings.weight.dtype
return prompts.to(dtype), intermediate_prompts.to(dtype)
def forward(
self,
@ -155,6 +189,12 @@ class DistributedBloomModel(BloomModel):
class DistributedBloomForCausalLM(RemoteGenerationMixin, BloomForCausalLM):
"""DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
_keys_to_ignore_on_load_missing = (
BloomForCausalLM._keys_to_ignore_on_load_missing
+ DistributedBloomModel._keys_to_ignore_on_load_missing
+ [r"^lm_head.word_embeddings\.weight$"] # Missing since they are shared with input embeddings
)
config_class = DistributedBloomConfig
def __init__(self, config: DistributedBloomConfig):
@ -185,6 +225,11 @@ class DistributedBloomForCausalLM(RemoteGenerationMixin, BloomForCausalLM):
class DistributedBloomForSequenceClassification(BloomForSequenceClassification):
_keys_to_ignore_on_load_missing = (
BloomForSequenceClassification._keys_to_ignore_on_load_missing
+ DistributedBloomModel._keys_to_ignore_on_load_missing
)
config_class = DistributedBloomConfig
def __init__(self, config: DistributedBloomConfig):

Loading…
Cancel
Save