|
|
|
@ -32,6 +32,7 @@ class DistributedBloomConfig(BloomConfig):
|
|
|
|
|
|
|
|
|
|
initial_peers: List[str] = PUBLIC_INITIAL_PEERS # a list of initial peers for hivemind DHT
|
|
|
|
|
dht_prefix: str # a prefix for all dht keys that correspond to this model (usually equal to model name)
|
|
|
|
|
daemon_startup_timeout: int = 30
|
|
|
|
|
dht: Optional[hivemind.DHT] = None # a running DHT instance, e.g. when using the same DHT for multiple models
|
|
|
|
|
chunk_size_for_efficient_fp16_on_cpu: int = 10000 # a chunk size for a LM head for efficient half-precision on CPU
|
|
|
|
|
pre_seq_len: int = 0 # a number of tokens for prompt tuning.
|
|
|
|
@ -95,7 +96,13 @@ class DistributedBloomModel(_LowCPUMemoryMixin, BloomModel):
|
|
|
|
|
dht = (
|
|
|
|
|
config.dht
|
|
|
|
|
if config.dht is not None
|
|
|
|
|
else hivemind.DHT(initial_peers=config.initial_peers, client_mode=True, num_workers=n_layer, start=True)
|
|
|
|
|
else hivemind.DHT(
|
|
|
|
|
initial_peers=config.initial_peers,
|
|
|
|
|
client_mode=True,
|
|
|
|
|
num_workers=n_layer,
|
|
|
|
|
startup_timeout=config.daemon_startup_timeout,
|
|
|
|
|
start=True,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
assert isinstance(dht, hivemind.DHT) and dht.is_alive(), "dht must be a running hivemind.DHT instance"
|
|
|
|
|
self.h = RemoteSequential(config, dht, config.dht_prefix, request_timeout=config.request_timeout)
|
|
|
|
|