diff --git a/setup.cfg b/setup.cfg index c6ae594..0053628 100644 --- a/setup.cfg +++ b/setup.cfg @@ -38,7 +38,7 @@ install_requires = tokenizers>=0.13.3 transformers>=4.30.1,<5.0.0 speedtest-cli==2.1.3 - pydantic>=1.8.1,<2.0 # 2.0 is incompatible with hivemind==1.1.8 + pydantic>=1.10,<2.0 # 2.0 is incompatible with hivemind==1.1.8 hivemind==1.1.8 tensor_parallel==1.0.23 humanfriendly diff --git a/src/petals/__init__.py b/src/petals/__init__.py index 5658167..3e67633 100644 --- a/src/petals/__init__.py +++ b/src/petals/__init__.py @@ -11,7 +11,7 @@ from petals.models import * from petals.utils import * from petals.utils.logging import initialize_logs as _initialize_logs -__version__ = "1.2.0.dev1" +__version__ = "1.2.0.dev2" if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"): diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index 6b3fde8..b2480f5 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -146,8 +146,9 @@ def main(): help="Skip checking this server's reachability via health.petals.ml " "when connecting to the public swarm. If you connect to a private swarm, " "the check is skipped by default. Use this option only if you know what you are doing") - - parser.add_argument("--adapters", nargs='+', default=None, help="List of pretrained LoRA adapters that can be used for inference or training.") + + parser.add_argument("--adapters", nargs='+', default=(), + help="List of pre-loaded LoRA adapters that can be used for inference or training") # fmt:on args = vars(parser.parse_args()) diff --git a/src/petals/data_structures.py b/src/petals/data_structures.py index 254faae..9e13ebe 100644 --- a/src/petals/data_structures.py +++ b/src/petals/data_structures.py @@ -1,10 +1,8 @@ -from __future__ import annotations - import dataclasses -from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Sequence, Tuple +import pydantic from hivemind import PeerID from hivemind.moe.expert_uid import ExpertUID @@ -21,13 +19,32 @@ class ServerState(Enum): ONLINE = 2 -@dataclass +@pydantic.dataclasses.dataclass class ServerInfo: state: ServerState - throughput: float + throughput: pydantic.confloat(ge=0, allow_inf_nan=False, strict=True) + + adapters: Sequence[str] = () + version: Optional[str] = None + torch_dtype: Optional[str] = None + quant_type: Optional[str] = None + using_relay: Optional[bool] = None + cache_tokens_left: Optional[pydantic.conint(ge=0, strict=True)] = None + + def to_tuple(self) -> Tuple[int, float, dict]: + extra_info = dataclasses.asdict(self) + del extra_info["state"], extra_info["throughput"] + return (self.state.value, self.throughput, extra_info) + + @classmethod + def from_tuple(cls, source: tuple): + state, throughput = source[:2] + extra_info = source[2] if len(source) > 2 else {} + # pydantic will validate existing fields and ignore extra ones + return cls(state=ServerState(state), throughput=throughput, **extra_info) -@dataclass +@dataclasses.dataclass class RemoteModuleInfo: """A remote module that is served by one or more servers""" @@ -35,7 +52,7 @@ class RemoteModuleInfo: servers: Dict[PeerID, ServerInfo] -@dataclass +@dataclasses.dataclass class RemoteSpanInfo: """A chain of remote blocks served by one specific remote peer""" diff --git a/src/petals/dht_utils.py b/src/petals/dht_utils.py index 99316f2..0710f60 100644 --- a/src/petals/dht_utils.py +++ b/src/petals/dht_utils.py @@ -11,7 +11,7 @@ from hivemind.dht import DHT, DHTNode, DHTValue from hivemind.p2p import PeerID from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger -from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo, ServerState +from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo logger = get_logger(__name__) @@ -19,10 +19,8 @@ logger = get_logger(__name__) def declare_active_modules( dht: DHT, uids: Sequence[ModuleUID], + server_info: ServerInfo, expiration_time: DHTExpiration, - state: ServerState, - throughput: float, - adapters: Optional[Sequence[str]] = None, wait: bool = True, ) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]: """ @@ -42,14 +40,7 @@ def declare_active_modules( assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid return dht.run_coroutine( - partial( - _declare_active_modules, - uids=uids, - expiration_time=expiration_time, - state=state, - throughput=throughput, - adapters=list(adapters or []), - ), + partial(_declare_active_modules, uids=uids, server_info=server_info, expiration_time=expiration_time), return_future=not wait, ) @@ -58,16 +49,14 @@ async def _declare_active_modules( dht: DHT, node: DHTNode, uids: List[ModuleUID], + server_info: ServerInfo, expiration_time: DHTExpiration, - state: ServerState, - throughput: float, - adapters: List[str], ) -> Dict[ModuleUID, bool]: num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers) return await node.store_many( keys=uids, subkeys=[dht.peer_id.to_base58()] * len(uids), - values=[(state.value, throughput, dict(adapters=adapters))] * len(uids), + values=[server_info.to_tuple()] * len(uids), expiration_time=expiration_time, num_workers=num_workers, ) @@ -115,29 +104,21 @@ async def _get_remote_module_infos( metadata = found[uid] if metadata is None or not isinstance(metadata.value, dict): if metadata is not None: - logger.error(f"Incorrect metadata for {uid}: {metadata}") + logger.warning(f"Incorrect metadata for {uid}: {metadata}") continue servers = {} for peer_id, server_info in metadata.value.items(): try: peer_id = PeerID.from_base58(peer_id) - state, throughput = server_info.value[:2] - extra_info = server_info.value[2] if len(server_info.value) > 2 else {} - adapters = extra_info.get("adapters", []) - if bool(active_adapter) and active_adapter not in adapters: + server_info = ServerInfo.from_tuple(server_info.value) + + if active_adapter and active_adapter not in server_info.adapters: logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}") continue - if not ( - isinstance(state, int) - and isinstance(throughput, float) - and math.isfinite(throughput) - and throughput >= 0.0 - ): - raise ValueError(f"Invalid server info: {server_info}") - servers[peer_id] = ServerInfo(ServerState(state), throughput) + servers[peer_id] = server_info except (TypeError, ValueError) as e: - logger.error(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}") + logger.warning(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}") if servers: modules[i] = RemoteModuleInfo(uid, servers) return modules diff --git a/src/petals/models/bloom/config.py b/src/petals/models/bloom/config.py index d6a8146..23521fc 100644 --- a/src/petals/models/bloom/config.py +++ b/src/petals/models/bloom/config.py @@ -9,8 +9,6 @@ from petals.client.lm_head import LMHeadConfig from petals.client.ptune import PTuneConfig from petals.client.routing.sequence_manager import SequenceManagerConfig from petals.models.bloom.block import WrappedBloomBlock -from petals.utils.auto_config import AutoDistributedConfig -from petals.utils.version import get_compatible_model_repo logger = get_logger(__name__) diff --git a/src/petals/models/llama/config.py b/src/petals/models/llama/config.py index 78443eb..b21fa9a 100644 --- a/src/petals/models/llama/config.py +++ b/src/petals/models/llama/config.py @@ -9,7 +9,6 @@ from petals.client.lm_head import LMHeadConfig from petals.client.ptune import PTuneConfig from petals.client.routing.sequence_manager import SequenceManagerConfig from petals.models.llama.block import WrappedLlamaBlock -from petals.utils.auto_config import AutoDistributedConfig logger = get_logger(__name__) @@ -31,8 +30,7 @@ class DistributedLlamaConfig(LlamaConfig, SequenceManagerConfig, PTuneConfig, LM loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path) if loading_from_repo and dht_prefix is None: dht_prefix = str(model_name_or_path) - if "/" in dht_prefix: # If present, strip repository name to merge blocks hosted by different accounts - dht_prefix = dht_prefix[dht_prefix.rfind("/") + 1 :] + dht_prefix = dht_prefix.split("/")[-1] # Use only repo name to merge blocks hosted by different accounts if not dht_prefix.endswith("-hf"): dht_prefix += "-hf" logger.info(f"Using DHT prefix: {dht_prefix}") diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index 12fd6eb..d0531de 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -562,11 +562,10 @@ class TransformerConnectionHandler(ConnectionHandler): """Return metadata about stored block uids and current load""" backend = self.module_backends[request.uid] if request.uid else next(iter(self.module_backends.values())) - cache_bytes_left = max(0, backend.memory_cache.max_size_bytes - backend.memory_cache.current_size_bytes) result = { "version": petals.__version__, "dht_client_mode": self.dht.client_mode, - CACHE_TOKENS_AVAILABLE: cache_bytes_left // max(backend.cache_bytes_per_token.values()), + CACHE_TOKENS_AVAILABLE: backend.memory_cache.bytes_left // max(backend.cache_bytes_per_token.values()), } if request.uid: diff --git a/src/petals/server/memory_cache.py b/src/petals/server/memory_cache.py index 7f00bae..a1e2f26 100644 --- a/src/petals/server/memory_cache.py +++ b/src/petals/server/memory_cache.py @@ -47,6 +47,10 @@ class MemoryCache: def current_size_bytes(self, value: int): self._current_size.value = value + @property + def bytes_left(self) -> int: + return self.max_size_bytes - self.current_size_bytes + @property def handle_counter(self) -> int: return self._handle_counter.value diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 83a94e3..bac93c5 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -16,8 +16,9 @@ from hivemind.proto.runtime_pb2 import CompressionType from hivemind.utils.logging import get_logger from transformers import PretrainedConfig +import petals from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS -from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState +from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerInfo, ServerState from petals.dht_utils import declare_active_modules, get_remote_module_infos from petals.server import block_selection from petals.server.backend import TransformerBackend, merge_inference_pools_inplace @@ -29,7 +30,6 @@ from petals.server.reachability import ReachabilityProtocol, check_direct_reacha from petals.server.throughput import get_dtype_name, get_server_throughput from petals.utils.auto_config import AutoDistributedConfig from petals.utils.convert_block import QuantType, check_device_balance, convert_block -from petals.utils.disk_cache import DEFAULT_CACHE_DIR from petals.utils.version import get_compatible_model_repo logger = get_logger(__name__) @@ -81,7 +81,7 @@ class Server: dht_client_mode: Optional[bool] = None, use_relay: bool = True, use_auto_relay: bool = True, - adapters: Optional[List[str]] = None, + adapters: Sequence[str] = (), **kwargs, ): """Create a server with one or more bloom blocks. See run_server.py for documentation.""" @@ -215,7 +215,15 @@ class Server: force_eval=(throughput == "eval"), cache_dir=cache_dir, ) - self.throughput = throughput + self.server_info = ServerInfo( + state=ServerState.JOINING, + throughput=throughput, + adapters=tuple(adapters), + version=petals.__version__, + torch_dtype=str(torch_dtype).lstrip("torch."), + quant_type=quant_type.name.lower(), + using_relay=self.dht.client_mode, + ) self.balance_quality = balance_quality self.mean_balance_check_period = mean_balance_check_period @@ -283,7 +291,7 @@ class Server: block_config=self.block_config, attn_cache_bytes=self.attn_cache_bytes, alloc_timeout=self.alloc_timeout, - throughput=self.throughput, + server_info=self.server_info, block_indices=block_indices, num_handlers=self.num_handlers, min_batch_size=self.min_batch_size, @@ -307,7 +315,6 @@ class Server: quant_type=self.quant_type, tensor_parallel_devices=self.tensor_parallel_devices, should_validate_reachability=self.should_validate_reachability, - adapters=self.adapters, start=True, ) try: @@ -385,7 +392,7 @@ class ModuleContainer(threading.Thread): block_config: PretrainedConfig, attn_cache_bytes: int, alloc_timeout: float, - throughput: float, + server_info: ServerInfo, block_indices: List[int], min_batch_size: int, max_batch_size: int, @@ -401,16 +408,18 @@ class ModuleContainer(threading.Thread): quant_type: QuantType, tensor_parallel_devices: Sequence[torch.device], should_validate_reachability: bool, - adapters: Optional[List[str]] = None, **kwargs, ) -> ModuleContainer: module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices] + memory_cache = MemoryCache(attn_cache_bytes, alloc_timeout) + + server_info.state = ServerState.JOINING joining_announcer = ModuleAnnouncerThread( module_uids, dht, - ServerState.JOINING, - adapters=adapters, - throughput=throughput, + server_info, + block_config=block_config, + memory_cache=memory_cache, update_period=update_period, expiration=expiration, daemon=True, @@ -420,7 +429,6 @@ class ModuleContainer(threading.Thread): assert len(tensor_parallel_devices) >= 1 and all(isinstance(d, torch.device) for d in tensor_parallel_devices) - memory_cache = MemoryCache(attn_cache_bytes, alloc_timeout) blocks = {} try: for module_uid, block_index in zip(module_uids, block_indices): @@ -441,7 +449,7 @@ class ModuleContainer(threading.Thread): tensor_parallel_devices, device, quant_type, - adapters=adapters, + adapters=server_info.adapters, freeze=True, use_auth_token=use_auth_token, cache_dir=cache_dir, @@ -477,13 +485,12 @@ class ModuleContainer(threading.Thread): joining_announcer.stop.set() joining_announcer.join() + server_info.state = ServerState.OFFLINE declare_active_modules( dht, module_uids, + server_info, expiration_time=get_dht_time() + expiration, - state=ServerState.OFFLINE, - throughput=throughput, - adapters=adapters, ) logger.info(f"Announced that blocks {module_uids} are offline") raise @@ -497,8 +504,9 @@ class ModuleContainer(threading.Thread): dht, dht_prefix, blocks, - adapters=adapters, - throughput=throughput, + block_config=block_config, + memory_cache=memory_cache, + server_info=server_info, update_period=update_period, expiration=expiration, **kwargs, @@ -510,10 +518,11 @@ class ModuleContainer(threading.Thread): dht_prefix: str, module_backends: Dict[str, TransformerBackend], *, + block_config: PretrainedConfig, + memory_cache: MemoryCache, inference_max_length: int, num_handlers: int, - throughput: float, - adapters: Optional[Sequence[str]], + server_info: ServerInfo, update_period: float, expiration: Optional[float] = None, request_timeout: float, @@ -525,7 +534,7 @@ class ModuleContainer(threading.Thread): super().__init__() self.dht, self.module_backends = dht, module_backends - self.throughput, self.update_period, self.expiration = throughput, update_period, expiration + self.server_info, self.update_period, self.expiration = server_info, update_period, expiration self.push_manager = mp.Manager() self.push_manager.__enter__() @@ -534,7 +543,7 @@ class ModuleContainer(threading.Thread): TransformerConnectionHandler( dht, self.module_backends, - adapters=adapters, + adapters=server_info.adapters, dht_prefix=dht_prefix, push_manager=self.push_manager, session_queues=session_queues, @@ -548,12 +557,14 @@ class ModuleContainer(threading.Thread): self.runtime = RuntimeWithDeduplicatedPools(self.module_backends, device=None, **kwargs) # note: We set device=None in runtime to avoid moving all modules to device 0 in runtime.run(). tensor_parallel has already moved it as needed. + + self.server_info.state = ServerState.ONLINE self.online_announcer = ModuleAnnouncerThread( list(self.module_backends.keys()), dht, - ServerState.ONLINE, - adapters=adapters, - throughput=throughput, + self.server_info, + block_config=block_config, + memory_cache=memory_cache, update_period=update_period, expiration=expiration, daemon=True, @@ -613,12 +624,12 @@ class ModuleContainer(threading.Thread): self.online_announcer.stop.set() self.online_announcer.join() + self.server_info.state = ServerState.OFFLINE declare_active_modules( self.dht, self.module_backends.keys(), + self.server_info, expiration_time=get_dht_time() + self.expiration, - state=ServerState.OFFLINE, - throughput=self.throughput, ) logger.info(f"Announced that blocks {list(self.module_backends.keys())} are offline") @@ -651,10 +662,10 @@ class ModuleAnnouncerThread(threading.Thread): self, module_uids: List[str], dht: DHT, - state: ServerState, - adapters: Optional[Sequence[str]], + server_info: ServerInfo, *, - throughput: float, + block_config: PretrainedConfig, + memory_cache: MemoryCache, update_period: float = 30, expiration: float, **kwargs, @@ -662,22 +673,21 @@ class ModuleAnnouncerThread(threading.Thread): super().__init__(**kwargs) self.module_uids = module_uids self.dht = dht - self.state = state - self.adapters = adapters - self.throughput = throughput + self.server_info = server_info + self.memory_cache = memory_cache + self.bytes_per_token = block_config.hidden_size * torch.finfo(DTYPE_MAP[server_info.torch_dtype]).bits // 8 self.update_period = update_period self.expiration = expiration self.stop = threading.Event() def run(self) -> None: while True: + self.server_info.cache_tokens_left = self.memory_cache.bytes_left // self.bytes_per_token declare_active_modules( self.dht, self.module_uids, + self.server_info, expiration_time=get_dht_time() + self.expiration, - state=self.state, - throughput=self.throughput, - adapters=self.adapters, ) if self.stop.wait(self.update_period): break diff --git a/src/petals/utils/convert_block.py b/src/petals/utils/convert_block.py index 299e979..f8a4637 100644 --- a/src/petals/utils/convert_block.py +++ b/src/petals/utils/convert_block.py @@ -2,7 +2,7 @@ Tools for converting transformer blocks, applying quantization and/or tensor parallelism """ import re -from typing import List, Optional, Sequence +from typing import Optional, Sequence import tensor_parallel as tp import torch @@ -25,7 +25,7 @@ def convert_block( output_device: torch.device, quant_type: QuantType, freeze: bool = True, - adapters: Optional[List[str]] = None, + adapters: Optional[Sequence[str]] = None, **kwargs, ) -> tp.TensorParallel: """