Shallow prompt tuning (#22)

pull/24/head
Dmitry Baranchuk 2 years ago committed by GitHub
parent 7e9f337a63
commit f5463812ad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -10,10 +10,15 @@ import torch.nn.functional as F
import torch.utils.checkpoint
from hivemind import use_hivemind_log_handler
from torch import nn
from torch.nn import CrossEntropyLoss, LayerNorm
from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss, LayerNorm
from transformers.file_utils import (add_code_sample_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward)
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.models.bloom.configuration_bloom import BloomConfig
from transformers.utils import logging
@ -469,3 +474,130 @@ class LMHead(nn.Module):
chunk = word_embeddings[i: i + self.chunk_size].float()
output[..., i: i + self.chunk_size] = F.linear(hidden_states, chunk)
return output
@add_start_docstrings(
"""
The Bloom Model transformer with a sequence classification head on top (linear layer).
[`BloomForSequenceClassification`] uses the last token in order to do the classification, as other causal models
(e.g. GPT-1) do.
Since it does classification on the last token, it requires to know the position of the last token. If a
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
each row of the batch).
""",
BLOOM_START_DOCSTRING,
)
class BloomForSequenceClassification(BloomPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.transformer = BloomModel(config)
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=SequenceClassifierOutputWithPast,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)
if input_ids is not None:
batch_size = input_ids.shape[0]
else:
batch_size = inputs_embeds.shape[0]
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
else:
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)

@ -2,10 +2,13 @@
import os
from typing import Optional, Tuple
import torch
import torch.nn as nn
import hivemind
from hivemind import get_logger, use_hivemind_log_handler
from src.bloom.model import BloomConfig, BloomForCausalLM, BloomModel, BloomPreTrainedModel, LMHead
from src.bloom.model import BloomConfig, BloomForCausalLM, BloomModel, BloomPreTrainedModel, LMHead, BloomForSequenceClassification
from src.client.remote_sequential import RemoteSequential
from src.data_structures import UID_DELIMITER
@ -22,7 +25,8 @@ class DistributedBloomConfig(BloomConfig):
initial_peers: Tuple[str, ...] = () # a list of initial peers for hivemind DHT
dht_prefix: str # a prefix for all dht keys that correspond to this model (usually equal to model name)
dht: Optional[hivemind.DHT] = None # a running DHT instance, e.g. when using the same DHT for multiple models
chunk_size_for_efficient_fp16_on_cpu: int = 10000 # a chunk size for a LM head for efficient half-precision on CPU
chunk_size_for_efficient_fp16_on_cpu: int = 10000 # a chunk size for a LM head for efficient half-precision on CPU
num_prefix_tokens: int = 0 # a number of tokens for prompt tuning.
class DistributedBloomModel(BloomModel):
@ -54,14 +58,81 @@ class DistributedBloomModel(BloomModel):
p.requires_grad = value
class DistributedBloomPrefix(DistributedBloomModel):
"""DistributedBloomModel with prefix tokens for prompt tuning"""
def __init__(self, config):
super().__init__(config)
assert config.num_prefix_tokens > 0, "The number of prefix tokens must be > 0"
self.prefix_length = config.num_prefix_tokens
self.prompt_embeddings = nn.Embedding(self.prefix_length, config.hidden_size)
self.prefix_tokens = torch.arange(self.prefix_length).long()
def get_prompt(self, batch_size):
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device)
prompts = self.prompt_embeddings(prefix_tokens)
return prompts
def forward(
self,
input_ids: Optional[torch.LongTensor],
inputs_embeds: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_values=None,
position_ids=None,
head_mask=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None
):
assert input_ids is None or inputs_embeds is None, "You cannot specify both input_ids and inputs_embeds at the same time"
assert input_ids is not None or inputs_embeds is not None, "You must specify either input_ids or inputs_embeds"
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
batch_size = inputs_embeds.shape[0]
if attention_mask is not None:
prefix_attention_mask = torch.ones(batch_size, self.prefix_length, device=attention_mask.device)
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
prompts = self.get_prompt(batch_size)
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
transformer_outputs = super().forward(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
past_key_values=past_key_values,
position_ids=position_ids,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict
)
# Remove prefix
last_hidden_state = transformer_outputs[0][:, self.prefix_length:]
transformer_outputs['last_hidden_state'] = last_hidden_state
return transformer_outputs
class DistributedBloomForCausalLM(BloomForCausalLM):
"""DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
config_class = DistributedBloomConfig
def __init__(self, config: DistributedBloomConfig):
BloomPreTrainedModel.__init__(self, config)
self.transformer = DistributedBloomModel(config)
if config.num_prefix_tokens > 0:
self.transformer = DistributedBloomPrefix(config)
else:
self.transformer = DistributedBloomModel(config)
self.lm_head = LMHead(config, self.transformer.word_embeddings)
# Initialize weights and apply final processing
self.post_init()
@ -70,3 +141,17 @@ class DistributedBloomForCausalLM(BloomForCausalLM):
def set_output_embeddings(self, new_embeddings):
self.lm_head.word_embeddings.weight = new_embeddings.weight
class DistributedBloomForSequenceClassification(BloomForSequenceClassification):
config_class = DistributedBloomConfig
def __init__(self, config: DistributedBloomConfig):
super().__init__(config)
if config.num_prefix_tokens > 0:
self.transformer = DistributedBloomPrefix(config)
else:
self.transformer = DistributedBloomModel(config)
# Initialize weights and apply final processing
self.post_init()

Loading…
Cancel
Save