Share more info about a server in DHT (#355)

pull/357/head
Alexander Borzunov 10 months ago committed by GitHub
parent 37fdcb3fe0
commit 2c8959e713
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -38,7 +38,7 @@ install_requires =
tokenizers>=0.13.3
transformers>=4.30.1,<5.0.0
speedtest-cli==2.1.3
pydantic>=1.8.1,<2.0 # 2.0 is incompatible with hivemind==1.1.8
pydantic>=1.10,<2.0 # 2.0 is incompatible with hivemind==1.1.8
hivemind==1.1.8
tensor_parallel==1.0.23
humanfriendly

@ -11,7 +11,7 @@ from petals.models import *
from petals.utils import *
from petals.utils.logging import initialize_logs as _initialize_logs
__version__ = "1.2.0.dev1"
__version__ = "1.2.0.dev2"
if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):

@ -146,8 +146,9 @@ def main():
help="Skip checking this server's reachability via health.petals.ml "
"when connecting to the public swarm. If you connect to a private swarm, "
"the check is skipped by default. Use this option only if you know what you are doing")
parser.add_argument("--adapters", nargs='+', default=None, help="List of pretrained LoRA adapters that can be used for inference or training.")
parser.add_argument("--adapters", nargs='+', default=(),
help="List of pre-loaded LoRA adapters that can be used for inference or training")
# fmt:on
args = vars(parser.parse_args())

@ -1,10 +1,8 @@
from __future__ import annotations
import dataclasses
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional, Sequence, Tuple
import pydantic
from hivemind import PeerID
from hivemind.moe.expert_uid import ExpertUID
@ -21,13 +19,32 @@ class ServerState(Enum):
ONLINE = 2
@dataclass
@pydantic.dataclasses.dataclass
class ServerInfo:
state: ServerState
throughput: float
throughput: pydantic.confloat(ge=0, allow_inf_nan=False, strict=True)
adapters: Sequence[str] = ()
version: Optional[str] = None
torch_dtype: Optional[str] = None
quant_type: Optional[str] = None
using_relay: Optional[bool] = None
cache_tokens_left: Optional[pydantic.conint(ge=0, strict=True)] = None
def to_tuple(self) -> Tuple[int, float, dict]:
extra_info = dataclasses.asdict(self)
del extra_info["state"], extra_info["throughput"]
return (self.state.value, self.throughput, extra_info)
@classmethod
def from_tuple(cls, source: tuple):
state, throughput = source[:2]
extra_info = source[2] if len(source) > 2 else {}
# pydantic will validate existing fields and ignore extra ones
return cls(state=ServerState(state), throughput=throughput, **extra_info)
@dataclass
@dataclasses.dataclass
class RemoteModuleInfo:
"""A remote module that is served by one or more servers"""
@ -35,7 +52,7 @@ class RemoteModuleInfo:
servers: Dict[PeerID, ServerInfo]
@dataclass
@dataclasses.dataclass
class RemoteSpanInfo:
"""A chain of remote blocks served by one specific remote peer"""

