LM head for CausalLM & chunked forward

pull/19/head
dbaranchuk 2 years ago
parent f055135b08
commit df42822f26

@ -319,14 +319,16 @@ class BloomForCausalLM(BloomPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.transformer = BloomModel(config)
self.lm_head = LMHeadForCausalLM(config)
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.transformer.word_embeddings
return self.lm_head.word_embeddings
def set_output_embeddings(self, new_embeddings):
self.transformer.word_embeddings.weight = new_embeddings.weight
self.lm_head.word_embeddings = new_embeddings.weight
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
# only last token for inputs_ids if past is defined in kwargs
@ -368,11 +370,8 @@ class BloomForCausalLM(BloomPreTrainedModel):
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer.forward(input_ids=input_ids, return_dict=return_dict, **kwargs)
word_embeddings = self.transformer.word_embeddings.weight
# Switch dtype in case word_embeddings are fp16/bf16
hidden_states = transformer_outputs[0].to(word_embeddings.dtype)
lm_logits = F.linear(hidden_states, word_embeddings).float()
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
@ -406,3 +405,44 @@ class BloomForCausalLM(BloomPreTrainedModel):
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past
)
@add_start_docstrings(
"""
The modified language modeling head which does not create extra tensor for the linear layer with weights tied to the input
embeddings. It reduces initial memory consumption which might be crucial for large dictionaries. In addition, it provides
an effcient way to perform half-precision calculations on CPU.
""",
BLOOM_START_DOCSTRING,
)
class LMHeadForCausalLM(nn.Module):
def __init__(self, config, word_embeddings: nn.Embedding):
super().__init__()
self.word_embeddings = word_embeddings.weight
self.chunk_size = config.chunk_size_for_efficient_fp16_on_cpu
def forward(self, hidden_states):
if self.word_embeddings.dtype in [torch.float16, torch.bfloat16] and \
'cpu' in self.word_embeddings.device:
# We use 'chunked_forward' only for half-precision computations on CPU.
lm_logits = self.chunked_forward(hidden_states)
else:
# Switch dtype in case word_embeddings are fp16/bf16
hidden_states = hidden_states.to(self.word_embeddings.dtype)
lm_logits = F.linear(hidden_states, self.word_embeddings).float()
return lm_logits
def chunked_forward(self, hidden_states):
""" Splits word embeddings on chunks and iteratively casts them into fp32 to perform matmul more efficiently on CPU.
chunk_size: provides trade-off between efficiency and extra memory consumption.
"""
assert self.chunk_size > 0, "Chunk size for chunked forward must be positive"
hidden_states = hidden_states.float()
num_embeddings = self.word_embeddings.shape[1]
output = torch.zeros(*hidden_states.shape[:-1], num_embeddings)
for i in range(0, num_embeddings, self.chunk_size):
chunk = self.word_embeddings[..., i:i+self.chunk_size].float()
output[..., i:i+self.chunk_size] = F.linear(hidden_states, chunk)
return output

@ -23,6 +23,7 @@ class DistributedBloomConfig(BloomConfig):
initial_peers: Tuple[str, ...] = () # a list of initial peers for hivemind DHT
dht_prefix: str # a prefix for all dht keys that correspond to this model (usually equal to model name)
dht: Optional[hivemind.DHT] = None # a running DHT instance, e.g. when using the same DHT for multiple models
chunk_size_for_efficient_fp16_on_cpu: int = 10000 # a chunk size for a LM head for efficient half-precision on CPU
class DistributedBloomModel(BloomModel):

Loading…
Cancel
Save