|
|
|
@ -151,55 +151,6 @@ class DistributedBloomModel(BloomModel):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DistributedBloomPrefix(DistributedBloomModel):
|
|
|
|
|
"""DistributedBloomModel with prefix tokens for prompt tuning"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, config):
|
|
|
|
|
super().__init__(config)
|
|
|
|
|
assert config.num_prefix_tokens > 0, "The number of prefix tokens must be > 0"
|
|
|
|
|
self.prefix_length = config.num_prefix_tokens
|
|
|
|
|
|
|
|
|
|
self.prompt_embeddings = nn.Embedding(self.prefix_length, config.hidden_size)
|
|
|
|
|
self.prefix_tokens = torch.arange(self.prefix_length).long()
|
|
|
|
|
|
|
|
|
|
def get_prompt(self, batch_size):
|
|
|
|
|
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
|
|
|
|
|
prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device)
|
|
|
|
|
prompts = self.prompt_embeddings(prefix_tokens)
|
|
|
|
|
return prompts
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
|
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
|
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
|
**kwargs,
|
|
|
|
|
):
|
|
|
|
|
assert (
|
|
|
|
|
input_ids is None or inputs_embeds is None
|
|
|
|
|
), "You cannot specify both input_ids and inputs_embeds at the same time"
|
|
|
|
|
assert input_ids is not None or inputs_embeds is not None, "You must specify either input_ids or inputs_embeds"
|
|
|
|
|
|
|
|
|
|
if inputs_embeds is None:
|
|
|
|
|
inputs_embeds = self.word_embeddings(input_ids)
|
|
|
|
|
|
|
|
|
|
batch_size = inputs_embeds.shape[0]
|
|
|
|
|
|
|
|
|
|
if attention_mask is not None:
|
|
|
|
|
prefix_attention_mask = torch.ones(batch_size, self.prefix_length, device=attention_mask.device)
|
|
|
|
|
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
|
|
|
|
|
|
|
|
|
|
prompts = self.get_prompt(batch_size)
|
|
|
|
|
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
|
|
|
|
|
|
|
|
|
|
transformer_outputs = super().forward(inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs)
|
|
|
|
|
|
|
|
|
|
# Remove prefix
|
|
|
|
|
last_hidden_state = transformer_outputs[0][:, self.prefix_length :]
|
|
|
|
|
transformer_outputs["last_hidden_state"] = last_hidden_state
|
|
|
|
|
return transformer_outputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DistributedBloomForCausalLM(RemoteGenerationMixin, BloomForCausalLM):
|
|
|
|
|
"""DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
|
|
|
|
|
|
|
|
|
|