mirror of
https://github.com/bigscience-workshop/petals
synced 2024-10-31 09:20:41 +00:00
switch to hivemind-master
This commit is contained in:
parent
5a15c13ca7
commit
20497f81d1
@ -15,11 +15,13 @@ def main():
|
||||
parser = configargparse.ArgParser(default_config_files=["config.yml"])
|
||||
parser.add('-c', '--config', required=False, is_config_file=True, help='config file path')
|
||||
|
||||
parser.add_argument('--block_config', type=str, default='bigscience/bloom', help="name or path of model config")
|
||||
parser.add_argument('--num_blocks', type=int, default=1, help="The number of blocks to serve")
|
||||
parser.add_argument('--host_maddrs', type=list, nargs='+', default=['/ip4/0.0.0.0/tcp/0'], required=False,
|
||||
parser.add_argument('--prefix', type=str, required=True, help="Announce all blocks with this prefix")
|
||||
parser.add_argument('--block_config', type=str, default='bigscience/bloom-6b3', help="name or path of model config")
|
||||
parser.add_argument('--num_blocks', type=int, default=None, help="The number of blocks to serve")
|
||||
parser.add_argument('--block_indices', type=str, default=None, help="Specific block indices to serve")
|
||||
parser.add_argument('--host_maddrs', nargs='+', default=['/ip4/0.0.0.0/tcp/0'], required=False,
|
||||
help='Multiaddrs to listen for external connections from other p2p instances; default: all IPv4 and TCP: /ip4/0.0.0.0/tcp/0')
|
||||
parser.add_argument('--announce_maddrs', type=list, nargs='+', default=None, required=False,
|
||||
parser.add_argument('--announce_maddrs', nargs='+', default=None, required=False,
|
||||
help='Visible multiaddrs the host announces for external connections from other p2p instances')
|
||||
|
||||
parser.add_argument('--compression', type=str, default='NONE', required=False, help='Tensor compression communication')
|
||||
|
@ -11,7 +11,7 @@ from hivemind.proto.runtime_pb2 import CompressionType
|
||||
from hivemind.utils.logging import use_hivemind_log_handler, get_logger
|
||||
import multiprocessing as mp
|
||||
|
||||
from src import DistributedBloomConfig
|
||||
from src import DistributedBloomConfig, BloomForCausalLM
|
||||
from src.bloom.block import BloomBlock
|
||||
from src.server.cache import MemoryCache
|
||||
from src.server.backend import TransformerBackend
|
||||
@ -81,8 +81,10 @@ class Server(threading.Thread):
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
num_blocks: int,
|
||||
prefix: str,
|
||||
block_config: str,
|
||||
num_blocks: Optional[int] = None,
|
||||
block_indices: Optional[str] = None,
|
||||
num_handlers: Optional[int] = None,
|
||||
min_batch_size: int = 1,
|
||||
max_batch_size: int = 4096,
|
||||
@ -101,20 +103,37 @@ class Server(threading.Thread):
|
||||
"""Create a server with one or more bloom blocks. See run_server.py for documentation."""
|
||||
if custom_module_path is not None:
|
||||
add_custom_models_from_file(custom_module_path)
|
||||
|
||||
assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both"
|
||||
dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
|
||||
visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
|
||||
logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
|
||||
|
||||
num_handlers = num_handlers if num_handlers is not None else num_blocks * 8
|
||||
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
||||
block_config = DistributedBloomConfig.from_pretrained(block_config, use_auth_token=True)
|
||||
memory_cache = MemoryCache(device, cache_size_bytes)
|
||||
|
||||
if block_indices is not None:
|
||||
try:
|
||||
start, end = block_indices.split(':')
|
||||
start, end = map(int, map(str.strip, (start, end)))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse --block_indices ({e}), must be start:end (e.g. 0:33)")
|
||||
raise
|
||||
block_indices = range(start, end)
|
||||
else:
|
||||
assert num_blocks is not None
|
||||
block_indices = range(num_blocks) # TODO replace with proper load balancing
|
||||
|
||||
## TODO: the code below will load the entire model in RAM. Please replace with sliced model
|
||||
block_config = DistributedBloomConfig.from_pretrained(block_config, use_auth_token=True)
|
||||
# model = BloomForCausalLM.from_pretrained(model, use_auth_token=True)
|
||||
## /TODO
|
||||
|
||||
|
||||
# initialize modules
|
||||
blocks = {}
|
||||
for i in range(num_blocks):
|
||||
module_uid = f"dummy_block.{i}"
|
||||
block = BloomBlock(block_config, layer_number=i)
|
||||
for block_index in block_indices:
|
||||
module_uid = f"{prefix}.{block_index}"
|
||||
block = BloomBlock(block_config, layer_number=block_index)
|
||||
for param in block.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
@ -129,6 +148,8 @@ class Server(threading.Thread):
|
||||
max_batch_size=max_batch_size,
|
||||
)
|
||||
|
||||
num_handlers = num_handlers if num_handlers is not None else len(blocks) * 4
|
||||
|
||||
return cls(
|
||||
dht,
|
||||
blocks,
|
||||
|
Loading…
Reference in New Issue
Block a user