basic backend

pull/9/head
justheuristic 2 years ago
parent 1c49bcb741
commit 7ce7cd7a97

@ -7,7 +7,7 @@ from hivemind.proto.runtime_pb2 import CompressionType
from hivemind.utils.limits import increase_file_limit
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from src.server.server import BloomServer
from src.server.server import Server
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__name__)
@ -63,7 +63,7 @@ def main():
compression_type = args.pop("compression")
compression = getattr(CompressionType, compression_type)
server = BloomServer.create(**args, start=True, compression=compression)
server = Server.create(**args, start=True, compression=compression)
try:
server.join()

@ -16,10 +16,8 @@ from src.server.cache import MemoryCache
# - ensure that TaskPool for inference is NOT batched
# - ensure that optimizer/scheduler is not created
HARDCODCED_LENGTH = 2048
class BloomBlockBackend(ExpertBackend):
class TransformerBlockBackend(ExpertBackend):
"""A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward"""
def __init__(self, name: str, module: BloomBlock, *, memory_cache: MemoryCache, **kwargs):
object().__init__() # to bypass super.__init__

@ -1,6 +1,6 @@
from typing import AsyncIterator
from typing import AsyncIterator, Dict
from hivemind import P2PContext
from hivemind import P2PContext, DHT
from hivemind.moe.server.connection_handler import ConnectionHandler
from hivemind.proto import runtime_pb2
@ -8,6 +8,13 @@ from hivemind.proto import runtime_pb2
class BloomConnectionHandler(ConnectionHandler):
"""Handles three request types: forward, backward and forward-incremental (inference)"""
def __init__(self, dht: DHT, experts: Dict[str, BloomBackend]):
super().__init__()
self.dht, self.experts = dht, experts
self._p2p: Optional[P2P] = None
self.ready = MPFuture()
async def rpc_forward_incremental(
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
) -> AsyncIterator[runtime_pb2.ExpertRequest]:

