|
|
|
@ -23,15 +23,24 @@ logger = get_logger(__file__)
|
|
|
|
|
|
|
|
|
|
class Server(threading.Thread):
|
|
|
|
|
"""Serves one or more bloom layers for inference, forward and backward; announces oneself to the DHT"""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self, dht: DHT, module_backends: Dict[str, BloomBlockBackend], *,
|
|
|
|
|
device: torch.device, num_connection_handlers: int = 8,
|
|
|
|
|
update_period: float = 30, expiration: Optional[float] = None,
|
|
|
|
|
start: bool, **kwargs
|
|
|
|
|
self,
|
|
|
|
|
dht: DHT,
|
|
|
|
|
module_backends: Dict[str, BloomBlockBackend],
|
|
|
|
|
*,
|
|
|
|
|
device: torch.device,
|
|
|
|
|
num_connection_handlers: int = 8,
|
|
|
|
|
update_period: float = 30,
|
|
|
|
|
expiration: Optional[float] = None,
|
|
|
|
|
start: bool,
|
|
|
|
|
**kwargs,
|
|
|
|
|
):
|
|
|
|
|
threading.Thread.__init__(self)
|
|
|
|
|
self.dht, self.module_backends, self.update_period = dht, module_backends, update_period
|
|
|
|
|
self.conn_handlers = [TransformerConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)]
|
|
|
|
|
self.conn_handlers = [
|
|
|
|
|
TransformerConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)
|
|
|
|
|
]
|
|
|
|
|
self.runtime = Runtime(self.module_backends, device=device, **kwargs)
|
|
|
|
|
self.dht_handler_thread = DHTHandlerThread(self.module_backends, dht, update_period, expiration, daemon=True)
|
|
|
|
|
self.checkpoint_saver = None # no need to save checkpoints since we do not change model state
|
|
|
|
@ -71,23 +80,23 @@ class Server(threading.Thread):
|
|
|
|
|
# noinspection PyMethodOverriding
|
|
|
|
|
@classmethod
|
|
|
|
|
def create(
|
|
|
|
|
cls,
|
|
|
|
|
num_blocks: int,
|
|
|
|
|
block_config: str,
|
|
|
|
|
num_handlers: Optional[int] = None,
|
|
|
|
|
min_batch_size: int = 1,
|
|
|
|
|
max_batch_size: int = 4096,
|
|
|
|
|
cache_size_bytes: Optional[int] = None,
|
|
|
|
|
device: Union[str, torch.device] = None,
|
|
|
|
|
initial_peers: Sequence[str] = (),
|
|
|
|
|
compression=CompressionType.NONE,
|
|
|
|
|
stats_report_interval: Optional[int] = None,
|
|
|
|
|
custom_module_path=None,
|
|
|
|
|
update_period: float = 30,
|
|
|
|
|
expiration: Optional[float] = None,
|
|
|
|
|
*,
|
|
|
|
|
start: bool,
|
|
|
|
|
**kwargs,
|
|
|
|
|
cls,
|
|
|
|
|
num_blocks: int,
|
|
|
|
|
block_config: str,
|
|
|
|
|
num_handlers: Optional[int] = None,
|
|
|
|
|
min_batch_size: int = 1,
|
|
|
|
|
max_batch_size: int = 4096,
|
|
|
|
|
cache_size_bytes: Optional[int] = None,
|
|
|
|
|
device: Union[str, torch.device] = None,
|
|
|
|
|
initial_peers: Sequence[str] = (),
|
|
|
|
|
compression=CompressionType.NONE,
|
|
|
|
|
stats_report_interval: Optional[int] = None,
|
|
|
|
|
custom_module_path=None,
|
|
|
|
|
update_period: float = 30,
|
|
|
|
|
expiration: Optional[float] = None,
|
|
|
|
|
*,
|
|
|
|
|
start: bool,
|
|
|
|
|
**kwargs,
|
|
|
|
|
) -> Server:
|
|
|
|
|
"""Create a server with one or more bloom blocks. See run_server.py for documentation."""
|
|
|
|
|
if custom_module_path is not None:
|
|
|
|
@ -181,4 +190,3 @@ class Server(threading.Thread):
|
|
|
|
|
|
|
|
|
|
self.runtime.shutdown()
|
|
|
|
|
logger.info("Server shutdown succesfully")
|
|
|
|
|
|
|
|
|
|