|
|
|
@ -27,12 +27,11 @@ class DistributedBloomConfig(BloomConfig):
|
|
|
|
|
|
|
|
|
|
class DistributedBloomModel(BloomModel):
|
|
|
|
|
"""BloomModel, but all transformer layers are hosted by the swarm"""
|
|
|
|
|
config_class = DistributedBloomConfig
|
|
|
|
|
|
|
|
|
|
def __init__(self, config: DistributedBloomConfig):
|
|
|
|
|
assert self.config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..."
|
|
|
|
|
assert (
|
|
|
|
|
self.config.initial_peers or config.dht
|
|
|
|
|
), "Please specify initial_peers=list(...) or dht=hivemind.DHT(...)"
|
|
|
|
|
assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..."
|
|
|
|
|
assert config.initial_peers or config.dht, "Please specify initial_peers=list(...) or dht=hivemind.DHT(...)"
|
|
|
|
|
|
|
|
|
|
n_layer, config.n_layer = config.n_layer, 0 # temporarily set n_layer to 0 to prevent layer initialization
|
|
|
|
|
super().__init__(config)
|
|
|
|
@ -50,9 +49,10 @@ class DistributedBloomModel(BloomModel):
|
|
|
|
|
|
|
|
|
|
class DistributedBloomForCausalLM(BloomForCausalLM):
|
|
|
|
|
"""DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
|
|
|
|
|
config_class = DistributedBloomConfig
|
|
|
|
|
|
|
|
|
|
def __init__(self, config: DistributedBloomConfig):
|
|
|
|
|
BloomPreTrainedModel().__init__(config)
|
|
|
|
|
BloomPreTrainedModel.__init__(self, config)
|
|
|
|
|
self.transformer = DistributedBloomModel(config)
|
|
|
|
|
# Initialize weights and apply final processing
|
|
|
|
|
self.post_init()
|
|
|
|
|