basic backend

justheuristic 2 years ago
parent 3215945882
commit 1c49bcb741

@ -0,0 +1,77 @@
import os, sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # add path to src
import configargparse
from hivemind.proto.runtime_pb2 import CompressionType
from hivemind.utils.limits import increase_file_limit
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from src.server.server import BloomServer
logger = get_logger(__name__)
def main():
# fmt:off
parser = configargparse.ArgParser(default_config_files=["config.yml"])
parser.add('-c', '--config', required=False, is_config_file=True, help='config file path')
parser.add_argument('--block_config', type=str, default='bigscience/bloom', help="name or path of model config")
parser.add_argument('--num_blocks', type=int, default=1, help="The number of blocks to serve")
parser.add_argument('--host_maddrs', type=list, nargs='+', default=['/ip4/'], required=False,
help='Multiaddrs to listen for external connections from other p2p instances; default: all IPv4 and TCP: /ip4/')
parser.add_argument('--announce_maddrs', type=list, nargs='+', default=None, required=False,
help='Visible multiaddrs the host announces for external connections from other p2p instances')
parser.add_argument('--compression', type=str, default='NONE', required=False, help='Tensor compression communication')
parser.add_argument('--num_handlers', type=int, default=None, required=False,
help='server will use this many processes to handle incoming requests')
parser.add_argument('--min_batch_size', type=int, default=1,
help='Minimum required batch size for all expert operations')
parser.add_argument('--max_batch_size', type=int, default=16384,
help='The total number of examples in the same batch will not exceed this value')
parser.add_argument('--cache_size_bytes', type=int, default=None,
help='The size of memory cache for storing past attention keys/values between inference steps')
parser.add_argument('--device', type=str, default=None, required=False,
help='all experts will use this device in torch notation; default: cuda if available else cpu')
parser.add_argument('--update_period', type=float, required=False, default=30,
help='Server will report experts to DHT once in this many seconds')
parser.add_argument('--expiration', type=float, required=False, default=None,
help='DHT entries will expire after this many seconds')
parser.add_argument('--initial_peers', type=str, nargs='*', required=False, default=[],
help='multiaddrs of one or more active DHT peers (if you want to join an existing DHT)')
parser.add_argument('--increase_file_limit', action='store_true',
help='On *nix, this will increase the max number of processes '
'a server can spawn before hitting "Too many open files"; Use at your own risk.')
parser.add_argument('--stats_report_interval', type=int, required=False,
help='Interval between two reports of batch processing performance statistics')
parser.add_argument('--custom_module_path', type=str, required=False,
help='Path of a file with custom nn.modules, wrapped into special decorator')
# fmt:on
args = vars(parser.parse_args())
args.pop("config", None)
if args.pop("increase_file_limit"):
compression_type = args.pop("compression")
compression = getattr(CompressionType, compression_type)
server = BloomServer.create(**args, start=True, compression=compression)
except KeyboardInterrupt:"Caught KeyboardInterrupt, shutting down")
if __name__ == "__main__":

@ -1 +1 @@
from src.bloom.model import BloomModel, BloomForCausalLM, MemoryEfficientBloomConfig
from src.bloom.model import BloomModel, BloomForCausalLM, DistributedBloomConfig

