|
|
|
@ -36,6 +36,7 @@ class DistributedBloomConfig(BloomConfig):
|
|
|
|
|
chunk_size_for_efficient_fp16_on_cpu: int = 10000 # a chunk size for a LM head for efficient half-precision on CPU
|
|
|
|
|
pre_seq_len: int = 0 # a number of tokens for prompt tuning.
|
|
|
|
|
tuning_mode: Optional[str] = None # One of the finetune options: [None, 'shallow_ptune', 'deep_ptune', 'adapters']
|
|
|
|
|
request_timeout: int = 20 # a number of seconds for waiting result from each node
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
original_register_parameter = nn.Module.register_parameter
|
|
|
|
@ -84,7 +85,7 @@ class DistributedBloomModel(BloomModel):
|
|
|
|
|
else hivemind.DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
|
|
|
|
|
)
|
|
|
|
|
assert isinstance(dht, hivemind.DHT) and dht.is_alive(), "dht must be a running hivemind.DHT instance"
|
|
|
|
|
self.h = RemoteSequential(config, dht, config.dht_prefix)
|
|
|
|
|
self.h = RemoteSequential(config, dht, config.dht_prefix, request_timeout=config.request_timeout)
|
|
|
|
|
|
|
|
|
|
# Forbid accumulate grads for embeddings and layernorm
|
|
|
|
|
self.set_requires_grad(False)
|
|
|
|
|