From 47a2b1ee65a6fe2cabc3e262c116c04149fe231b Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Sat, 24 Jun 2023 02:30:13 +0400 Subject: [PATCH] Fix llama's lm_head.weight.requires_grad (#330) By default, `llama's lm_head.weight.requires_grad` was True, but we expect it to be False. --- src/petals/client/lm_head.py | 3 ++- src/petals/client/ptune.py | 4 ---- src/petals/models/bloom/model.py | 2 +- src/petals/models/llama/model.py | 2 +- 4 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/petals/client/lm_head.py b/src/petals/client/lm_head.py index ddd2887..938d6da 100644 --- a/src/petals/client/lm_head.py +++ b/src/petals/client/lm_head.py @@ -26,7 +26,8 @@ class LMHead(nn.Module): super().__init__() if not config.tie_word_embeddings: - self.weight = nn.Parameter(torch.zeros((config.vocab_size, config.hidden_size), requires_grad=False)) + self.weight = nn.Parameter(torch.zeros(config.vocab_size, config.hidden_size)) + self.weight.requires_grad = False else: self.weight = None # Will be set to get_input_embeddings().weight during loading the model self.bias = None diff --git a/src/petals/client/ptune.py b/src/petals/client/ptune.py index 5cf613c..684cc23 100644 --- a/src/petals/client/ptune.py +++ b/src/petals/client/ptune.py @@ -40,10 +40,6 @@ class PTuneMixin: elif config.tuning_mode: raise NotImplementedError(f"{self.tuning_mode} mode is not supported for now") - def set_requires_grad(self, value): - for p in self.parameters(): - p.requires_grad = value - def get_prompt(self, batch_size): prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1) prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device) diff --git a/src/petals/models/bloom/model.py b/src/petals/models/bloom/model.py index fae9faf..e4961d3 100644 --- a/src/petals/models/bloom/model.py +++ b/src/petals/models/bloom/model.py @@ -35,7 +35,7 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel): self.h = RemoteSequential(config, dht=dht) - self.set_requires_grad(False) # Forbid accumulate grads for embeddings and layernorm + self.requires_grad_(False) # Forbid accumulate grads for embeddings and layernorm self.init_prompts(config) def forward( diff --git a/src/petals/models/llama/model.py b/src/petals/models/llama/model.py index 37b4683..244207b 100644 --- a/src/petals/models/llama/model.py +++ b/src/petals/models/llama/model.py @@ -33,7 +33,7 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel): self.layers = RemoteSequential(config, dht=dht) - self.set_requires_grad(False) # Forbid accumulate grads for embeddings and layernorm + self.requires_grad_(False) # Forbid accumulate grads for embeddings and layernorm self.init_prompts(config) def forward(