mv set_requires_grad to remote_model

pull/19/head
dbaranchuk 2 years ago
parent 5168a3405a
commit 21e1f42f04

@ -165,19 +165,12 @@ class BloomModel(BloomPreTrainedModel):
# Initialize weights and apply final processing
self.post_init()
# Forbid accumulate grads for embeddings and layernorm
self.set_requires_grad(False)
def get_input_embeddings(self):
return self.word_embeddings
def set_input_embeddings(self, new_embeddings):
self.word_embeddings = new_embeddings
def set_requires_grad(self, value):
for p in self.parameters():
p.requires_grad = value
@add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,

@ -45,6 +45,13 @@ class DistributedBloomModel(BloomModel):
)
assert isinstance(dht, hivemind.DHT) and dht.is_alive(), "dht must be a running hivemind.DHT instance"
self.h = RemoteSequential(config, dht, config.dht_prefix)
# Forbid accumulate grads for embeddings and layernorm
self.set_requires_grad(False)
def set_requires_grad(self, value):
for p in self.parameters():
p.requires_grad = value
class DistributedBloomForCausalLM(BloomForCausalLM):

Loading…
Cancel
Save