|
|
|
@ -16,8 +16,9 @@ from hivemind.proto.runtime_pb2 import CompressionType
|
|
|
|
|
from hivemind.utils.logging import get_logger
|
|
|
|
|
from transformers import PretrainedConfig
|
|
|
|
|
|
|
|
|
|
import petals
|
|
|
|
|
from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
|
|
|
|
|
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
|
|
|
|
|
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerInfo, ServerState
|
|
|
|
|
from petals.dht_utils import declare_active_modules, get_remote_module_infos
|
|
|
|
|
from petals.server import block_selection
|
|
|
|
|
from petals.server.backend import TransformerBackend, merge_inference_pools_inplace
|
|
|
|
@ -29,7 +30,6 @@ from petals.server.reachability import ReachabilityProtocol, check_direct_reacha
|
|
|
|
|
from petals.server.throughput import get_dtype_name, get_server_throughput
|
|
|
|
|
from petals.utils.auto_config import AutoDistributedConfig
|
|
|
|
|
from petals.utils.convert_block import QuantType, check_device_balance, convert_block
|
|
|
|
|
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
|
|
|
|
|
from petals.utils.version import get_compatible_model_repo
|
|
|
|
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
@ -81,7 +81,7 @@ class Server:
|
|
|
|
|
dht_client_mode: Optional[bool] = None,
|
|
|
|
|
use_relay: bool = True,
|
|
|
|
|
use_auto_relay: bool = True,
|
|
|
|
|
adapters: Optional[List[str]] = None,
|
|
|
|
|
adapters: Sequence[str] = (),
|
|
|
|
|
**kwargs,
|
|
|
|
|
):
|
|
|
|
|
"""Create a server with one or more bloom blocks. See run_server.py for documentation."""
|
|
|
|
@ -215,7 +215,15 @@ class Server:
|
|
|
|
|
force_eval=(throughput == "eval"),
|
|
|
|
|
cache_dir=cache_dir,
|
|
|
|
|
)
|
|
|
|
|
self.throughput = throughput
|
|
|
|
|
self.server_info = ServerInfo(
|
|
|
|
|
state=ServerState.JOINING,
|
|
|
|
|
throughput=throughput,
|
|
|
|
|
adapters=tuple(adapters),
|
|
|
|
|
version=petals.__version__,
|
|
|
|
|
torch_dtype=str(torch_dtype).lstrip("torch."),
|
|
|
|
|
quant_type=quant_type.name.lower(),
|
|
|
|
|
using_relay=self.dht.client_mode,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.balance_quality = balance_quality
|
|
|
|
|
self.mean_balance_check_period = mean_balance_check_period
|
|
|
|
@ -283,7 +291,7 @@ class Server:
|
|
|
|
|
block_config=self.block_config,
|
|
|
|
|
attn_cache_bytes=self.attn_cache_bytes,
|
|
|
|
|
alloc_timeout=self.alloc_timeout,
|
|
|
|
|
throughput=self.throughput,
|
|
|
|
|
server_info=self.server_info,
|
|
|
|
|
block_indices=block_indices,
|
|
|
|
|
num_handlers=self.num_handlers,
|
|
|
|
|
min_batch_size=self.min_batch_size,
|
|
|
|
@ -307,7 +315,6 @@ class Server:
|
|
|
|
|
quant_type=self.quant_type,
|
|
|
|
|
tensor_parallel_devices=self.tensor_parallel_devices,
|
|
|
|
|
should_validate_reachability=self.should_validate_reachability,
|
|
|
|
|
adapters=self.adapters,
|
|
|
|
|
start=True,
|
|
|
|
|
)
|
|
|
|
|
try:
|
|
|
|
@ -385,7 +392,7 @@ class ModuleContainer(threading.Thread):
|
|
|
|
|
block_config: PretrainedConfig,
|
|
|
|
|
attn_cache_bytes: int,
|
|
|
|
|
alloc_timeout: float,
|
|
|
|
|
throughput: float,
|
|
|
|
|
server_info: ServerInfo,
|
|
|
|
|
block_indices: List[int],
|
|
|
|
|
min_batch_size: int,
|
|
|
|
|
max_batch_size: int,
|
|
|
|
@ -401,16 +408,18 @@ class ModuleContainer(threading.Thread):
|
|
|
|
|
quant_type: QuantType,
|
|
|
|
|
tensor_parallel_devices: Sequence[torch.device],
|
|
|
|
|
should_validate_reachability: bool,
|
|
|
|
|
adapters: Optional[List[str]] = None,
|
|
|
|
|
**kwargs,
|
|
|
|
|
) -> ModuleContainer:
|
|
|
|
|
module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices]
|
|
|
|
|
memory_cache = MemoryCache(attn_cache_bytes, alloc_timeout)
|
|
|
|
|
|
|
|
|
|
server_info.state = ServerState.JOINING
|
|
|
|
|
joining_announcer = ModuleAnnouncerThread(
|
|
|
|
|
module_uids,
|
|
|
|
|
dht,
|
|
|
|
|
ServerState.JOINING,
|
|
|
|
|
adapters=adapters,
|
|
|
|
|
throughput=throughput,
|
|
|
|
|
server_info,
|
|
|
|
|
block_config=block_config,
|
|
|
|
|
memory_cache=memory_cache,
|
|
|
|
|
update_period=update_period,
|
|
|
|
|
expiration=expiration,
|
|
|
|
|
daemon=True,
|
|
|
|
@ -420,7 +429,6 @@ class ModuleContainer(threading.Thread):
|
|
|
|
|
|
|
|
|
|
assert len(tensor_parallel_devices) >= 1 and all(isinstance(d, torch.device) for d in tensor_parallel_devices)
|
|
|
|
|
|
|
|
|
|
memory_cache = MemoryCache(attn_cache_bytes, alloc_timeout)
|
|
|
|
|
blocks = {}
|
|
|
|
|
try:
|
|
|
|
|
for module_uid, block_index in zip(module_uids, block_indices):
|
|
|
|
@ -441,7 +449,7 @@ class ModuleContainer(threading.Thread):
|
|
|
|
|
tensor_parallel_devices,
|
|
|
|
|
device,
|
|
|
|
|
quant_type,
|
|
|
|
|
adapters=adapters,
|
|
|
|
|
adapters=server_info.adapters,
|
|
|
|
|
freeze=True,
|
|
|
|
|
use_auth_token=use_auth_token,
|
|
|
|
|
cache_dir=cache_dir,
|
|
|
|
@ -477,13 +485,12 @@ class ModuleContainer(threading.Thread):
|
|
|
|
|
|
|
|
|
|
joining_announcer.stop.set()
|
|
|
|
|
joining_announcer.join()
|
|
|
|
|
server_info.state = ServerState.OFFLINE
|
|
|
|
|
declare_active_modules(
|
|
|
|
|
dht,
|
|
|
|
|
module_uids,
|
|
|
|
|
server_info,
|
|
|
|
|
expiration_time=get_dht_time() + expiration,
|
|
|
|
|
state=ServerState.OFFLINE,
|
|
|
|
|
throughput=throughput,
|
|
|
|
|
adapters=adapters,
|
|
|
|
|
)
|
|
|
|
|
logger.info(f"Announced that blocks {module_uids} are offline")
|
|
|
|
|
raise
|
|
|
|
@ -497,8 +504,9 @@ class ModuleContainer(threading.Thread):
|
|
|
|
|
dht,
|
|
|
|
|
dht_prefix,
|
|
|
|
|
blocks,
|
|
|
|
|
adapters=adapters,
|
|
|
|
|
throughput=throughput,
|
|
|
|
|
block_config=block_config,
|
|
|
|
|
memory_cache=memory_cache,
|
|
|
|
|
server_info=server_info,
|
|
|
|
|
update_period=update_period,
|
|
|
|
|
expiration=expiration,
|
|
|
|
|
**kwargs,
|
|
|
|
@ -510,10 +518,11 @@ class ModuleContainer(threading.Thread):
|
|
|
|
|
dht_prefix: str,
|
|
|
|
|
module_backends: Dict[str, TransformerBackend],
|
|
|
|
|
*,
|
|
|
|
|
block_config: PretrainedConfig,
|
|
|
|
|
memory_cache: MemoryCache,
|
|
|
|
|
inference_max_length: int,
|
|
|
|
|
num_handlers: int,
|
|
|
|
|
throughput: float,
|
|
|
|
|
adapters: Optional[Sequence[str]],
|
|
|
|
|
server_info: ServerInfo,
|
|
|
|
|
update_period: float,
|
|
|
|
|
expiration: Optional[float] = None,
|
|
|
|
|
request_timeout: float,
|
|
|
|
@ -525,7 +534,7 @@ class ModuleContainer(threading.Thread):
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
self.dht, self.module_backends = dht, module_backends
|
|
|
|
|
self.throughput, self.update_period, self.expiration = throughput, update_period, expiration
|
|
|
|
|
self.server_info, self.update_period, self.expiration = server_info, update_period, expiration
|
|
|
|
|
|
|
|
|
|
self.push_manager = mp.Manager()
|
|
|
|
|
self.push_manager.__enter__()
|
|
|
|
@ -534,7 +543,7 @@ class ModuleContainer(threading.Thread):
|
|
|
|
|
TransformerConnectionHandler(
|
|
|
|
|
dht,
|
|
|
|
|
self.module_backends,
|
|
|
|
|
adapters=adapters,
|
|
|
|
|
adapters=server_info.adapters,
|
|
|
|
|
dht_prefix=dht_prefix,
|
|
|
|
|
push_manager=self.push_manager,
|
|
|
|
|
session_queues=session_queues,
|
|
|
|
@ -548,12 +557,14 @@ class ModuleContainer(threading.Thread):
|
|
|
|
|
|
|
|
|
|
self.runtime = RuntimeWithDeduplicatedPools(self.module_backends, device=None, **kwargs)
|
|
|
|
|
# note: We set device=None in runtime to avoid moving all modules to device 0 in runtime.run(). tensor_parallel has already moved it as needed.
|
|
|
|
|
|
|
|
|
|
self.server_info.state = ServerState.ONLINE
|
|
|
|
|
self.online_announcer = ModuleAnnouncerThread(
|
|
|
|
|
list(self.module_backends.keys()),
|
|
|
|
|
dht,
|
|
|
|
|
ServerState.ONLINE,
|
|
|
|
|
adapters=adapters,
|
|
|
|
|
throughput=throughput,
|
|
|
|
|
self.server_info,
|
|
|
|
|
block_config=block_config,
|
|
|
|
|
memory_cache=memory_cache,
|
|
|
|
|
update_period=update_period,
|
|
|
|
|
expiration=expiration,
|
|
|
|
|
daemon=True,
|
|
|
|
@ -613,12 +624,12 @@ class ModuleContainer(threading.Thread):
|
|
|
|
|
self.online_announcer.stop.set()
|
|
|
|
|
self.online_announcer.join()
|
|
|
|
|
|
|
|
|
|
self.server_info.state = ServerState.OFFLINE
|
|
|
|
|
declare_active_modules(
|
|
|
|
|
self.dht,
|
|
|
|
|
self.module_backends.keys(),
|
|
|
|
|
self.server_info,
|
|
|
|
|
expiration_time=get_dht_time() + self.expiration,
|
|
|
|
|
state=ServerState.OFFLINE,
|
|
|
|
|
throughput=self.throughput,
|
|
|
|
|
)
|
|
|
|
|
logger.info(f"Announced that blocks {list(self.module_backends.keys())} are offline")
|
|
|
|
|
|
|
|
|
@ -651,10 +662,10 @@ class ModuleAnnouncerThread(threading.Thread):
|
|
|
|
|
self,
|
|
|
|
|
module_uids: List[str],
|
|
|
|
|
dht: DHT,
|
|
|
|
|
state: ServerState,
|
|
|
|
|
adapters: Optional[Sequence[str]],
|
|
|
|
|
server_info: ServerInfo,
|
|
|
|
|
*,
|
|
|
|
|
throughput: float,
|
|
|
|
|
block_config: PretrainedConfig,
|
|
|
|
|
memory_cache: MemoryCache,
|
|
|
|
|
update_period: float = 30,
|
|
|
|
|
expiration: float,
|
|
|
|
|
**kwargs,
|
|
|
|
@ -662,22 +673,21 @@ class ModuleAnnouncerThread(threading.Thread):
|
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
|
self.module_uids = module_uids
|
|
|
|
|
self.dht = dht
|
|
|
|
|
self.state = state
|
|
|
|
|
self.adapters = adapters
|
|
|
|
|
self.throughput = throughput
|
|
|
|
|
self.server_info = server_info
|
|
|
|
|
self.memory_cache = memory_cache
|
|
|
|
|
self.bytes_per_token = block_config.hidden_size * torch.finfo(DTYPE_MAP[server_info.torch_dtype]).bits // 8
|
|
|
|
|
self.update_period = update_period
|
|
|
|
|
self.expiration = expiration
|
|
|
|
|
self.stop = threading.Event()
|
|
|
|
|
|
|
|
|
|
def run(self) -> None:
|
|
|
|
|
while True:
|
|
|
|
|
self.server_info.cache_tokens_left = self.memory_cache.bytes_left // self.bytes_per_token
|
|
|
|
|
declare_active_modules(
|
|
|
|
|
self.dht,
|
|
|
|
|
self.module_uids,
|
|
|
|
|
self.server_info,
|
|
|
|
|
expiration_time=get_dht_time() + self.expiration,
|
|
|
|
|
state=self.state,
|
|
|
|
|
throughput=self.throughput,
|
|
|
|
|
adapters=self.adapters,
|
|
|
|
|
)
|
|
|
|
|
if self.stop.wait(self.update_period):
|
|
|
|
|
break
|
|
|
|
|