lm_head
dbaranchuk 2 years ago
parent b3cc9e0d99
commit 6bffeff0a1

@ -319,7 +319,7 @@ class BloomForCausalLM(BloomPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.transformer = BloomModel(config) self.transformer = BloomModel(config)
self.lm_head = LMHeadForCausalLM(config) self.lm_head = LMHeadForCausalLM(config, self.transformer.word_embeddings)
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
@ -328,7 +328,7 @@ class BloomForCausalLM(BloomPreTrainedModel):
return self.lm_head.word_embeddings return self.lm_head.word_embeddings
def set_output_embeddings(self, new_embeddings): def set_output_embeddings(self, new_embeddings):
self.lm_head.word_embeddings = new_embeddings.weight self.lm_head.word_embeddings.weight = new_embeddings.weight
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
# only last token for inputs_ids if past is defined in kwargs # only last token for inputs_ids if past is defined in kwargs
@ -418,18 +418,19 @@ class BloomForCausalLM(BloomPreTrainedModel):
class LMHeadForCausalLM(nn.Module): class LMHeadForCausalLM(nn.Module):
def __init__(self, config, word_embeddings: nn.Embedding): def __init__(self, config, word_embeddings: nn.Embedding):
super().__init__() super().__init__()
self.word_embeddings = word_embeddings.weight self.word_embeddings = word_embeddings
self.chunk_size = config.chunk_size_for_efficient_fp16_on_cpu self.chunk_size = config.chunk_size_for_efficient_fp16_on_cpu
def forward(self, hidden_states): def forward(self, hidden_states):
if self.word_embeddings.dtype in [torch.float16, torch.bfloat16] and \ word_embeddings = self.word_embeddings.weight
'cpu' in self.word_embeddings.device: if word_embeddings.dtype in [torch.float16, torch.bfloat16] and \
word_embeddings.device.type == 'cpu':
# We use 'chunked_forward' only for half-precision computations on CPU. # We use 'chunked_forward' only for half-precision computations on CPU.
lm_logits = self.chunked_forward(hidden_states) lm_logits = self.chunked_forward(hidden_states)
else: else:
# Switch dtype in case word_embeddings are fp16/bf16 # Switch dtype in case word_embeddings are fp16/bf16
hidden_states = hidden_states.to(self.word_embeddings.dtype) hidden_states = hidden_states.to(word_embeddings.dtype)
lm_logits = F.linear(hidden_states, self.word_embeddings).float() lm_logits = F.linear(hidden_states, word_embeddings).float()
return lm_logits return lm_logits
def chunked_forward(self, hidden_states): def chunked_forward(self, hidden_states):
@ -438,11 +439,13 @@ class LMHeadForCausalLM(nn.Module):
""" """
assert self.chunk_size > 0, "Chunk size for chunked forward must be positive" assert self.chunk_size > 0, "Chunk size for chunked forward must be positive"
word_embeddings = self.word_embeddings.weight
hidden_states = hidden_states.float() hidden_states = hidden_states.float()
num_embeddings = self.word_embeddings.shape[1] num_embeddings = word_embeddings.shape[0]
output = torch.zeros(*hidden_states.shape[:-1], num_embeddings) output = torch.zeros(*hidden_states.shape[:-1], num_embeddings)
for i in range(0, num_embeddings, self.chunk_size): for i in range(0, num_embeddings, self.chunk_size):
chunk = self.word_embeddings[..., i:i+self.chunk_size].float() chunk = word_embeddings[i:i+self.chunk_size].float()
output[..., i:i+self.chunk_size] = F.linear(hidden_states, chunk) output[..., i:i+self.chunk_size] = F.linear(hidden_states, chunk)
return output return output

@ -54,6 +54,6 @@ class DistributedBloomForCausalLM(BloomForCausalLM):
def __init__(self, config: DistributedBloomConfig): def __init__(self, config: DistributedBloomConfig):
BloomPreTrainedModel.__init__(self, config) BloomPreTrainedModel.__init__(self, config)
self.transformer = DistributedBloomModel(config) self.transformer = DistributedBloomModel(config)
self.lm_head = LMHeadForCausalLM(config) self.lm_head = LMHeadForCausalLM(config, self.transformer.word_embeddings)
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()

Loading…
Cancel
Save