You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
petals/src/petals/bloom/modeling_utils.py

73 lines
2.7 KiB
Python

"""
PyTorch BLOOM model that implements several memory-efficient modes.
Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
See commit history for authorship.
"""
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from hivemind import get_logger
from torch import nn
from transformers import BloomConfig
logger = get_logger(__file__)
class LMHead(nn.Module):
"""
The modified language modeling head which does not create extra tensor for the linear layer with weights tied to the input
embeddings. Thus, it reduces initial memory consumption which might be crucial for large dictionaries.
In addition, it provides an effcient way to deal with half-precision word embeddings on CPU.
"""
def __init__(self, config: BloomConfig, word_embeddings: nn.Embedding):
super().__init__()
self.word_embeddings = word_embeddings
self.chunk_size = config.chunk_size_for_efficient_fp16_on_cpu
@property
def in_features(self) -> int:
return self.word_embeddings.num_embeddings
@property
def out_features(self) -> int:
return self.word_embeddings.embedding_dim
@property
def weight(self):
return self.word_embeddings.weight
@property
def bias(self):
return None
def forward(self, hidden_states):
word_embeddings = self.word_embeddings.weight
# We use 'chunked_forward' only when embeddings are in half-precision on CPU.
if word_embeddings.dtype in [torch.float16, torch.bfloat16] and word_embeddings.device.type == "cpu":
lm_logits = self.chunked_forward(hidden_states)
else:
# Switch dtype in case word_embeddings are fp16/bf16
hidden_states = hidden_states.to(word_embeddings.dtype)
lm_logits = F.linear(hidden_states, word_embeddings)
return lm_logits
def chunked_forward(self, hidden_states):
"""Splits word embeddings on chunks and iteratively casts them into fp32 to perform matmul more efficiently on CPU.
chunk_size: provides trade-off between efficiency and extra memory consumption.
"""
assert self.chunk_size > 0, "Chunk size for chunked forward must be positive"
word_embeddings = self.word_embeddings.weight
num_embeddings = self.word_embeddings.num_embeddings
hidden_states = hidden_states.float()
output = torch.empty(*hidden_states.shape[:-1], num_embeddings)
for i in range(0, num_embeddings, self.chunk_size):
chunk = word_embeddings[i : i + self.chunk_size].float()
output[..., i : i + self.chunk_size] = F.linear(hidden_states, chunk)
return output