Set dht.num_workers = n_layer, update_period = 150, expiration = 300 (#125)

pull/128/head
Alexander Borzunov 1 year ago committed by GitHub
parent 3ca8b4f082
commit 9dbf5e2e6f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -73,7 +73,7 @@ def main():
'If set to "auto" (default), the script evaluates network and compute throughput '
'on the first run and uses these estimates for future runs. '
'If set to "eval", the script re-evaluates the throughput and overrides the cache.')
parser.add_argument('--update_period', type=float, required=False, default=30,
parser.add_argument('--update_period', type=float, required=False, default=150,
help='Server will report blocks to DHT once in this many seconds')
parser.add_argument('--expiration', type=float, required=False, default=None,
help='DHT entries will expire after this many seconds')

@ -82,7 +82,7 @@ class DistributedBloomModel(BloomModel):
dht = (
config.dht
if config.dht is not None
else hivemind.DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
else hivemind.DHT(initial_peers=config.initial_peers, client_mode=True, num_workers=n_layer, start=True)
)
assert isinstance(dht, hivemind.DHT) and dht.is_alive(), "dht must be a running hivemind.DHT instance"
self.h = RemoteSequential(config, dht, config.dht_prefix, request_timeout=config.request_timeout)

@ -61,7 +61,7 @@ class Server:
compression=CompressionType.NONE,
stats_report_interval: Optional[int] = None,
custom_module_path=None,
update_period: float = 30,
update_period: float = 150,
expiration: Optional[float] = None,
request_timeout: float = 3 * 60,
session_timeout: float = 30 * 60,
@ -106,7 +106,14 @@ class Server:
self.request_timeout = request_timeout
self.session_timeout, self.step_timeout = session_timeout, step_timeout
self.dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
self.block_config = BloomConfig.from_pretrained(
converted_model_name_or_path,
use_auth_token=use_auth_token,
revision=revision,
)
self.module_uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]
self.dht = DHT(initial_peers=initial_peers, start=True, num_workers=self.block_config.n_layer, **kwargs)
visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()]
if initial_peers == PUBLIC_INITIAL_PEERS:
logger.info("Connecting to the public Petals swarm")
@ -124,13 +131,6 @@ class Server:
logger.info("Model weights will be loaded in 8-bit format")
self.load_in_8bit = load_in_8bit
self.block_config = BloomConfig.from_pretrained(
converted_model_name_or_path,
use_auth_token=use_auth_token,
revision=revision,
)
self.module_uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]
assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both"
if num_blocks is None and block_indices is None:
num_blocks = self._choose_num_blocks()

Loading…
Cancel
Save