|
|
|
@ -1,6 +1,7 @@
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
import gc
|
|
|
|
|
import itertools
|
|
|
|
|
import math
|
|
|
|
|
import multiprocessing as mp
|
|
|
|
|
import random
|
|
|
|
@ -72,8 +73,8 @@ class Server:
|
|
|
|
|
prefetch_batches: int = 1,
|
|
|
|
|
sender_threads: int = 1,
|
|
|
|
|
balance_quality: float = 0.75,
|
|
|
|
|
mean_balance_check_period: float = 60,
|
|
|
|
|
mean_block_selection_delay: float = 0.5,
|
|
|
|
|
mean_balance_check_period: float = 120,
|
|
|
|
|
mean_block_selection_delay: float = 2.5,
|
|
|
|
|
use_auth_token: Optional[str] = None,
|
|
|
|
|
load_in_8bit: Optional[bool] = None,
|
|
|
|
|
**kwargs,
|
|
|
|
@ -119,7 +120,6 @@ class Server:
|
|
|
|
|
visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()]
|
|
|
|
|
if initial_peers == PUBLIC_INITIAL_PEERS:
|
|
|
|
|
logger.info(f"Connecting to the public swarm, peer_id = {self.dht.peer_id}")
|
|
|
|
|
logger.info("Please check that your server is reachable at http://health.petals.ml")
|
|
|
|
|
else:
|
|
|
|
|
logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
|
|
|
|
|
|
|
|
|
@ -157,8 +157,8 @@ class Server:
|
|
|
|
|
if attn_cache_size is None:
|
|
|
|
|
# Hidden size is 14336 for the bigscience/bloom-petals model. For other models, scale accordingly
|
|
|
|
|
attn_cache_size = 0.5 * gib * num_blocks * self.block_config.hidden_size / 14336
|
|
|
|
|
self.attn_cache_size, self.alloc_timeout = attn_cache_size, alloc_timeout
|
|
|
|
|
logger.info(f"Attention cache for all blocks will consume up to {attn_cache_size / gib:.2f} GiB")
|
|
|
|
|
self.memory_cache = MemoryCache(device, attn_cache_size, alloc_timeout)
|
|
|
|
|
|
|
|
|
|
if cache_dir is None:
|
|
|
|
|
cache_dir = DEFAULT_CACHE_DIR
|
|
|
|
@ -211,7 +211,8 @@ class Server:
|
|
|
|
|
prefix=self.prefix,
|
|
|
|
|
converted_model_name_or_path=self.converted_model_name_or_path,
|
|
|
|
|
block_config=self.block_config,
|
|
|
|
|
memory_cache=self.memory_cache,
|
|
|
|
|
attn_cache_size=self.attn_cache_size,
|
|
|
|
|
alloc_timeout=self.alloc_timeout,
|
|
|
|
|
throughput=self.throughput,
|
|
|
|
|
block_indices=block_indices,
|
|
|
|
|
num_handlers=self.num_handlers,
|
|
|
|
@ -310,7 +311,8 @@ class ModuleContainer(threading.Thread):
|
|
|
|
|
prefix: str,
|
|
|
|
|
converted_model_name_or_path: str,
|
|
|
|
|
block_config: BloomConfig,
|
|
|
|
|
memory_cache: MemoryCache,
|
|
|
|
|
attn_cache_size: int,
|
|
|
|
|
alloc_timeout: float,
|
|
|
|
|
throughput: float,
|
|
|
|
|
block_indices: List[int],
|
|
|
|
|
min_batch_size: int,
|
|
|
|
@ -339,8 +341,9 @@ class ModuleContainer(threading.Thread):
|
|
|
|
|
joining_announcer.start()
|
|
|
|
|
logger.info(f"Announced that blocks {block_indices} are joining")
|
|
|
|
|
|
|
|
|
|
memory_cache = MemoryCache(device, attn_cache_size, alloc_timeout)
|
|
|
|
|
blocks = {}
|
|
|
|
|
try:
|
|
|
|
|
blocks = {}
|
|
|
|
|
for module_uid, block_index in zip(module_uids, block_indices):
|
|
|
|
|
block = load_pretrained_block(
|
|
|
|
|
converted_model_name_or_path,
|
|
|
|
@ -380,6 +383,10 @@ class ModuleContainer(threading.Thread):
|
|
|
|
|
max_batch_size=max_batch_size,
|
|
|
|
|
)
|
|
|
|
|
except:
|
|
|
|
|
logger.debug("Shutting down backends")
|
|
|
|
|
for backend in blocks.values():
|
|
|
|
|
backend.shutdown()
|
|
|
|
|
|
|
|
|
|
joining_announcer.stop.set()
|
|
|
|
|
joining_announcer.join()
|
|
|
|
|
declare_active_modules(
|
|
|
|
@ -563,7 +570,7 @@ class ModuleAnnouncerThread(threading.Thread):
|
|
|
|
|
self.stop = threading.Event()
|
|
|
|
|
|
|
|
|
|
def run(self) -> None:
|
|
|
|
|
while True:
|
|
|
|
|
for iter_no in itertools.count():
|
|
|
|
|
declare_active_modules(
|
|
|
|
|
self.dht,
|
|
|
|
|
self.module_uids,
|
|
|
|
@ -571,5 +578,10 @@ class ModuleAnnouncerThread(threading.Thread):
|
|
|
|
|
state=self.state,
|
|
|
|
|
throughput=self.throughput,
|
|
|
|
|
)
|
|
|
|
|
if iter_no == 0 and self.state == ServerState.JOINING:
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Please ensure that your server is reachable. "
|
|
|
|
|
f"For public swarm, open http://health.petals.ml and find peer_id = {self.dht.peer_id}"
|
|
|
|
|
)
|
|
|
|
|
if self.stop.wait(self.update_period):
|
|
|
|
|
break
|
|
|
|
|