|
|
|
@ -61,7 +61,7 @@ class Server:
|
|
|
|
|
compression=CompressionType.NONE,
|
|
|
|
|
stats_report_interval: Optional[int] = None,
|
|
|
|
|
custom_module_path=None,
|
|
|
|
|
update_period: float = 30,
|
|
|
|
|
update_period: float = 150,
|
|
|
|
|
expiration: Optional[float] = None,
|
|
|
|
|
request_timeout: float = 3 * 60,
|
|
|
|
|
session_timeout: float = 30 * 60,
|
|
|
|
@ -106,7 +106,14 @@ class Server:
|
|
|
|
|
self.request_timeout = request_timeout
|
|
|
|
|
self.session_timeout, self.step_timeout = session_timeout, step_timeout
|
|
|
|
|
|
|
|
|
|
self.dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
|
|
|
|
|
self.block_config = BloomConfig.from_pretrained(
|
|
|
|
|
converted_model_name_or_path,
|
|
|
|
|
use_auth_token=use_auth_token,
|
|
|
|
|
revision=revision,
|
|
|
|
|
)
|
|
|
|
|
self.module_uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]
|
|
|
|
|
|
|
|
|
|
self.dht = DHT(initial_peers=initial_peers, start=True, num_workers=self.block_config.n_layer, **kwargs)
|
|
|
|
|
visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()]
|
|
|
|
|
if initial_peers == PUBLIC_INITIAL_PEERS:
|
|
|
|
|
logger.info("Connecting to the public Petals swarm")
|
|
|
|
@ -124,13 +131,6 @@ class Server:
|
|
|
|
|
logger.info("Model weights will be loaded in 8-bit format")
|
|
|
|
|
self.load_in_8bit = load_in_8bit
|
|
|
|
|
|
|
|
|
|
self.block_config = BloomConfig.from_pretrained(
|
|
|
|
|
converted_model_name_or_path,
|
|
|
|
|
use_auth_token=use_auth_token,
|
|
|
|
|
revision=revision,
|
|
|
|
|
)
|
|
|
|
|
self.module_uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]
|
|
|
|
|
|
|
|
|
|
assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both"
|
|
|
|
|
if num_blocks is None and block_indices is None:
|
|
|
|
|
num_blocks = self._choose_num_blocks()
|
|
|
|
|