|
|
@ -77,7 +77,8 @@ class DistributedFalconModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMix
|
|
|
|
if inputs_embeds is None:
|
|
|
|
if inputs_embeds is None:
|
|
|
|
inputs_embeds = self.word_embeddings(input_ids)
|
|
|
|
inputs_embeds = self.word_embeddings(input_ids)
|
|
|
|
|
|
|
|
|
|
|
|
if self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0:
|
|
|
|
use_prompts = self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0
|
|
|
|
|
|
|
|
if use_prompts:
|
|
|
|
batch_size = inputs_embeds.shape[0]
|
|
|
|
batch_size = inputs_embeds.shape[0]
|
|
|
|
prompts, intermediate_prompts = self.get_prompt(batch_size)
|
|
|
|
prompts, intermediate_prompts = self.get_prompt(batch_size)
|
|
|
|
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
|
|
|
|
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
|
|
|
@ -94,7 +95,7 @@ class DistributedFalconModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMix
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# Remove prefix
|
|
|
|
# Remove prefix
|
|
|
|
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
|
|
|
|
if use_prompts:
|
|
|
|
hidden_states = hidden_states[:, self.pre_seq_len :]
|
|
|
|
hidden_states = hidden_states[:, self.pre_seq_len :]
|
|
|
|
|
|
|
|
|
|
|
|
# Add last hidden state
|
|
|
|
# Add last hidden state
|
|
|
|