|
|
|
@ -14,7 +14,7 @@ import multiprocessing as mp
|
|
|
|
|
from src import DistributedBloomConfig
|
|
|
|
|
from src.bloom.block import BloomBlock
|
|
|
|
|
from src.server.cache import MemoryCache
|
|
|
|
|
from src.server.backend import TransformerBlockBackend
|
|
|
|
|
from src.server.backend import BloomBlockBackend
|
|
|
|
|
from src.server.handler import BloomConnectionHandler
|
|
|
|
|
|
|
|
|
|
use_hivemind_log_handler("in_root_logger")
|
|
|
|
@ -24,16 +24,14 @@ logger = get_logger(__file__)
|
|
|
|
|
class Server(threading.Thread):
|
|
|
|
|
"""Serves one or more bloom layers for inference, forward and backward; announces oneself to the DHT"""
|
|
|
|
|
def __init__(
|
|
|
|
|
self, dht: DHT, module_backends: Dict[str, TransformerBlockBackend], *,
|
|
|
|
|
self, dht: DHT, module_backends: Dict[str, BloomBlockBackend], *,
|
|
|
|
|
device: torch.device, num_connection_handlers: int = 8,
|
|
|
|
|
update_period: float = 30, expiration: Optional[float] = None,
|
|
|
|
|
start: bool, **kwargs
|
|
|
|
|
):
|
|
|
|
|
threading.Thread.__init__(self)
|
|
|
|
|
self.dht, self.module_backends, self.update_period = dht, module_backends, update_period
|
|
|
|
|
self.conn_handlers = [
|
|
|
|
|
BloomConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)
|
|
|
|
|
]
|
|
|
|
|
self.conn_handlers = [BloomConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)]
|
|
|
|
|
self.runtime = Runtime(self.module_backends, device=device, **kwargs)
|
|
|
|
|
self.dht_handler_thread = DHTHandlerThread(self.module_backends, dht, update_period, expiration, daemon=True)
|
|
|
|
|
self.checkpoint_saver = None # no need to save checkpoints since we do not change model state
|
|
|
|
@ -102,16 +100,17 @@ class Server(threading.Thread):
|
|
|
|
|
num_handlers = num_handlers if num_handlers is not None else num_blocks * 8
|
|
|
|
|
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
block_config = DistributedBloomConfig.from_pretrained(block_config, use_auth_token=True)
|
|
|
|
|
|
|
|
|
|
memory_cache = MemoryCache(device, cache_size_bytes)
|
|
|
|
|
# initialize modules
|
|
|
|
|
blocks = {}
|
|
|
|
|
for i in range(num_blocks):
|
|
|
|
|
module_uid = f"dummy_block.{i}"
|
|
|
|
|
HARDCODCED_LENGTH = 2048
|
|
|
|
|
|
|
|
|
|
blocks[module_uid] = TransformerBlockBackend(
|
|
|
|
|
blocks[module_uid] = BloomBlockBackend(
|
|
|
|
|
module_uid,
|
|
|
|
|
BloomBlock(block_config, layer_number=i),
|
|
|
|
|
memory_cache=memory_cache,
|
|
|
|
|
args_schema=(BatchTensorDescriptor(1, HARDCODCED_LENGTH, block_config.hidden_size, compression=compression),),
|
|
|
|
|
kwargs_schema={},
|
|
|
|
|
outputs_schema=(BatchTensorDescriptor(1, HARDCODCED_LENGTH, block_config.hidden_size, compression=compression),),
|
|
|
|
|