prototype remote sequential

This commit is contained in:
justheuristic 2022-06-29 10:45:08 +03:00
parent 6e3db6bed6
commit 5849cea28c
2 changed files with 94 additions and 0 deletions

View File

@ -0,0 +1,28 @@
# this code is in active development, interfaces may change
from hivemind import DHT, use_hivemind_log_handler, get_logger
from src.bloom import DistributedBloomConfig, BloomForCausalLM
from src.client.remote_sequential import RemoteSequential
from src.data_structures import UID_DELIMITER
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
class DistributedBloomForCausalLM(BloomForCausalLM):
"""BloomForCausalLM, but all transformer layers are hosted by the swarm"""
def __init__(self, config: DistributedBloomConfig, dht: DHT, prefix: str):
logger.warning(f"{self.__class__.__name__} is in active development; expect adventures")
if prefix.endswith(UID_DELIMITER):
logger.warning(f"dht_prefix {prefix} already ends with '{UID_DELIMITER}'."
f"This will cause {self.__class__.__name__} to look for modules under "
f"{prefix}{UID_DELIMITER}*. Please make sure this is what you intended.")
n_layer, config.n_layer = config.n_layer, 0 # temporarily set n_layer to 0 to prevent layer initialization
super().__init__(config)
assert len(self.transformer.h) == 0
config.n_layer = n_layer
self.transformer.h = RemoteSequential(config, dht, prefix)

View File

@ -0,0 +1,66 @@
import logging
from functools import partial
from typing import Tuple
import torch
from hivemind import DHT
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from torch import nn
from src import DistributedBloomConfig
from src.client.remote_model import logger
from src.data_structures import UID_DELIMITER, RemoteModuleInfo
from src.dht_utils import _get_remote_module_infos, _create_remote_modules_from_infos
class RemoteSequential(nn.Sequential):
"""A sequence of transformer blocks hosted by the swarm"""
def __init__(self, config: DistributedBloomConfig, dht: DHT, prefix: str, max_retries: int = 3):
super().__init__()
self.config = config
self.dht = dht
self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
self.prefix = prefix
self.block_uids = tuple(f"{prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer))
logger.debug(f"Remote block uids: {self.block_uids}")
self.block_infos: Tuple[RemoteModuleInfo, ...] = tuple(dht.run_coroutine(
partial(_get_remote_module_infos, uids=self.block_uids, expiration_time=float('inf')), return_future=False
))
self.max_retries = max_retries
assert len(self.block_infos) == len(self.block_uids)
for uid, info in zip(self.block_uids, self.block_infos):
assert isinstance(info, (type(None), RemoteModuleInfo)), f"Unexpected dht entry for {uid}: {info}"
assert info is not None, f"Found no active peers for block {uid}"
assert isinstance(info.peer_ids, (list, tuple)), f"expected peer_ids to be list/tuple, got {info.peer_ids}"
assert info.uid == uid, f"The DHT entry for {uid} actually points to {info.uid}"
assert len(info.peer_ids) > 0, f"Found no active peers for block {uid}"
def forward(self, inputs: torch.Tensor):
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3 and inputs.shape[-1] == self.config.n_embed
for block_index in range(self.config.n_layer):
for retry_index in range(self.max_retries):
try:
block = self[block_index]
outputs, = block(inputs)
assert isinstance(outputs, torch.Tensor)
assert outputs.shape == inputs.shape, f"Expected {block} output {inputs.shape}, got {outputs.shape}"
inputs = outputs
break
except Exception as e:
if retry_index == self.max_retries - 1:
raise e
else:
logging.debug(f"Caught {e} when running forward for block {block_index}", exc_info=True)
return inputs
def __getitem__(self, block_index: int):
assert 0 <= block_index < self.config.n_layer
module, = _create_remote_modules_from_infos([self.block_infos[block_index]], self.p2p)
return module
def __iter__(self):
for block_index in range(self.config.n_layer):
yield self[block_index]