Fix llama's lm_head.weight.requires_grad (#330)

By default, `llama's lm_head.weight.requires_grad` was True, but we expect it to be False.
pull/332/head
Alexander Borzunov 11 months ago committed by GitHub
parent 7a37513f77
commit 47a2b1ee65
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -26,7 +26,8 @@ class LMHead(nn.Module):
super().__init__()
if not config.tie_word_embeddings:
self.weight = nn.Parameter(torch.zeros((config.vocab_size, config.hidden_size), requires_grad=False))
self.weight = nn.Parameter(torch.zeros(config.vocab_size, config.hidden_size))
self.weight.requires_grad = False
else:
self.weight = None # Will be set to get_input_embeddings().weight during loading the model
self.bias = None

@ -40,10 +40,6 @@ class PTuneMixin:
elif config.tuning_mode:
raise NotImplementedError(f"{self.tuning_mode} mode is not supported for now")
def set_requires_grad(self, value):
for p in self.parameters():
p.requires_grad = value
def get_prompt(self, batch_size):
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device)

@ -35,7 +35,7 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
self.h = RemoteSequential(config, dht=dht)
self.set_requires_grad(False) # Forbid accumulate grads for embeddings and layernorm
self.requires_grad_(False) # Forbid accumulate grads for embeddings and layernorm
self.init_prompts(config)
def forward(

@ -33,7 +33,7 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
self.layers = RemoteSequential(config, dht=dht)
self.set_requires_grad(False) # Forbid accumulate grads for embeddings and layernorm
self.requires_grad_(False) # Forbid accumulate grads for embeddings and layernorm
self.init_prompts(config)
def forward(

Loading…
Cancel
Save