@ -1,43 +1,75 @@
from __future__ import annotations
import threading
from typing import Optional, Dict, Union, Sequence
import torch
from hivemind import Server, DHT
from hivemind import DHT, BatchTensorDescriptor
from hivemind.moe.server.dht_handler import DHTHandlerThread
from hivemind.moe.server.layers import add_custom_models_from_file
from hivemind.moe.server.runtime import Runtime
from hivemind.proto.runtime_pb2 import CompressionType
from hivemind.utils.logging import use_hivemind_log_handler, get_logger
import multiprocessing as mp
from src import DistributedBloomConfig
from src.bloom.block import BloomBlock
from src.server.cache import MemoryCache
from src.server.backend import BloomBlockBackend
from src.server.backend import TransformerBlockBackend
from src.server.handler import BloomConnectionHandler
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
class BloomServer(Server):
class Server(threading.Thread):
"""Serves one or more bloom layers for inference, forward and backward; announces oneself to the DHT"""
def __init__(
self, dht: DHT, module_backends: Dict[str, BloomBlockBackend], *,
device: torch.device, num_connection_handlers: int = 8, update_period: float = 30,
cache_size_bytes: Optional[int] = None, start: bool, **kwargs,
self, dht: DHT, module_backends: Dict[str, TransformerBlockBackend], *,
device: torch.device, num_connection_handlers: int = 8,
update_period: float = 30, expiration: Optional[float] = None,
start: bool, **kwargs
):
threading.Thread.__init__(self)
self.attention_cache = MemoryCache(device=device, max_size_bytes=cache_size_bytes)
self.dht, self.module_backends, self.update_period = dht, module_backends, update_period
self.conn_handlers = [BloomConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)]
self.conn_handlers = [
BloomConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)
]
self.runtime = Runtime(self.module_backends, device=device, **kwargs)
self.dht_handler_thread = DHTHandlerThread(self.experts, dht, update_period=update_period, daemon=True)
self.dht_handler_thread = DHTHandlerThread(self.module_backends, dht, update_period, expiration, daemon=True)
self.checkpoint_saver = None # no need to save checkpoints since we do not change model state
if start:
self.run_in_background(await_ready=True)
def run(self):
"""
Starts Server in the current thread. Initializes dht if necessary, starts connection handlers,
runs Runtime (self.runtime) to process incoming requests.
"""
logger.info(f"Serving {len(self.module_backends)} blocks:")
for expert_name, backend in self.module_backends.items():
num_parameters = sum(p.numel() for p in backend.module.parameters() if p.requires_grad)
logger.info(f"{expert_name}: {backend.module.__class__.__name__}, {num_parameters} parameters")
if not self.dht.is_alive():
self.dht.run_in_background(await_ready=True)
if self.module_backends:
self.dht_handler_thread.start()
if self.checkpoint_saver is not None:
self.checkpoint_saver.start()
for process in self.conn_handlers:
if not process.is_alive():
process.start()
process.ready.result()
try:
self.runtime.run()
finally:
self.shutdown()
# noinspection PyMethodOverriding
@classmethod
def create(
@ -69,40 +101,84 @@ class BloomServer(Server):
num_handlers = num_handlers if num_handlers is not None else num_blocks * 8
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
if isinstance(block_config, str):
block_config = DistributedBloomConfig
block_config = DistributedBloomConfig.from_pretrained(block_config, use_auth_token=True)
# initialize modules
module_backends = {}
for i in range(len(module_backends)):
blocks = {}
for i in range(num_blocks):
module_uid = f"dummy_block.{i}"
block = BloomBlock(block_config, layer_number=i)
#TODO run the actual model
module_backends[module_uid] = BloomBlockBackend(
name=expert_uid,
expert=block,
args_schema=args_schema,
num_warmup_steps=num_warmup_steps,
num_total_steps=num_total_steps,
clip_grad_norm=clip_grad_norm,
HARDCODCED_LENGTH = 2048
blocks[module_uid] = TransformerBlockBackend(
module_uid,
BloomBlock(block_config, layer_number=i),
args_schema=(BatchTensorDescriptor(1, HARDCODCED_LENGTH, block_config.hidden_size, compression=compression),),
kwargs_schema={},
outputs_schema=(BatchTensorDescriptor(1, HARDCODCED_LENGTH, block_config.hidden_size, compression=compression),),
min_batch_size=min_batch_size,
max_batch_size=max_batch_size,
)
if checkpoint_dir is not None:
load_experts(experts, checkpoint_dir)
return cls(
dht,
experts,
blocks,
cache_size_bytes=cache_size_bytes,
num_connection_handlers=num_handlers,
device=device,
checkpoint_dir=checkpoint_dir,
stats_report_interval=stats_report_interval,
update_period=update_period,
expiration=expiration,
start=start,
)
def run_in_background(self, await_ready=True, timeout=None):
"""
Starts Server in a background thread. if await_ready, this method will wait until background server
is ready to process incoming requests or for :timeout: seconds max.
"""
self.start()
if await_ready and not self.ready.wait(timeout=timeout):
raise TimeoutError("Server didn't notify .ready in {timeout} seconds")
@property
def ready(self) -> mp.synchronize.Event:
"""
An event (multiprocessing.Event) that is set when the server is ready to process requests.
Example
=======
>>> server.start()
>>> server.ready.wait(timeout=10)
>>> print("Server ready" if server.ready.is_set() else "Server didn't start in 10 seconds")
"""
return self.runtime.ready # mp.Event that is true if self is ready to process batches
def shutdown(self):
"""
Gracefully terminate the server, process-safe.
Please note that terminating server otherwise (e.g. by killing processes) may result in zombie processes.
If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
"""
self.ready.clear()
for process in self.conn_handlers:
process.terminate()
process.join()
logger.debug("Connection handlers terminated")
if self.module_backends:
self.dht_handler_thread.stop.set()
self.dht_handler_thread.join()
if self.checkpoint_saver is not None:
self.checkpoint_saver.stop.set()
self.checkpoint_saver.join()
self.dht.shutdown()
self.dht.join()
logger.debug(f"Shutting down runtime")
self.runtime.shutdown()
logger.info("Server shutdown succesfully")

Loading…
Cancel
Save