refactoring

pull/19/head
dbaranchuk 2 years ago
parent 6bffeff0a1
commit 79280c4371

@ -3,7 +3,7 @@ PyTorch BLOOM model that implements several memory-efficient modes.
Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
See commit history for authorship.
"""
from typing import Tuple
from typing import Tuple, Union
import torch
import torch.nn.functional as F
@ -319,16 +319,16 @@ class BloomForCausalLM(BloomPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.transformer = BloomModel(config)
self.lm_head = LMHeadForCausalLM(config, self.transformer.word_embeddings)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.lm_head.word_embeddings
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head.word_embeddings.weight = new_embeddings.weight
self.lm_head = new_embeddings
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
# only last token for inputs_ids if past is defined in kwargs
@ -361,7 +361,20 @@ class BloomForCausalLM(BloomPreTrainedModel):
output_type=CausalLMOutputWithCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
def forward(self, input_ids=None, labels=None, return_dict=None, **kwargs):
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
@ -369,8 +382,21 @@ class BloomForCausalLM(BloomPreTrainedModel):
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer.forward(input_ids=input_ids, return_dict=return_dict, **kwargs)
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
loss = None
@ -415,7 +441,7 @@ class BloomForCausalLM(BloomPreTrainedModel):
""",
BLOOM_START_DOCSTRING,
)
class LMHeadForCausalLM(nn.Module):
class LMHead(nn.Module):
def __init__(self, config, word_embeddings: nn.Embedding):
super().__init__()
self.word_embeddings = word_embeddings

@ -5,7 +5,7 @@ from typing import Optional, Tuple
import hivemind
from hivemind import get_logger, use_hivemind_log_handler
from src.bloom.model import BloomConfig, BloomForCausalLM, BloomModel, BloomPreTrainedModel, LMHeadForCausalLM
from src.bloom.model import BloomConfig, BloomForCausalLM, BloomModel, BloomPreTrainedModel, LMHead
from src.client.remote_sequential import RemoteSequential
from src.data_structures import UID_DELIMITER
@ -54,6 +54,12 @@ class DistributedBloomForCausalLM(BloomForCausalLM):
def __init__(self, config: DistributedBloomConfig):
BloomPreTrainedModel.__init__(self, config)
self.transformer = DistributedBloomModel(config)
self.lm_head = LMHeadForCausalLM(config, self.transformer.word_embeddings)
self.lm_head = LMHead(config, self.transformer.word_embeddings)
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.lm_head.word_embeddings
def set_output_embeddings(self, new_embeddings):
self.lm_head.word_embeddings.weight = new_embeddings.weight

Loading…
Cancel
Save