From 1c89c5c7ffca940d46f97e9ba1fc7a93b8f64610 Mon Sep 17 00:00:00 2001 From: Artem Chumachenko Date: Tue, 6 Sep 2022 20:11:39 +0300 Subject: [PATCH] fix rebase --- src/client/remote_model.py | 49 -------------------------------------- 1 file changed, 49 deletions(-) diff --git a/src/client/remote_model.py b/src/client/remote_model.py index 9bea827..158b7a1 100644 --- a/src/client/remote_model.py +++ b/src/client/remote_model.py @@ -151,55 +151,6 @@ class DistributedBloomModel(BloomModel): ) -class DistributedBloomPrefix(DistributedBloomModel): - """DistributedBloomModel with prefix tokens for prompt tuning""" - - def __init__(self, config): - super().__init__(config) - assert config.num_prefix_tokens > 0, "The number of prefix tokens must be > 0" - self.prefix_length = config.num_prefix_tokens - - self.prompt_embeddings = nn.Embedding(self.prefix_length, config.hidden_size) - self.prefix_tokens = torch.arange(self.prefix_length).long() - - def get_prompt(self, batch_size): - prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1) - prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device) - prompts = self.prompt_embeddings(prefix_tokens) - return prompts - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - **kwargs, - ): - assert ( - input_ids is None or inputs_embeds is None - ), "You cannot specify both input_ids and inputs_embeds at the same time" - assert input_ids is not None or inputs_embeds is not None, "You must specify either input_ids or inputs_embeds" - - if inputs_embeds is None: - inputs_embeds = self.word_embeddings(input_ids) - - batch_size = inputs_embeds.shape[0] - - if attention_mask is not None: - prefix_attention_mask = torch.ones(batch_size, self.prefix_length, device=attention_mask.device) - attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) - - prompts = self.get_prompt(batch_size) - inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1) - - transformer_outputs = super().forward(inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs) - - # Remove prefix - last_hidden_state = transformer_outputs[0][:, self.prefix_length :] - transformer_outputs["last_hidden_state"] = last_hidden_state - return transformer_outputs - - class DistributedBloomForCausalLM(RemoteGenerationMixin, BloomForCausalLM): """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""