basic backend
parent
3215945882
commit
1c49bcb741
@ -0,0 +1,77 @@
|
||||
import os, sys
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # add path to src
|
||||
|
||||
import configargparse
|
||||
|
||||
from hivemind.proto.runtime_pb2 import CompressionType
|
||||
from hivemind.utils.limits import increase_file_limit
|
||||
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
||||
|
||||
from src.server.server import BloomServer
|
||||
|
||||
use_hivemind_log_handler("in_root_logger")
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def main():
|
||||
# fmt:off
|
||||
parser = configargparse.ArgParser(default_config_files=["config.yml"])
|
||||
parser.add('-c', '--config', required=False, is_config_file=True, help='config file path')
|
||||
|
||||
parser.add_argument('--block_config', type=str, default='bigscience/bloom', help="name or path of model config")
|
||||
parser.add_argument('--num_blocks', type=int, default=1, help="The number of blocks to serve")
|
||||
parser.add_argument('--host_maddrs', type=list, nargs='+', default=['/ip4/0.0.0.0/tcp/0'], required=False,
|
||||
help='Multiaddrs to listen for external connections from other p2p instances; default: all IPv4 and TCP: /ip4/0.0.0.0/tcp/0')
|
||||
parser.add_argument('--announce_maddrs', type=list, nargs='+', default=None, required=False,
|
||||
help='Visible multiaddrs the host announces for external connections from other p2p instances')
|
||||
|
||||
parser.add_argument('--compression', type=str, default='NONE', required=False, help='Tensor compression communication')
|
||||
|
||||
parser.add_argument('--num_handlers', type=int, default=None, required=False,
|
||||
help='server will use this many processes to handle incoming requests')
|
||||
parser.add_argument('--min_batch_size', type=int, default=1,
|
||||
help='Minimum required batch size for all expert operations')
|
||||
parser.add_argument('--max_batch_size', type=int, default=16384,
|
||||
help='The total number of examples in the same batch will not exceed this value')
|
||||
parser.add_argument('--cache_size_bytes', type=int, default=None,
|
||||
help='The size of memory cache for storing past attention keys/values between inference steps')
|
||||
parser.add_argument('--device', type=str, default=None, required=False,
|
||||
help='all experts will use this device in torch notation; default: cuda if available else cpu')
|
||||
|
||||
parser.add_argument('--update_period', type=float, required=False, default=30,
|
||||
help='Server will report experts to DHT once in this many seconds')
|
||||
parser.add_argument('--expiration', type=float, required=False, default=None,
|
||||
help='DHT entries will expire after this many seconds')
|
||||
parser.add_argument('--initial_peers', type=str, nargs='*', required=False, default=[],
|
||||
help='multiaddrs of one or more active DHT peers (if you want to join an existing DHT)')
|
||||
parser.add_argument('--increase_file_limit', action='store_true',
|
||||
help='On *nix, this will increase the max number of processes '
|
||||
'a server can spawn before hitting "Too many open files"; Use at your own risk.')
|
||||
parser.add_argument('--stats_report_interval', type=int, required=False,
|
||||
help='Interval between two reports of batch processing performance statistics')
|
||||
|
||||
parser.add_argument('--custom_module_path', type=str, required=False,
|
||||
help='Path of a file with custom nn.modules, wrapped into special decorator')
|
||||
|
||||
# fmt:on
|
||||
args = vars(parser.parse_args())
|
||||
args.pop("config", None)
|
||||
|
||||
if args.pop("increase_file_limit"):
|
||||
increase_file_limit()
|
||||
|
||||
compression_type = args.pop("compression")
|
||||
compression = getattr(CompressionType, compression_type)
|
||||
|
||||
server = BloomServer.create(**args, start=True, compression=compression)
|
||||
|
||||
try:
|
||||
server.join()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Caught KeyboardInterrupt, shutting down")
|
||||
finally:
|
||||
server.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1 +1 @@
|
||||
from src.bloom.model import BloomModel, BloomForCausalLM, MemoryEfficientBloomConfig
|
||||
from src.bloom.model import BloomModel, BloomForCausalLM, DistributedBloomConfig
|
@ -1,73 +0,0 @@
|
||||
"""Code for serving bloom blocks via hivemind-server"""
|
||||
import threading
|
||||
from typing import AsyncIterator, Tuple, Optional
|
||||
|
||||
import torch
|
||||
from hivemind import P2PContext, DHT
|
||||
from hivemind.moe.server.connection_handler import ConnectionHandler
|
||||
from hivemind.moe.server.dht_handler import DHTHandlerThread
|
||||
from hivemind.moe.server.expert_backend import ExpertBackend
|
||||
from hivemind.moe.server.runtime import Runtime
|
||||
from hivemind.moe.server.server import Server
|
||||
from hivemind.proto import runtime_pb2
|
||||
from torch import nn
|
||||
|
||||
from src.node.cache import AttentionCache
|
||||
|
||||
|
||||
class BloomServer(Server):
|
||||
"""Serves one or more bloom layers for inference, forward and backward; announces oneself to the DHT"""
|
||||
def __init__(
|
||||
self, dht: DHT, device=torch.device, num_connection_handlers: int = 8, update_period: int = 30,
|
||||
attention_cache_size: Optional[int] = None, start=False, **kwargs,
|
||||
):
|
||||
threading.Thread.__init__(self)
|
||||
self.attention_cache = AttentionCache(attention_cache_size, dtype=torch.bfloat16, device=torch.)
|
||||
expert_blocks = dict(LOAD_BLOOM_LAYERS_HERE)
|
||||
|
||||
expert_backends = {name: _BloomBlockBackend(name, block, ..., self.attention_kv_cache) for name, block in expert_blocks.items()}
|
||||
self.dht, self.experts, self.update_period = dht, expert_backends, update_period
|
||||
self.conn_handlers = [ConnectionHandler(dht, self.experts) for _ in range(num_connection_handlers)]
|
||||
self.runtime = Runtime(self.experts, **kwargs)
|
||||
self.dht_handler_thread = DHTHandlerThread(self.experts, dht, update_period=update_period, daemon=True)
|
||||
self.checkpoint_saver = None # no need to save checkpoints since we do not change model state
|
||||
|
||||
if start:
|
||||
self.run_in_background(await_ready=True)
|
||||
|
||||
|
||||
class _BloomConnectionHandler(ConnectionHandler):
|
||||
"""Handles three request types: forward, backward and forward-incremental (inference)"""
|
||||
|
||||
async def rpc_forward_incremental(
|
||||
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
|
||||
) -> AsyncIterator[runtime_pb2.ExpertRequest]:
|
||||
# encode expert_uid as @model_name[starting_layer:finishing_layer]
|
||||
# - while not closed: read input embeddings, check input shapes, run inference, return batch of outputs, repeat
|
||||
# - receive and maintain a handle for attention cache here
|
||||
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class _BloomBlockBackend(ExpertBackend):
|
||||
def __init__(self, name: str, expert: nn.Module, *, attention_cache: AttentionCache, **kwargs):
|
||||
self.attention_cache = attention_cache
|
||||
super().__init__(name, expert, **kwargs)
|
||||
#TODO
|
||||
# BloomBackend serves a single layer
|
||||
# - ensure that parameters do not require grad!
|
||||
# - ensure that TaskPool for inference is NOT batched
|
||||
# - ensure that optimizer/scheduler is not created
|
||||
|
||||
def forward_incremental(self, *inputs: torch.Tensor, attention_cache_handle: int) -> Tuple[torch.Tensor, ...]:
|
||||
with self.attention_cache.use_cache(attention_cache_handle) as (current_length, cached_keys, cached_values):
|
||||
raise NotImplementedError("TODO")
|
||||
|
||||
# later:
|
||||
# - do not worry about OOM in cache for now! - just make sure that nothing except cache could oom.
|
||||
# - contiguous attention cache with max size
|
||||
# - select a subset of experts
|
||||
# - priorities
|
||||
# - option to backtrack a few tokens
|
||||
# - ensure that backprop is performed optimally, does not accumulate grads wrt parameters
|
||||
# - forget about length-adaptive forward/backward for now, use fixed length, maybe several fixed lengths - or better yet, forget finetuning for now
|
@ -0,0 +1,53 @@
|
||||
"""Code for serving bloom blocks via hivemind-server"""
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from hivemind import BatchTensorDescriptor
|
||||
from hivemind.moe.server.expert_backend import ExpertBackend
|
||||
from hivemind.moe.server.task_pool import TaskPool
|
||||
|
||||
from src.bloom.block import BloomBlock
|
||||
from src.server.cache import MemoryCache
|
||||
|
||||
|
||||
# TODO
|
||||
# BloomBackend serves a single layer
|
||||
# - ensure that parameters do not require grad!
|
||||
# - ensure that TaskPool for inference is NOT batched
|
||||
# - ensure that optimizer/scheduler is not created
|
||||
|
||||
HARDCODCED_LENGTH = 2048
|
||||
|
||||
|
||||
class BloomBlockBackend(ExpertBackend):
|
||||
"""A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward"""
|
||||
def __init__(self, name: str, module: BloomBlock, *, memory_cache: MemoryCache, **kwargs):
|
||||
object().__init__() # to bypass super.__init__
|
||||
self.name, self.module = name, module
|
||||
self.memory_cache = memory_cache
|
||||
|
||||
for name, param in module.named_parameters():
|
||||
assert not param.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
|
||||
for name, buf in module.named_buffers():
|
||||
assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
|
||||
|
||||
self.args_schema = (BatchTensorDescriptor(HARDCODCED_LENGTH, module.hidden_size),)
|
||||
self.kwargs_schema = {}
|
||||
self.outputs_schema = (BatchTensorDescriptor(HARDCODCED_LENGTH, module.hidden_size),)
|
||||
|
||||
self.forward_schema = (self.args_schema, self.kwargs_schema) # inputs for forward
|
||||
self.backward_schema = (self.forward_schema, self.outputs_schema) # inputs to backward
|
||||
|
||||
self.grad_inputs_schema = self.forward_schema # outputs from backward have same shape as inputs for forward
|
||||
self.forward_pool = TaskPool(self.forward, name=f"{self.name}_forward", **kwargs)
|
||||
self.backward_pool = TaskPool(self.backward, name=f"{self.name}_backward", **kwargs)
|
||||
|
||||
@property
|
||||
def expert(self):
|
||||
#TODO un-hardcode this naming from hivemind
|
||||
return self.module
|
||||
|
||||
def forward_incremental(self, *inputs: torch.Tensor, attention_cache_handle: int) -> Tuple[torch.Tensor, ...]:
|
||||
with self.memory_cache.use_cache(attention_cache_handle) as (current_length, cached_keys, cached_values):
|
||||
raise NotImplementedError("TODO")
|
||||
|
@ -0,0 +1,18 @@
|
||||
from typing import AsyncIterator
|
||||
|
||||
from hivemind import P2PContext
|
||||
from hivemind.moe.server.connection_handler import ConnectionHandler
|
||||
from hivemind.proto import runtime_pb2
|
||||
|
||||
|
||||
class BloomConnectionHandler(ConnectionHandler):
|
||||
"""Handles three request types: forward, backward and forward-incremental (inference)"""
|
||||
|
||||
async def rpc_forward_incremental(
|
||||
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
|
||||
) -> AsyncIterator[runtime_pb2.ExpertRequest]:
|
||||
# encode expert_uid as @model_name[starting_layer:finishing_layer]
|
||||
# - while not closed: read input embeddings, check input shapes, run inference, return batch of outputs, repeat
|
||||
# - receive and maintain a handle for attention cache here
|
||||
|
||||
raise NotImplementedError()
|
@ -0,0 +1,108 @@
|
||||
import threading
|
||||
from typing import Optional, Dict, Union, Sequence
|
||||
|
||||
import torch
|
||||
from hivemind import Server, DHT
|
||||
from hivemind.moe.server.dht_handler import DHTHandlerThread
|
||||
from hivemind.moe.server.layers import add_custom_models_from_file
|
||||
from hivemind.moe.server.runtime import Runtime
|
||||
from hivemind.proto.runtime_pb2 import CompressionType
|
||||
from hivemind.utils.logging import use_hivemind_log_handler, get_logger
|
||||
|
||||
from src import DistributedBloomConfig
|
||||
from src.bloom.block import BloomBlock
|
||||
from src.server.cache import MemoryCache
|
||||
from src.server.backend import BloomBlockBackend
|
||||
from src.server.handler import BloomConnectionHandler
|
||||
|
||||
use_hivemind_log_handler("in_root_logger")
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
class BloomServer(Server):
|
||||
"""Serves one or more bloom layers for inference, forward and backward; announces oneself to the DHT"""
|
||||
def __init__(
|
||||
self, dht: DHT, module_backends: Dict[str, BloomBlockBackend], *,
|
||||
device: torch.device, num_connection_handlers: int = 8, update_period: float = 30,
|
||||
cache_size_bytes: Optional[int] = None, start: bool, **kwargs,
|
||||
):
|
||||
threading.Thread.__init__(self)
|
||||
self.attention_cache = MemoryCache(device=device, max_size_bytes=cache_size_bytes)
|
||||
|
||||
self.dht, self.module_backends, self.update_period = dht, module_backends, update_period
|
||||
self.conn_handlers = [BloomConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)]
|
||||
self.runtime = Runtime(self.module_backends, device=device, **kwargs)
|
||||
self.dht_handler_thread = DHTHandlerThread(self.experts, dht, update_period=update_period, daemon=True)
|
||||
self.checkpoint_saver = None # no need to save checkpoints since we do not change model state
|
||||
|
||||
if start:
|
||||
self.run_in_background(await_ready=True)
|
||||
|
||||
# noinspection PyMethodOverriding
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
num_blocks: int,
|
||||
block_config: str,
|
||||
num_handlers: Optional[int] = None,
|
||||
min_batch_size: int = 1,
|
||||
max_batch_size: int = 4096,
|
||||
cache_size_bytes: Optional[int] = None,
|
||||
device: Union[str, torch.device] = None,
|
||||
initial_peers: Sequence[str] = (),
|
||||
compression=CompressionType.NONE,
|
||||
stats_report_interval: Optional[int] = None,
|
||||
custom_module_path=None,
|
||||
update_period: float = 30,
|
||||
expiration: Optional[float] = None,
|
||||
*,
|
||||
start: bool,
|
||||
**kwargs,
|
||||
) -> Server:
|
||||
"""Create a server with one or more bloom blocks. See run_server.py for documentation."""
|
||||
if custom_module_path is not None:
|
||||
add_custom_models_from_file(custom_module_path)
|
||||
|
||||
dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
|
||||
visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
|
||||
logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
|
||||
|
||||
num_handlers = num_handlers if num_handlers is not None else num_blocks * 8
|
||||
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
||||
if isinstance(block_config, str):
|
||||
block_config = DistributedBloomConfig
|
||||
|
||||
# initialize modules
|
||||
module_backends = {}
|
||||
for i in range(len(module_backends)):
|
||||
module_uid = f"dummy_block.{i}"
|
||||
block = BloomBlock(block_config, layer_number=i)
|
||||
#TODO run the actual model
|
||||
|
||||
module_backends[module_uid] = BloomBlockBackend(
|
||||
name=expert_uid,
|
||||
expert=block,
|
||||
args_schema=args_schema,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_total_steps=num_total_steps,
|
||||
clip_grad_norm=clip_grad_norm,
|
||||
min_batch_size=min_batch_size,
|
||||
max_batch_size=max_batch_size,
|
||||
)
|
||||
|
||||
if checkpoint_dir is not None:
|
||||
load_experts(experts, checkpoint_dir)
|
||||
|
||||
return cls(
|
||||
dht,
|
||||
experts,
|
||||
cache_size_bytes=cache_size_bytes,
|
||||
num_connection_handlers=num_handlers,
|
||||
device=device,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
stats_report_interval=stats_report_interval,
|
||||
update_period=update_period,
|
||||
expiration=expiration,
|
||||
start=start,
|
||||
)
|
||||
|
Loading…
Reference in New Issue