switch to hivemind-master

inference_chain
justheuristic 2 years ago
parent 5a15c13ca7
commit 20497f81d1

@ -15,11 +15,13 @@ def main():
parser = configargparse.ArgParser(default_config_files=["config.yml"])
parser.add('-c', '--config', required=False, is_config_file=True, help='config file path')
parser.add_argument('--block_config', type=str, default='bigscience/bloom', help="name or path of model config")
parser.add_argument('--num_blocks', type=int, default=1, help="The number of blocks to serve")
parser.add_argument('--host_maddrs', type=list, nargs='+', default=['/ip4/0.0.0.0/tcp/0'], required=False,
parser.add_argument('--prefix', type=str, required=True, help="Announce all blocks with this prefix")
parser.add_argument('--block_config', type=str, default='bigscience/bloom-6b3', help="name or path of model config")
parser.add_argument('--num_blocks', type=int, default=None, help="The number of blocks to serve")
parser.add_argument('--block_indices', type=str, default=None, help="Specific block indices to serve")
parser.add_argument('--host_maddrs', nargs='+', default=['/ip4/0.0.0.0/tcp/0'], required=False,
help='Multiaddrs to listen for external connections from other p2p instances; default: all IPv4 and TCP: /ip4/0.0.0.0/tcp/0')
parser.add_argument('--announce_maddrs', type=list, nargs='+', default=None, required=False,
parser.add_argument('--announce_maddrs', nargs='+', default=None, required=False,
help='Visible multiaddrs the host announces for external connections from other p2p instances')
parser.add_argument('--compression', type=str, default='NONE', required=False, help='Tensor compression communication')

@ -11,7 +11,7 @@ from hivemind.proto.runtime_pb2 import CompressionType
from hivemind.utils.logging import use_hivemind_log_handler, get_logger
import multiprocessing as mp
from src import DistributedBloomConfig
from src import DistributedBloomConfig, BloomForCausalLM
from src.bloom.block import BloomBlock
from src.server.cache import MemoryCache
from src.server.backend import TransformerBackend
@ -81,8 +81,10 @@ class Server(threading.Thread):
@classmethod
def create(
cls,
num_blocks: int,
prefix: str,
block_config: str,
num_blocks: Optional[int] = None,
block_indices: Optional[str] = None,
num_handlers: Optional[int] = None,
min_batch_size: int = 1,
max_batch_size: int = 4096,
@ -101,20 +103,37 @@ class Server(threading.Thread):
"""Create a server with one or more bloom blocks. See run_server.py for documentation."""
if custom_module_path is not None:
add_custom_models_from_file(custom_module_path)
assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both"
dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
num_handlers = num_handlers if num_handlers is not None else num_blocks * 8
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
block_config = DistributedBloomConfig.from_pretrained(block_config, use_auth_token=True)
memory_cache = MemoryCache(device, cache_size_bytes)
if block_indices is not None:
try:
start, end = block_indices.split(':')
start, end = map(int, map(str.strip, (start, end)))
except Exception as e:
logger.error(f"Failed to parse --block_indices ({e}), must be start:end (e.g. 0:33)")
raise
block_indices = range(start, end)
else:
assert num_blocks is not None
block_indices = range(num_blocks) # TODO replace with proper load balancing
## TODO: the code below will load the entire model in RAM. Please replace with sliced model
block_config = DistributedBloomConfig.from_pretrained(block_config, use_auth_token=True)
# model = BloomForCausalLM.from_pretrained(model, use_auth_token=True)
## /TODO
# initialize modules
blocks = {}
for i in range(num_blocks):
module_uid = f"dummy_block.{i}"
block = BloomBlock(block_config, layer_number=i)
for block_index in block_indices:
module_uid = f"{prefix}.{block_index}"
block = BloomBlock(block_config, layer_number=block_index)
for param in block.parameters():
param.requires_grad = False
@ -129,6 +148,8 @@ class Server(threading.Thread):
max_batch_size=max_batch_size,
)
num_handlers = num_handlers if num_handlers is not None else len(blocks) * 4
return cls(
dht,
blocks,

Loading…
Cancel
Save