|
|
|
@ -11,7 +11,7 @@ from hivemind.proto.runtime_pb2 import CompressionType
|
|
|
|
|
from hivemind.utils.logging import use_hivemind_log_handler, get_logger
|
|
|
|
|
import multiprocessing as mp
|
|
|
|
|
|
|
|
|
|
from src import DistributedBloomConfig
|
|
|
|
|
from src import DistributedBloomConfig, BloomForCausalLM
|
|
|
|
|
from src.bloom.block import BloomBlock
|
|
|
|
|
from src.server.cache import MemoryCache
|
|
|
|
|
from src.server.backend import TransformerBackend
|
|
|
|
@ -81,8 +81,10 @@ class Server(threading.Thread):
|
|
|
|
|
@classmethod
|
|
|
|
|
def create(
|
|
|
|
|
cls,
|
|
|
|
|
num_blocks: int,
|
|
|
|
|
prefix: str,
|
|
|
|
|
block_config: str,
|
|
|
|
|
num_blocks: Optional[int] = None,
|
|
|
|
|
block_indices: Optional[str] = None,
|
|
|
|
|
num_handlers: Optional[int] = None,
|
|
|
|
|
min_batch_size: int = 1,
|
|
|
|
|
max_batch_size: int = 4096,
|
|
|
|
@ -101,20 +103,37 @@ class Server(threading.Thread):
|
|
|
|
|
"""Create a server with one or more bloom blocks. See run_server.py for documentation."""
|
|
|
|
|
if custom_module_path is not None:
|
|
|
|
|
add_custom_models_from_file(custom_module_path)
|
|
|
|
|
|
|
|
|
|
assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both"
|
|
|
|
|
dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
|
|
|
|
|
visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
|
|
|
|
|
logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
|
|
|
|
|
|
|
|
|
|
num_handlers = num_handlers if num_handlers is not None else num_blocks * 8
|
|
|
|
|
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
block_config = DistributedBloomConfig.from_pretrained(block_config, use_auth_token=True)
|
|
|
|
|
memory_cache = MemoryCache(device, cache_size_bytes)
|
|
|
|
|
|
|
|
|
|
if block_indices is not None:
|
|
|
|
|
try:
|
|
|
|
|
start, end = block_indices.split(':')
|
|
|
|
|
start, end = map(int, map(str.strip, (start, end)))
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Failed to parse --block_indices ({e}), must be start:end (e.g. 0:33)")
|
|
|
|
|
raise
|
|
|
|
|
block_indices = range(start, end)
|
|
|
|
|
else:
|
|
|
|
|
assert num_blocks is not None
|
|
|
|
|
block_indices = range(num_blocks) # TODO replace with proper load balancing
|
|
|
|
|
|
|
|
|
|
## TODO: the code below will load the entire model in RAM. Please replace with sliced model
|
|
|
|
|
block_config = DistributedBloomConfig.from_pretrained(block_config, use_auth_token=True)
|
|
|
|
|
# model = BloomForCausalLM.from_pretrained(model, use_auth_token=True)
|
|
|
|
|
## /TODO
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# initialize modules
|
|
|
|
|
blocks = {}
|
|
|
|
|
for i in range(num_blocks):
|
|
|
|
|
module_uid = f"dummy_block.{i}"
|
|
|
|
|
block = BloomBlock(block_config, layer_number=i)
|
|
|
|
|
for block_index in block_indices:
|
|
|
|
|
module_uid = f"{prefix}.{block_index}"
|
|
|
|
|
block = BloomBlock(block_config, layer_number=block_index)
|
|
|
|
|
for param in block.parameters():
|
|
|
|
|
param.requires_grad = False
|
|
|
|
|
|
|
|
|
@ -129,6 +148,8 @@ class Server(threading.Thread):
|
|
|
|
|
max_batch_size=max_batch_size,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
num_handlers = num_handlers if num_handlers is not None else len(blocks) * 4
|
|
|
|
|
|
|
|
|
|
return cls(
|
|
|
|
|
dht,
|
|
|
|
|
blocks,
|
|
|
|
|