You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

75 lines
2.8 KiB

PyTorch BLOOM model that implements several memory-efficient modes.
Based on
See commit history for authorship.
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from hivemind import use_hivemind_log_handler
from torch import nn
from transformers import BloomConfig
from transformers.utils import logging
logger = logging.get_logger(__file__)
class LMHead(nn.Module):
The modified language modeling head which does not create extra tensor for the linear layer with weights tied to the input
embeddings. Thus, it reduces initial memory consumption which might be crucial for large dictionaries.
In addition, it provides an effcient way to deal with half-precision word embeddings on CPU.
def __init__(self, config: BloomConfig, word_embeddings: nn.Embedding):
self.word_embeddings = word_embeddings
self.chunk_size = config.chunk_size_for_efficient_fp16_on_cpu
def in_features(self) -> int:
return self.word_embeddings.num_embeddings
def out_features(self) -> int:
return self.word_embeddings.embedding_dim
def weight(self):
return self.word_embeddings.weight
def bias(self):
return None
def forward(self, hidden_states):
word_embeddings = self.word_embeddings.weight
# We use 'chunked_forward' only when embeddings are in half-precision on CPU.
if word_embeddings.dtype in [torch.float16, torch.bfloat16] and word_embeddings.device.type == "cpu":
lm_logits = self.chunked_forward(hidden_states)
# Switch dtype in case word_embeddings are fp16/bf16
hidden_states =
lm_logits = F.linear(hidden_states, word_embeddings)
return lm_logits
def chunked_forward(self, hidden_states):
"""Splits word embeddings on chunks and iteratively casts them into fp32 to perform matmul more efficiently on CPU.
chunk_size: provides trade-off between efficiency and extra memory consumption.
assert self.chunk_size > 0, "Chunk size for chunked forward must be positive"
word_embeddings = self.word_embeddings.weight
num_embeddings = self.word_embeddings.num_embeddings
hidden_states = hidden_states.float()
output = torch.empty(*hidden_states.shape[:-1], num_embeddings)
for i in range(0, num_embeddings, self.chunk_size):
chunk = word_embeddings[i : i + self.chunk_size].float()
output[..., i : i + self.chunk_size] = F.linear(hidden_states, chunk)
return output