standardize
justheuristic 2 years ago
parent 2e90ac30a0
commit 1c68670d06

@ -27,12 +27,11 @@ class DistributedBloomConfig(BloomConfig):
class DistributedBloomModel(BloomModel):
"""BloomModel, but all transformer layers are hosted by the swarm"""
config_class = DistributedBloomConfig
def __init__(self, config: DistributedBloomConfig):
assert self.config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..."
assert (
self.config.initial_peers or config.dht
), "Please specify initial_peers=list(...) or dht=hivemind.DHT(...)"
assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..."
assert config.initial_peers or config.dht, "Please specify initial_peers=list(...) or dht=hivemind.DHT(...)"
n_layer, config.n_layer = config.n_layer, 0 # temporarily set n_layer to 0 to prevent layer initialization
super().__init__(config)
@ -50,9 +49,10 @@ class DistributedBloomModel(BloomModel):
class DistributedBloomForCausalLM(BloomForCausalLM):
"""DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
config_class = DistributedBloomConfig
def __init__(self, config: DistributedBloomConfig):
BloomPreTrainedModel().__init__(config)
BloomPreTrainedModel.__init__(self, config)
self.transformer = DistributedBloomModel(config)
# Initialize weights and apply final processing
self.post_init()

Loading…
Cancel
Save