|
|
|
@ -95,12 +95,18 @@ class DistributedBloomModel(BloomModel):
|
|
|
|
|
self.prefix_tokens = torch.arange(self.pre_seq_len).long()
|
|
|
|
|
|
|
|
|
|
with force_non_empty_weights():
|
|
|
|
|
self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size)
|
|
|
|
|
if self.word_embeddings_layernorm.weight.dtype in (torch.float16, torch.bfloat16):
|
|
|
|
|
logger.info(
|
|
|
|
|
"Prompt embeddings and their optimizer statistics will be kept in float32 "
|
|
|
|
|
"to increase ptune quality"
|
|
|
|
|
)
|
|
|
|
|
self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size, dtype=torch.float32)
|
|
|
|
|
if config.tuning_mode == "deep_ptune":
|
|
|
|
|
self.intermediate_prompt_embeddings = nn.Embedding(
|
|
|
|
|
self.pre_seq_len,
|
|
|
|
|
config.num_hidden_layers * config.hidden_size
|
|
|
|
|
config.num_hidden_layers * config.hidden_size,
|
|
|
|
|
# ^-- TODO: should be num_hidden_layers - 1
|
|
|
|
|
dtype=torch.float32,
|
|
|
|
|
)
|
|
|
|
|
elif config.tuning_mode:
|
|
|
|
|
raise NotImplementedError(f"{self.tuning_mode} mode is not supported for now")
|
|
|
|
@ -122,7 +128,9 @@ class DistributedBloomModel(BloomModel):
|
|
|
|
|
intermediate_prompts = intermediate_prompts.permute([2, 0, 1, 3])
|
|
|
|
|
else:
|
|
|
|
|
intermediate_prompts = DUMMY
|
|
|
|
|
return prompts, intermediate_prompts
|
|
|
|
|
|
|
|
|
|
dtype = self.word_embeddings.weight.dtype
|
|
|
|
|
return prompts.to(dtype), intermediate_prompts.to(dtype)
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|