You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
petals/src/petals/data_structures.py

88 lines
2.3 KiB
Python

import dataclasses
from enum import Enum
from typing import Any, Dict, Optional, Sequence, Tuple
import pydantic
from hivemind import PeerID
from hivemind.moe.expert_uid import ExpertUID
ModuleUID = str
UID_DELIMITER = "." # delimits parts of one module uid, e.g. "bloom.transformer.h.4.self_attention"
CHAIN_DELIMITER = " " # delimits multiple uids in a sequence, e.g. "bloom.layer3 bloom.layer4"
class ServerState(Enum):
OFFLINE = 0
JOINING = 1
ONLINE = 2
RPS = pydantic.confloat(ge=0, allow_inf_nan=False, strict=True)
@pydantic.dataclasses.dataclass
class ServerInfo:
state: ServerState
throughput: RPS
public_name: Optional[str] = None
version: Optional[str] = None
network_rps: Optional[RPS] = None
forward_rps: Optional[RPS] = None
inference_rps: Optional[RPS] = None
adapters: Sequence[str] = ()
torch_dtype: Optional[str] = None
quant_type: Optional[str] = None
using_relay: Optional[bool] = None
cache_tokens_left: Optional[pydantic.conint(ge=0, strict=True)] = None
next_pings: Optional[Dict[str, pydantic.confloat(ge=0, strict=True)]] = None
def to_tuple(self) -> Tuple[int, float, dict]:
extra_info = dataclasses.asdict(self)
del extra_info["state"], extra_info["throughput"]
return (self.state.value, self.throughput, extra_info)
@classmethod
def from_tuple(cls, source: tuple):
state, throughput = source[:2]
extra_info = source[2] if len(source) > 2 else {}
# pydantic will validate existing fields and ignore extra ones
return cls(state=ServerState(state), throughput=throughput, **extra_info)
@dataclasses.dataclass
class RemoteModuleInfo:
"""A remote module that is served by one or more servers"""
uid: ModuleUID
servers: Dict[PeerID, ServerInfo]
@dataclasses.dataclass
class RemoteSpanInfo:
"""A chain of remote blocks served by one specific remote peer"""
peer_id: PeerID
start: int
end: int
server_info: ServerInfo
@property
def length(self):
return self.end - self.start
RPCInfo = Dict[str, Any]
Handle = int
@dataclasses.dataclass(frozen=True)
class InferenceMetadata:
uid: ExpertUID
prefix_length: int
cache_handles: Tuple[Handle, ...]
active_adapter: Optional[str]