@ -8,6 +8,7 @@ from typing import Tuple
import torch
import torch.utils.checkpoint
from hivemind import use_hivemind_log_handler
from torch import nn
from torch.nn import CrossEntropyLoss, LayerNorm
from transformers.file_utils import (
@ -23,6 +24,7 @@ from transformers.utils import logging
from src.bloom.block import BloomBlock
from src.bloom.ops import build_alibi_tensor
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "bigscience/Bloom"
@ -30,7 +32,7 @@ _CONFIG_FOR_DOC = "MemoryEfficientBloomConfig"
_TOKENIZER_FOR_DOC = "BloomTokenizer"
class MemoryEfficientBloomConfig(_VanillaBloomConfig):
class DistributedBloomConfig(_VanillaBloomConfig):
compression: str = "none"
slow_but_exact: bool = False
@ -42,7 +44,7 @@ class BloomPreTrainedModel(PreTrainedModel):
config_class = MemoryEfficientBloomConfig
config_class = DistributedBloomConfig
base_model_prefix = "transformer"
supports_gradient_checkpointing = True
_no_split_modules = ["BloomBlock"]

@ -1,73 +0,0 @@
"""Code for serving bloom blocks via hivemind-server"""
import threading
from typing import AsyncIterator, Tuple, Optional
import torch
from hivemind import P2PContext, DHT
from import ConnectionHandler
from import DHTHandlerThread
from import ExpertBackend
from import Runtime
from import Server
from hivemind.proto import runtime_pb2
from torch import nn
from src.node.cache import AttentionCache
class BloomServer(Server):
"""Serves one or more bloom layers for inference, forward and backward; announces oneself to the DHT"""
def __init__(
self, dht: DHT, device=torch.device, num_connection_handlers: int = 8, update_period: int = 30,
attention_cache_size: Optional[int] = None, start=False, **kwargs,
self.attention_cache = AttentionCache(attention_cache_size, dtype=torch.bfloat16, device=torch.)
expert_blocks = dict(LOAD_BLOOM_LAYERS_HERE)
expert_backends = {name: _BloomBlockBackend(name, block, ..., self.attention_kv_cache) for name, block in expert_blocks.items()}
self.dht, self.experts, self.update_period = dht, expert_backends, update_period
self.conn_handlers = [ConnectionHandler(dht, self.experts) for _ in range(num_connection_handlers)]
self.runtime = Runtime(self.experts, **kwargs)
self.dht_handler_thread = DHTHandlerThread(self.experts, dht, update_period=update_period, daemon=True)
self.checkpoint_saver = None # no need to save checkpoints since we do not change model state
if start:
class _BloomConnectionHandler(ConnectionHandler):
"""Handles three request types: forward, backward and forward-incremental (inference)"""
async def rpc_forward_incremental(
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
) -> AsyncIterator[runtime_pb2.ExpertRequest]:
# encode expert_uid as @model_name[starting_layer:finishing_layer]
# - while not closed: read input embeddings, check input shapes, run inference, return batch of outputs, repeat
# - receive and maintain a handle for attention cache here
raise NotImplementedError()
class _BloomBlockBackend(ExpertBackend):
def __init__(self, name: str, expert: nn.Module, *, attention_cache: AttentionCache, **kwargs):
self.attention_cache = attention_cache
super().__init__(name, expert, **kwargs)
# BloomBackend serves a single layer
# - ensure that parameters do not require grad!
# - ensure that TaskPool for inference is NOT batched
# - ensure that optimizer/scheduler is not created
def forward_incremental(self, *inputs: torch.Tensor, attention_cache_handle: int) -> Tuple[torch.Tensor, ...]:
with self.attention_cache.use_cache(attention_cache_handle) as (current_length, cached_keys, cached_values):
raise NotImplementedError("TODO")
# later:
# - do not worry about OOM in cache for now! - just make sure that nothing except cache could oom.
# - contiguous attention cache with max size
# - select a subset of experts
# - priorities
# - option to backtrack a few tokens
# - ensure that backprop is performed optimally, does not accumulate grads wrt parameters
# - forget about length-adaptive forward/backward for now, use fixed length, maybe several fixed lengths - or better yet, forget finetuning for now

@ -0,0 +1,53 @@
"""Code for serving bloom blocks via hivemind-server"""
from typing import Tuple
import torch
from hivemind import BatchTensorDescriptor
from import ExpertBackend
from import TaskPool
from src.bloom.block import BloomBlock
from src.server.cache import MemoryCache
# BloomBackend serves a single layer
# - ensure that parameters do not require grad!
# - ensure that TaskPool for inference is NOT batched
# - ensure that optimizer/scheduler is not created
class BloomBlockBackend(ExpertBackend):
"""A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward"""
def __init__(self, name: str, module: BloomBlock, *, memory_cache: MemoryCache, **kwargs):
object().__init__() # to bypass super.__init__, self.module = name, module
self.memory_cache = memory_cache
for name, param in module.named_parameters():
assert not param.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
for name, buf in module.named_buffers():
assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
self.args_schema = (BatchTensorDescriptor(HARDCODCED_LENGTH, module.hidden_size),)
self.kwargs_schema = {}
self.outputs_schema = (BatchTensorDescriptor(HARDCODCED_LENGTH, module.hidden_size),)
self.forward_schema = (self.args_schema, self.kwargs_schema) # inputs for forward
self.backward_schema = (self.forward_schema, self.outputs_schema) # inputs to backward
self.grad_inputs_schema = self.forward_schema # outputs from backward have same shape as inputs for forward
self.forward_pool = TaskPool(self.forward, name=f"{}_forward", **kwargs)
self.backward_pool = TaskPool(self.backward, name=f"{}_backward", **kwargs)
def expert(self):
#TODO un-hardcode this naming from hivemind
return self.module
def forward_incremental(self, *inputs: torch.Tensor, attention_cache_handle: int) -> Tuple[torch.Tensor, ...]:
with self.memory_cache.use_cache(attention_cache_handle) as (current_length, cached_keys, cached_values):
raise NotImplementedError("TODO")

@ -5,6 +5,7 @@ For now, the only purpose of this code is to ensure that allocated memory will b
TODO In future, one could modify cache to implement, among other things,
- in allocate_cache, if there is not enough memory, wait for memory to be freed by existing tasks up to a given timeout.
-- note: this can be done using mp.Condtion
- allocate cache as one contigous buffer to avoid fragmentation
- quantize cached values using bitsandbytes
- LRU offloading from gpu to ram
@ -18,9 +19,11 @@ from typing import Dict, Optional, Union
import hivemind
import torch
from hivemind import use_hivemind_log_handler
from hivemind.utils import TensorDescriptor, get_logger
logger = get_logger(__file__)
logger = get_logger(__name__)
Handle = int

@ -0,0 +1,18 @@
from typing import AsyncIterator
from hivemind import P2PContext
from import ConnectionHandler
from hivemind.proto import runtime_pb2
class BloomConnectionHandler(ConnectionHandler):
"""Handles three request types: forward, backward and forward-incremental (inference)"""
async def rpc_forward_incremental(
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
) -> AsyncIterator[runtime_pb2.ExpertRequest]:
# encode expert_uid as @model_name[starting_layer:finishing_layer]
# - while not closed: read input embeddings, check input shapes, run inference, return batch of outputs, repeat
# - receive and maintain a handle for attention cache here
raise NotImplementedError()

@ -0,0 +1,108 @@
import threading
from typing import Optional, Dict, Union, Sequence
import torch
from hivemind import Server, DHT
from import DHTHandlerThread
from import add_custom_models_from_file
from import Runtime
from hivemind.proto.runtime_pb2 import CompressionType
from hivemind.utils.logging import use_hivemind_log_handler, get_logger
from src import DistributedBloomConfig
from src.bloom.block import BloomBlock
from src.server.cache import MemoryCache
from src.server.backend import BloomBlockBackend
from src.server.handler import BloomConnectionHandler
logger = get_logger(__file__)
class BloomServer(Server):
"""Serves one or more bloom layers for inference, forward and backward; announces oneself to the DHT"""
def __init__(
self, dht: DHT, module_backends: Dict[str, BloomBlockBackend], *,
device: torch.device, num_connection_handlers: int = 8, update_period: float = 30,
cache_size_bytes: Optional[int] = None, start: bool, **kwargs,
self.attention_cache = MemoryCache(device=device, max_size_bytes=cache_size_bytes)
self.dht, self.module_backends, self.update_period = dht, module_backends, update_period
self.conn_handlers = [BloomConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)]
self.runtime = Runtime(self.module_backends, device=device, **kwargs)
self.dht_handler_thread = DHTHandlerThread(self.experts, dht, update_period=update_period, daemon=True)
self.checkpoint_saver = None # no need to save checkpoints since we do not change model state
if start:
# noinspection PyMethodOverriding
def create(
num_blocks: int,
block_config: str,
num_handlers: Optional[int] = None,
min_batch_size: int = 1,
max_batch_size: int = 4096,
cache_size_bytes: Optional[int] = None,
device: Union[str, torch.device] = None,
initial_peers: Sequence[str] = (),
stats_report_interval: Optional[int] = None,
update_period: float = 30,
expiration: Optional[float] = None,
start: bool,
) -> Server:
"""Create a server with one or more bloom blocks. See for documentation."""
if custom_module_path is not None:
dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
num_handlers = num_handlers if num_handlers is not None else num_blocks * 8
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
if isinstance(block_config, str):
block_config = DistributedBloomConfig
# initialize modules
module_backends = {}
for i in range(len(module_backends)):
module_uid = f"dummy_block.{i}"
block = BloomBlock(block_config, layer_number=i)
#TODO run the actual model
module_backends[module_uid] = BloomBlockBackend(
if checkpoint_dir is not None:
load_experts(experts, checkpoint_dir)
return cls(