pull/19/head
dbaranchuk 2 years ago
parent b3cc9e0d99
commit 6bffeff0a1

@ -319,7 +319,7 @@ class BloomForCausalLM(BloomPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.transformer = BloomModel(config)
self.lm_head = LMHeadForCausalLM(config)
self.lm_head = LMHeadForCausalLM(config, self.transformer.word_embeddings)
# Initialize weights and apply final processing
self.post_init()
@ -328,7 +328,7 @@ class BloomForCausalLM(BloomPreTrainedModel):
return self.lm_head.word_embeddings
def set_output_embeddings(self, new_embeddings):
self.lm_head.word_embeddings = new_embeddings.weight
self.lm_head.word_embeddings.weight = new_embeddings.weight
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
# only last token for inputs_ids if past is defined in kwargs
@ -418,18 +418,19 @@ class BloomForCausalLM(BloomPreTrainedModel):
class LMHeadForCausalLM(nn.Module):
def __init__(self, config, word_embeddings: nn.Embedding):
super().__init__()
self.word_embeddings = word_embeddings.weight
self.word_embeddings = word_embeddings
self.chunk_size = config.chunk_size_for_efficient_fp16_on_cpu
def forward(self, hidden_states):
if self.word_embeddings.dtype in [torch.float16, torch.bfloat16] and \
'cpu' in self.word_embeddings.device:
word_embeddings = self.word_embeddings.weight
if word_embeddings.dtype in [torch.float16, torch.bfloat16] and \
word_embeddings.device.type == 'cpu':
# We use 'chunked_forward' only for half-precision computations on CPU.
lm_logits = self.chunked_forward(hidden_states)
else:
# Switch dtype in case word_embeddings are fp16/bf16
hidden_states = hidden_states.to(self.word_embeddings.dtype)
lm_logits = F.linear(hidden_states, self.word_embeddings).float()
hidden_states = hidden_states.to(word_embeddings.dtype)
lm_logits = F.linear(hidden_states, word_embeddings).float()
return lm_logits
def chunked_forward(self, hidden_states):
@ -438,11 +439,13 @@ class LMHeadForCausalLM(nn.Module):
"""
assert self.chunk_size > 0, "Chunk size for chunked forward must be positive"
word_embeddings = self.word_embeddings.weight
hidden_states = hidden_states.float()
num_embeddings = self.word_embeddings.shape[1]
num_embeddings = word_embeddings.shape[0]
output = torch.zeros(*hidden_states.shape[:-1], num_embeddings)
for i in range(0, num_embeddings, self.chunk_size):
chunk = self.word_embeddings[..., i:i+self.chunk_size].float()
chunk = word_embeddings[i:i+self.chunk_size].float()
output[..., i:i+self.chunk_size] = F.linear(hidden_states, chunk)
return output

@ -54,6 +54,6 @@ class DistributedBloomForCausalLM(BloomForCausalLM):
def __init__(self, config: DistributedBloomConfig):
BloomPreTrainedModel.__init__(self, config)
self.transformer = DistributedBloomModel(config)
self.lm_head = LMHeadForCausalLM(config)
self.lm_head = LMHeadForCausalLM(config, self.transformer.word_embeddings)
# Initialize weights and apply final processing
self.post_init()

Loading…
Cancel
Save