You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
petals/src/petals/data_structures.py

118 lines
3.1 KiB
Python

import dataclasses
from enum import Enum
from typing import Any, Dict, Optional, Sequence, Tuple
import pydantic
from hivemind import PeerID
from hivemind.moe.expert_uid import ExpertUID
ModuleUID = str
UID_DELIMITER = "." # delimits parts of one module uid, e.g. "bloom.transformer.h.4.self_attention"
CHAIN_DELIMITER = " " # delimits multiple uids in a sequence, e.g. "bloom.layer3 bloom.layer4"
def parse_uid(uid: ModuleUID) -> Tuple[str, int]:
assert CHAIN_DELIMITER not in uid, "parse_uid() does not support chained UIDs"
dht_prefix, index = uid.split(UID_DELIMITER)
return dht_prefix, int(index)
@pydantic.dataclasses.dataclass
class ModelInfo:
num_blocks: pydantic.conint(ge=1, strict=True)
repository: Optional[str] = None
def to_dict(self) -> dict:
return dataclasses.asdict(self)
@classmethod
def from_dict(cls, source: dict):
return cls(**source)
class ServerState(Enum):
OFFLINE = 0
JOINING = 1
ONLINE = 2
RPS = pydantic.confloat(ge=0, allow_inf_nan=False, strict=True)
@pydantic.dataclasses.dataclass
class ServerInfo:
state: ServerState
throughput: RPS
start_block: Optional[pydantic.conint(ge=0, strict=True)] = None
end_block: Optional[pydantic.conint(ge=0, strict=True)] = None
public_name: Optional[str] = None
version: Optional[str] = None
network_rps: Optional[RPS] = None
forward_rps: Optional[RPS] = None
inference_rps: Optional[RPS] = None
adapters: Sequence[str] = ()
torch_dtype: Optional[str] = None
quant_type: Optional[str] = None
using_relay: Optional[bool] = None
cache_tokens_left: Optional[pydantic.conint(ge=0, strict=True)] = None
next_pings: Optional[Dict[str, pydantic.confloat(ge=0, strict=True)]] = None
def to_tuple(self) -> Tuple[int, float, dict]:
extra_info = dataclasses.asdict(self)
del extra_info["state"], extra_info["throughput"]
return (self.state.value, self.throughput, extra_info)
@classmethod
def from_tuple(cls, source: tuple):
state, throughput = source[:2]
extra_info = source[2] if len(source) > 2 else {}
# pydantic will validate existing fields and ignore extra ones
return cls(state=ServerState(state), throughput=throughput, **extra_info)
@dataclasses.dataclass
class RemoteModuleInfo:
"""A remote module that is served by one or more servers"""
uid: ModuleUID
servers: Dict[PeerID, ServerInfo]
@dataclasses.dataclass
class RemoteSpanInfo:
"""A chain of remote blocks served by one specific remote peer"""
peer_id: PeerID
start: int
end: int
server_info: ServerInfo
@property
def length(self) -> int:
return self.end - self.start
@property
def state(self) -> ServerState:
return self.server_info.state
@property
def throughput(self) -> float:
return self.server_info.throughput
RPCInfo = Dict[str, Any]
Handle = int
@dataclasses.dataclass(frozen=True)
class InferenceMetadata:
uid: ExpertUID
prefix_length: int
cache_handles: Tuple[Handle, ...]
active_adapter: Optional[str]