@ -11,7 +11,7 @@ from hivemind.dht import DHT, DHTNode, DHTValue
from hivemind.p2p import PeerID
from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo, ServerState
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo
logger = get_logger(__name__)
@ -19,10 +19,8 @@ logger = get_logger(__name__)
def declare_active_modules(
dht: DHT,
uids: Sequence[ModuleUID],
server_info: ServerInfo,
expiration_time: DHTExpiration,
state: ServerState,
throughput: float,
adapters: Optional[Sequence[str]] = None,
wait: bool = True,
) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]:
"""
@ -42,14 +40,7 @@ def declare_active_modules(
assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid
return dht.run_coroutine(
partial(
_declare_active_modules,
uids=uids,
expiration_time=expiration_time,
state=state,
throughput=throughput,
adapters=list(adapters or []),
),
partial(_declare_active_modules, uids=uids, server_info=server_info, expiration_time=expiration_time),
return_future=not wait,
)
@ -58,16 +49,14 @@ async def _declare_active_modules(
dht: DHT,
node: DHTNode,
uids: List[ModuleUID],
server_info: ServerInfo,
expiration_time: DHTExpiration,
state: ServerState,
throughput: float,
adapters: List[str],
) -> Dict[ModuleUID, bool]:
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
return await node.store_many(
keys=uids,
subkeys=[dht.peer_id.to_base58()] * len(uids),
values=[(state.value, throughput, dict(adapters=adapters))] * len(uids),
values=[server_info.to_tuple()] * len(uids),
expiration_time=expiration_time,
num_workers=num_workers,
)
@ -115,29 +104,21 @@ async def _get_remote_module_infos(
metadata = found[uid]
if metadata is None or not isinstance(metadata.value, dict):
if metadata is not None:
logger.error(f"Incorrect metadata for {uid}: {metadata}")
logger.warning(f"Incorrect metadata for {uid}: {metadata}")
continue
servers = {}
for peer_id, server_info in metadata.value.items():
try:
peer_id = PeerID.from_base58(peer_id)
state, throughput = server_info.value[:2]
extra_info = server_info.value[2] if len(server_info.value) > 2 else {}
adapters = extra_info.get("adapters", [])
if bool(active_adapter) and active_adapter not in adapters:
server_info = ServerInfo.from_tuple(server_info.value)
if active_adapter and active_adapter not in server_info.adapters:
logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}")
continue
if not (
isinstance(state, int)
and isinstance(throughput, float)
and math.isfinite(throughput)
and throughput >= 0.0
):
raise ValueError(f"Invalid server info: {server_info}")
servers[peer_id] = ServerInfo(ServerState(state), throughput)
servers[peer_id] = server_info
except (TypeError, ValueError) as e:
logger.error(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}")
logger.warning(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}")
if servers:
modules[i] = RemoteModuleInfo(uid, servers)
return modules

@ -9,8 +9,6 @@ from petals.client.lm_head import LMHeadConfig
from petals.client.ptune import PTuneConfig
from petals.client.routing.sequence_manager import SequenceManagerConfig
from petals.models.bloom.block import WrappedBloomBlock
from petals.utils.auto_config import AutoDistributedConfig
from petals.utils.version import get_compatible_model_repo
logger = get_logger(__name__)

@ -9,7 +9,6 @@ from petals.client.lm_head import LMHeadConfig
from petals.client.ptune import PTuneConfig
from petals.client.routing.sequence_manager import SequenceManagerConfig
from petals.models.llama.block import WrappedLlamaBlock
from petals.utils.auto_config import AutoDistributedConfig
logger = get_logger(__name__)
@ -31,8 +30,7 @@ class DistributedLlamaConfig(LlamaConfig, SequenceManagerConfig, PTuneConfig, LM
loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
if loading_from_repo and dht_prefix is None:
dht_prefix = str(model_name_or_path)
if "/" in dht_prefix: # If present, strip repository name to merge blocks hosted by different accounts
dht_prefix = dht_prefix[dht_prefix.rfind("/") + 1 :]
dht_prefix = dht_prefix.split("/")[-1] # Use only repo name to merge blocks hosted by different accounts
if not dht_prefix.endswith("-hf"):
dht_prefix += "-hf"
logger.info(f"Using DHT prefix: {dht_prefix}")

@ -562,11 +562,10 @@ class TransformerConnectionHandler(ConnectionHandler):
"""Return metadata about stored block uids and current load"""
backend = self.module_backends[request.uid] if request.uid else next(iter(self.module_backends.values()))
cache_bytes_left = max(0, backend.memory_cache.max_size_bytes - backend.memory_cache.current_size_bytes)
result = {
"version": petals.__version__,
"dht_client_mode": self.dht.client_mode,
CACHE_TOKENS_AVAILABLE: cache_bytes_left // max(backend.cache_bytes_per_token.values()),
CACHE_TOKENS_AVAILABLE: backend.memory_cache.bytes_left // max(backend.cache_bytes_per_token.values()),
}
if request.uid:

@ -47,6 +47,10 @@ class MemoryCache:
def current_size_bytes(self, value: int):
self._current_size.value = value
@property
def bytes_left(self) -> int:
return self.max_size_bytes - self.current_size_bytes
@property
def handle_counter(self) -> int:
return self._handle_counter.value

@ -16,8 +16,9 @@ from hivemind.proto.runtime_pb2 import CompressionType
from hivemind.utils.logging import get_logger
from transformers import PretrainedConfig
import petals
from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerInfo, ServerState
from petals.dht_utils import declare_active_modules, get_remote_module_infos
from petals.server import block_selection
from petals.server.backend import TransformerBackend, merge_inference_pools_inplace
@ -29,7 +30,6 @@ from petals.server.reachability import ReachabilityProtocol, check_direct_reacha
from petals.server.throughput import get_dtype_name, get_server_throughput
from petals.utils.auto_config import AutoDistributedConfig
from petals.utils.convert_block import QuantType, check_device_balance, convert_block
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
from petals.utils.version import get_compatible_model_repo
logger = get_logger(__name__)
@ -81,7 +81,7 @@ class Server:
dht_client_mode: Optional[bool] = None,
use_relay: bool = True,
use_auto_relay: bool = True,
adapters: Optional[List[str]] = None,
adapters: Sequence[str] = (),
**kwargs,
):
"""Create a server with one or more bloom blocks. See run_server.py for documentation."""
@ -215,7 +215,15 @@ class Server:
force_eval=(throughput == "eval"),
cache_dir=cache_dir,
)
self.throughput = throughput
self.server_info = ServerInfo(
state=ServerState.JOINING,
throughput=throughput,
adapters=tuple(adapters),
version=petals.__version__,
torch_dtype=str(torch_dtype).lstrip("torch."),
quant_type=quant_type.name.lower(),
using_relay=self.dht.client_mode,
)
self.balance_quality = balance_quality
self.mean_balance_check_period = mean_balance_check_period
@ -283,7 +291,7 @@ class Server:
block_config=self.block_config,
attn_cache_bytes=self.attn_cache_bytes,
alloc_timeout=self.alloc_timeout,
throughput=self.throughput,
server_info=self.server_info,
block_indices=block_indices,
num_handlers=self.num_handlers,
min_batch_size=self.min_batch_size,
@ -307,7 +315,6 @@ class Server:
quant_type=self.quant_type,
tensor_parallel_devices=self.tensor_parallel_devices,
should_validate_reachability=self.should_validate_reachability,
adapters=self.adapters,
start=True,
)
try:
@ -385,7 +392,7 @@ class ModuleContainer(threading.Thread):
block_config: PretrainedConfig,
attn_cache_bytes: int,
alloc_timeout: float,
throughput: float,
server_info: ServerInfo,
block_indices: List[int],
min_batch_size: int,
max_batch_size: int,
@ -401,16 +408,18 @@ class ModuleContainer(threading.Thread):
quant_type: QuantType,
tensor_parallel_devices: Sequence[torch.device],
should_validate_reachability: bool,
adapters: Optional[List[str]] = None,
**kwargs,
) -> ModuleContainer:
module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices]
memory_cache = MemoryCache(attn_cache_bytes, alloc_timeout)
server_info.state = ServerState.JOINING
joining_announcer = ModuleAnnouncerThread(
module_uids,
dht,
ServerState.JOINING,
adapters=adapters,
throughput=throughput,
server_info,
block_config=block_config,
memory_cache=memory_cache,
update_period=update_period,
expiration=expiration,
daemon=True,
@ -420,7 +429,6 @@ class ModuleContainer(threading.Thread):
assert len(tensor_parallel_devices) >= 1 and all(isinstance(d, torch.device) for d in tensor_parallel_devices)
memory_cache = MemoryCache(attn_cache_bytes, alloc_timeout)
blocks = {}
try:
for module_uid, block_index in zip(module_uids, block_indices):
@ -441,7 +449,7 @@ class ModuleContainer(threading.Thread):
tensor_parallel_devices,
device,
quant_type,
adapters=adapters,
adapters=server_info.adapters,
freeze=True,
use_auth_token=use_auth_token,
cache_dir=cache_dir,
@ -477,13 +485,12 @@ class ModuleContainer(threading.Thread):
joining_announcer.stop.set()
joining_announcer.join()
server_info.state = ServerState.OFFLINE
declare_active_modules(
dht,
module_uids,
server_info,
expiration_time=get_dht_time() + expiration,
state=ServerState.OFFLINE,
throughput=throughput,
adapters=adapters,
)
logger.info(f"Announced that blocks {module_uids} are offline")
raise
@ -497,8 +504,9 @@ class ModuleContainer(threading.Thread):
dht,
dht_prefix,
blocks,
adapters=adapters,
throughput=throughput,
block_config=block_config,
memory_cache=memory_cache,
server_info=server_info,
update_period=update_period,
expiration=expiration,
**kwargs,
@ -510,10 +518,11 @@ class ModuleContainer(threading.Thread):
dht_prefix: str,
module_backends: Dict[str, TransformerBackend],
*,
block_config: PretrainedConfig,
memory_cache: MemoryCache,
inference_max_length: int,
num_handlers: int,
throughput: float,
adapters: Optional[Sequence[str]],
server_info: ServerInfo,
update_period: float,
expiration: Optional[float] = None,
request_timeout: float,
@ -525,7 +534,7 @@ class ModuleContainer(threading.Thread):
super().__init__()
self.dht, self.module_backends = dht, module_backends
self.throughput, self.update_period, self.expiration = throughput, update_period, expiration
self.server_info, self.update_period, self.expiration = server_info, update_period, expiration
self.push_manager = mp.Manager()
self.push_manager.__enter__()
@ -534,7 +543,7 @@ class ModuleContainer(threading.Thread):
TransformerConnectionHandler(
dht,
self.module_backends,
adapters=adapters,
adapters=server_info.adapters,
dht_prefix=dht_prefix,
push_manager=self.push_manager,
session_queues=session_queues,
@ -548,12 +557,14 @@ class ModuleContainer(threading.Thread):
self.runtime = RuntimeWithDeduplicatedPools(self.module_backends, device=None, **kwargs)
# note: We set device=None in runtime to avoid moving all modules to device 0 in runtime.run(). tensor_parallel has already moved it as needed.
self.server_info.state = ServerState.ONLINE
self.online_announcer = ModuleAnnouncerThread(
list(self.module_backends.keys()),
dht,
ServerState.ONLINE,
adapters=adapters,
throughput=throughput,
self.server_info,
block_config=block_config,
memory_cache=memory_cache,
update_period=update_period,
expiration=expiration,
daemon=True,
@ -613,12 +624,12 @@ class ModuleContainer(threading.Thread):
self.online_announcer.stop.set()
self.online_announcer.join()
self.server_info.state = ServerState.OFFLINE
declare_active_modules(
self.dht,
self.module_backends.keys(),
self.server_info,
expiration_time=get_dht_time() + self.expiration,
state=ServerState.OFFLINE,
throughput=self.throughput,
)
logger.info(f"Announced that blocks {list(self.module_backends.keys())} are offline")
@ -651,10 +662,10 @@ class ModuleAnnouncerThread(threading.Thread):
self,
module_uids: List[str],
dht: DHT,
state: ServerState,
adapters: Optional[Sequence[str]],
server_info: ServerInfo,
*,
throughput: float,
block_config: PretrainedConfig,
memory_cache: MemoryCache,
update_period: float = 30,
expiration: float,
**kwargs,
@ -662,22 +673,21 @@ class ModuleAnnouncerThread(threading.Thread):
super().__init__(**kwargs)
self.module_uids = module_uids
self.dht = dht
self.state = state
self.adapters = adapters
self.throughput = throughput
self.server_info = server_info
self.memory_cache = memory_cache
self.bytes_per_token = block_config.hidden_size * torch.finfo(DTYPE_MAP[server_info.torch_dtype]).bits // 8
self.update_period = update_period
self.expiration = expiration
self.stop = threading.Event()
def run(self) -> None:
while True:
self.server_info.cache_tokens_left = self.memory_cache.bytes_left // self.bytes_per_token
declare_active_modules(
self.dht,
self.module_uids,
self.server_info,
expiration_time=get_dht_time() + self.expiration,
state=self.state,
throughput=self.throughput,
adapters=self.adapters,
)
if self.stop.wait(self.update_period):
break

@ -2,7 +2,7 @@
Tools for converting transformer blocks, applying quantization and/or tensor parallelism
"""
import re
from typing import List, Optional, Sequence
from typing import Optional, Sequence
import tensor_parallel as tp
import torch
@ -25,7 +25,7 @@ def convert_block(
output_device: torch.device,
quant_type: QuantType,
freeze: bool = True,
adapters: Optional[List[str]] = None,
adapters: Optional[Sequence[str]] = None,
**kwargs,
) -> tp.TensorParallel:
"""

Loading…
Cancel
Save