diff --git a/README.md b/README.md index b59606d..0e5547b 100644 --- a/README.md +++ b/README.md @@ -37,18 +37,18 @@ Then open a python notebook or console and run: ```python import torch import hivemind -from src import get_remote_module +from src import DistributedBloomConfig, get_remote_module dht = hivemind.DHT( initial_peers=[TODO_COPY_FULL_ADDRESS_FROM_ANY_OF_THE_SERVERS], # e.g. /ip4/127.0.0.1/... client_mode=True, start=True, ) - -layer3, layer4 = get_remote_module(dht, ['bigscience/test-bloomd-6b3.3', 'bigscience/test-bloomd-6b3.4']) +config = DistributedBloomConfig.from_pretrained("bigscience/test-bloom-6b3") +layer3, layer4 = get_remote_module(dht, ['bigscience/test-bloomd-6b3.3', 'bigscience/test-bloomd-6b3.4'], config) assert layer3 is not None and layer4 is not None, "one or both layers were not found in DHT" # test forward/backward, two blocks -outputs, = layer4(*layer3(torch.randn(1, 64, 4096))) +outputs = layer4(layer3(torch.randn(1, 64, 4096))) loss = (outputs * torch.randn_like(outputs)).norm() loss.backward() @@ -74,18 +74,18 @@ python -m cli.convert_model --model bigscience/bloom-6b3 \ To test distributed inference, run one or more servers, then open a new shell and run pytest with environment variables: ```bash -# shell A: serve blocks 3 and 4 +# shell A: serve model python -m cli.run_server --converted_model_name_or_path bigscience/test-bloomd-6b3 \ - --block_indices 3:5 --torch_dtype float32 --identity_path ./server1.id --host_maddrs /ip4/127.0.0.1/tcp/31337 + --torch_dtype float32 --identity_path ./server1.id --host_maddrs /ip4/127.0.0.1/tcp/31337 -# shell B: connect to the swarm and test individual blocks for exact match -export PYTHONPATH=. INITIAL_PEERS="/ip4/TODO_COPY_INITIAL_PEERS_FROM_SERVER_OUTPUT" -BLOCK_UID=bigscience/test-bloomd-6b3.3 pytest tests/test_block_exact_match.py -BLOCK_UID=bigscience/test-bloomd-6b3.4 pytest tests/test_block_exact_match.py +# shell B: +export PYTHONPATH=. +export INITIAL_PEERS="/ip4/TODO_COPY_INITIAL_PEERS_FROM_SERVER_OUTPUT" +export MODEL_NAME="bigscience/test-bloomd-6b3" -# the test below will fail because there is no server that serves layer 7 -# BLOCK_UID=bigscience/test-bloomd-6b3.7 pytest tests/test_block_exact_match.py +# test individual random blocks for exact match +pytest tests/test_block_exact_match.py -# test the full model (requires that servers collectively serve all model layers) -REF_NAME=bigscience/bloom-6b3 pytest tests/test_full_model.py +# test the full model +pytest tests/test_full_model.py ``` diff --git a/src/client/__init__.py b/src/client/__init__.py index 0335921..165de67 100644 --- a/src/client/__init__.py +++ b/src/client/__init__.py @@ -1,5 +1,4 @@ from src.client.inference_session import RemoteSequentialInferenceSession, RemoteTransformerBlockInferenceSession -from src.client.remote_block import RemoteTransformerBlock from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel -from src.client.remote_sequential import RemoteSequential +from src.client.remote_sequential import RemoteSequential, RemoteTransformerBlock from src.client.sequence_manager import RemoteSequenceManager diff --git a/src/client/remote_block.py b/src/client/remote_block.py deleted file mode 100644 index 7d0f920..0000000 --- a/src/client/remote_block.py +++ /dev/null @@ -1,40 +0,0 @@ -# Note: this code is being actively modified by justheuristic. If you want to change anything about it, please warn me. -from __future__ import annotations - -import random - -import torch -from hivemind.moe.client.expert import RemoteExpert, RemoteExpertWorker -from hivemind.moe.expert_uid import ExpertInfo -from hivemind.p2p import P2P, StubBase -from hivemind.utils import get_logger, use_hivemind_log_handler - -from src.client.inference_session import RemoteTransformerBlockInferenceSession -from src.data_structures import RemoteModuleInfo -from src.server.handler import TransformerConnectionHandler - -use_hivemind_log_handler("in_root_logger") -logger = get_logger(__file__) - - -class RemoteTransformerBlock(RemoteExpert): - """A class that interacts with a remote module on a specific server for forward/backward or inference""" - - def __init__(self, peers_info: RemoteModuleInfo, p2p: P2P): - peer_info = ExpertInfo(peers_info.uid, random.choice(list(peers_info.servers.keys()))) # TODO replace this - super().__init__(peer_info, p2p) - - @property - def stub(self) -> StubBase: - return TransformerConnectionHandler.get_stub(self.p2p, self.peer_id) - - def forward(self, inputs: torch.Tensor, **kwargs): - for k, v in kwargs.items(): - assert v is None or v is False, f"Extra keyword arguments are not yet supported (got {k} = {v})" - return super().forward(inputs) - - def inference_session(self, **kwargs) -> RemoteTransformerBlockInferenceSession: - """Initialize a new inference session with the specified remote server""" - return RemoteExpertWorker.run_coroutine( - RemoteTransformerBlockInferenceSession._create(self.stub, self.uid, self.info, **kwargs) - ) diff --git a/src/client/remote_sequential.py b/src/client/remote_sequential.py index 86cca85..d9e63b2 100644 --- a/src/client/remote_sequential.py +++ b/src/client/remote_sequential.py @@ -1,6 +1,5 @@ from __future__ import annotations -import logging from typing import Optional, Union import torch @@ -10,11 +9,9 @@ from torch import nn import src from src.client.inference_session import RemoteSequentialInferenceSession -from src.client.remote_block import RemoteTransformerBlock from src.client.sequence_manager import RemoteSequenceManager from src.client.sequential_autograd import _RemoteSequentialAutogradFunction from src.data_structures import UID_DELIMITER -from src.dht_utils import _create_remote_modules_from_infos from src.utils.misc import DUMMY use_hivemind_log_handler("in_root_logger") @@ -57,12 +54,16 @@ class RemoteSequential(nn.Module): outputs = _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager) return outputs - def __getitem__(self, ix: Union[int, slice]) -> Union[RemoteTransformerBlock, RemoteSequential]: + def __getitem__(self, ix: Union[int, slice]) -> RemoteSequential: assert isinstance(ix, (int, slice)) if isinstance(ix, int): - assert 0 <= ix < len(self) - (module,) = _create_remote_modules_from_infos([self.sequence_manager.block_infos[ix]], self.p2p) - return module + return RemoteTransformerBlock( + self.config, + self.dht, + dht_prefix=self.dht_prefix, + p2p=self.p2p, + sequence_manager=self.sequence_manager[ix], + ) else: return RemoteSequential( self.config, @@ -85,3 +86,18 @@ class RemoteSequential(nn.Module): def extra_repr(self) -> str: return f"modules={self.sequence_manager.block_uids[0]}..{self.sequence_manager.block_uids[-1]}" + + +class RemoteTransformerBlock(RemoteSequential): + """Single transformer block hosted by swarm + + This class is deprecated and kept for backward compatibility. + It will be removed soon in favor of using ``RemoteSequential`` directly. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + assert len(self) == 1, "Remote Block is a sequence size 1" + + def extra_repr(self): + return f"{self.sequence_manager.block_uids[0]}" diff --git a/src/client/sequence_manager.py b/src/client/sequence_manager.py index 777f070..af552dd 100644 --- a/src/client/sequence_manager.py +++ b/src/client/sequence_manager.py @@ -82,6 +82,7 @@ class RemoteSequenceManager: for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)): if info is None: logger.warning(f"Found no block info for block {uid}") + continue if not isinstance(info, RemoteModuleInfo): logger.warning(f"Unexpected dht entry type for {uid}: {info}") if not info.servers: @@ -95,22 +96,24 @@ class RemoteSequenceManager: closed_spans = [] active_spans = {} for block_index, info in enumerate(block_infos): - for peer_id, server in info.servers.items(): - if server.state != ServerState.ONLINE: - continue - if peer_id not in active_spans: - active_spans[peer_id] = RemoteSpanInfo(start=block_index, end=block_index + 1, peer_id=peer_id) - else: # peer_id in active_spans - active_spans[peer_id].end = block_index + 1 + if info is not None: + for peer_id, server in info.servers.items(): + if server.state != ServerState.ONLINE: + continue + if peer_id not in active_spans: + active_spans[peer_id] = RemoteSpanInfo(start=block_index, end=block_index + 1, peer_id=peer_id) + else: # peer_id in active_spans + active_spans[peer_id].end = block_index + 1 for peer_id in list(active_spans.keys()): if ( - peer_id not in info.servers + info is None + or peer_id not in info.servers or info.servers[peer_id].state != ServerState.ONLINE or block_index == len(block_infos) - 1 ): closed_spans.append(active_spans.pop(peer_id)) - assert not active_spans + assert not active_spans, f"spans: {active_spans}" closed_spans.sort(key=lambda span: span.end - span.start, reverse=True) diff --git a/src/client/sequential_autograd.py b/src/client/sequential_autograd.py index 081194c..1498236 100644 --- a/src/client/sequential_autograd.py +++ b/src/client/sequential_autograd.py @@ -110,7 +110,7 @@ async def sequential_forward( If some subsequence fails, reconstructs the remaining path and tries to finish the forward. """ - assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3 + assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}" end_index = end_index if end_index is not None else len(sequence_manager.block_uids) assert start_index >= 0 and end_index <= len(sequence_manager.block_uids) diff --git a/src/dht_utils.py b/src/dht_utils.py index fe5df32..78ef083 100644 --- a/src/dht_utils.py +++ b/src/dht_utils.py @@ -9,7 +9,7 @@ from typing import Dict, List, Optional, Sequence, Union from hivemind.dht import DHT, DHTNode, DHTValue from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker -from hivemind.p2p import P2P, PeerID +from hivemind.p2p import PeerID from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler import src @@ -72,34 +72,63 @@ async def _declare_active_modules( ) +def get_remote_sequence( + dht: DHT, + start: int, + stop: int, + config: src.DistributedBloomConfig, + dht_prefix: Optional[str] = None, + return_future: bool = False, +) -> Union[src.RemoteSequential, MPFuture]: + return RemoteExpertWorker.run_coroutine( + _get_remote_sequence(dht, start, stop, config, dht_prefix), return_future=return_future + ) + + +async def _get_remote_sequence( + dht: DHT, + start: int, + stop: int, + config: src.DistributedBloomConfig, + dht_prefix: Optional[str] = None, +) -> src.RemoteSequential: + uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start, stop)] + p2p = await dht.replicate_p2p() + manager = src.RemoteSequenceManager(dht, uids, p2p) + return src.RemoteSequential(config, dht, dht_prefix, p2p, manager) + + def get_remote_module( dht: DHT, uid_or_uids: Union[ModuleUID, List[ModuleUID]], - expiration_time: Optional[DHTExpiration] = None, + config: src.DistributedBloomConfig, + dht_prefix: Optional[str] = None, return_future: bool = False, -) -> Union[List[Optional[src.RemoteTransformerBlock]], MPFuture[List[Optional[src.RemoteTransformerBlock]]]]: +) -> Union[Union[src.RemoteTransformerBlock, List[src.RemoteTransformerBlock]], MPFuture]: """ :param uid_or_uids: find one or more modules with these ids from across the DHT - :param expiration_time: if specified, return modules that expire no sooner than this (based on get_dht_time) + :param config: model config, usualy taken by .from_pretrained(MODEL_NAME) :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background. - :returns: a list of [RemoteTransformerBlock if found else None] + :returns: a list of [RemoteTransformerBlock] """ - single_uid = isinstance(uid_or_uids, ModuleUID) - uids = [uid_or_uids] if single_uid else uid_or_uids - infos = dht.run_coroutine( - partial(_get_remote_module_infos, uids=uids, expiration_time=expiration_time), return_future + return RemoteExpertWorker.run_coroutine( + _get_remote_module(dht, uid_or_uids, config, dht_prefix), return_future=return_future ) - if return_future: - - async def _unpack(infos_future: MPFuture, dht: DHT): - p2p = await dht.replicate_p2p() - modules = _create_remote_modules_from_infos(await infos_future, p2p) - return modules[0] if single_uid else modules - return RemoteExpertWorker.run_coroutine(_unpack(infos, dht), return_future) - p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p()) - modules = _create_remote_modules_from_infos(infos, p2p) +async def _get_remote_module( + dht: DHT, + uid_or_uids: Union[ModuleUID, List[ModuleUID]], + config: src.DistributedBloomConfig, + dht_prefix: Optional[str] = None, +) -> Union[src.RemoteTransformerBlock, List[src.RemoteTransformerBlock]]: + single_uid = isinstance(uid_or_uids, ModuleUID) + uids = [uid_or_uids] if single_uid else uid_or_uids + p2p = await dht.replicate_p2p() + managers = (src.RemoteSequenceManager(dht, [uid], p2p) for uid in uids) + modules = [ + src.RemoteTransformerBlock(config, dht, dht_prefix=dht_prefix, p2p=p2p, sequence_manager=m) for m in managers + ] return modules[0] if single_uid else modules @@ -149,15 +178,3 @@ async def _get_remote_module_infos( if servers: modules[i] = RemoteModuleInfo(uid, servers) return modules - - -def _create_remote_modules_from_infos( - infos: Sequence[Optional[RemoteModuleInfo]], p2p: P2P -) -> List[Optional[src.RemoteTransformerBlock]]: - modules: List[Optional[src.RemoteTransformerBlock]] = [] - for info in infos: - if info is not None: - modules.append(src.RemoteTransformerBlock(info, p2p)) - else: - modules.append(None) - return modules diff --git a/tests/test_block_exact_match.py b/tests/test_block_exact_match.py index 4761aea..fad84ae 100644 --- a/tests/test_block_exact_match.py +++ b/tests/test_block_exact_match.py @@ -7,8 +7,10 @@ import transformers from hivemind import P2PHandlerError from test_utils import * +import src +from src import DistributedBloomConfig from src.bloom.from_pretrained import load_pretrained_block -from src.client.remote_block import RemoteTransformerBlock +from src.client.remote_sequential import RemoteTransformerBlock from src.data_structures import UID_DELIMITER from src.dht_utils import get_remote_module @@ -16,16 +18,14 @@ from src.dht_utils import get_remote_module @pytest.mark.forked def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3): dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True) - config = transformers.AutoConfig.from_pretrained(MODEL_NAME) + config = DistributedBloomConfig.from_pretrained(MODEL_NAME) for block_index in random.sample(range(config.n_layer), 3): - block_uid = f"{MODEL_NAME}{UID_DELIMITER}{block_index}" - remote_block = get_remote_module(dht, block_uid) - assert remote_block is not None, f"Could not find {block_uid} in DHT" + remote_block = get_remote_module(dht, f"{MODEL_NAME}{UID_DELIMITER}{block_index}", config) assert isinstance(remote_block, RemoteTransformerBlock) inputs = torch.randn(1, 8, config.hidden_size) - (outputs_forward,) = remote_block(inputs) + outputs_forward = remote_block(inputs) outputs_inference = [] with remote_block.inference_session(max_length=inputs.shape[1]) as sess: diff --git a/tests/test_chained_calls.py b/tests/test_chained_calls.py index 8148286..7cf6d44 100644 --- a/tests/test_chained_calls.py +++ b/tests/test_chained_calls.py @@ -7,25 +7,20 @@ import hivemind import pytest import torch -import transformers -from hivemind.moe.expert_uid import UID_DELIMITER, ExpertInfo from test_utils import * +import src from src.bloom.from_pretrained import load_pretrained_block -from src.client.remote_block import RemoteTransformerBlock -from src.dht_utils import get_remote_module +from src.client.remote_sequential import RemoteSequential +from src.dht_utils import get_remote_sequence @pytest.mark.forked def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1): dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True) - config = transformers.AutoConfig.from_pretrained(MODEL_NAME) - remote_block = get_remote_module(dht, f"{MODEL_NAME}{UID_DELIMITER}0") - assert remote_block is not None, f"Could not find {MODEL_NAME}{UID_DELIMITER}0 in DHT" - assert isinstance(remote_block, RemoteTransformerBlock) - - _ = remote_block.info # lazy-init info now, because otherwise we will _break_ info init by chaning _info - remote_block._info = ExpertInfo(f"{MODEL_NAME}.3 {MODEL_NAME}.4 {MODEL_NAME}.5", remote_block._info.peer_id) + config = src.DistributedBloomConfig.from_pretrained(MODEL_NAME) + remote_blocks = get_remote_sequence(dht, 3, 6, config) + assert isinstance(remote_blocks, RemoteSequential) ref_blocks = [ load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch.float32), @@ -33,7 +28,7 @@ def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq load_pretrained_block(MODEL_NAME, 5, torch_dtype=torch.float32), ] inputs = torch.randn(1, seq_length, config.hidden_size, requires_grad=True) - outputs_rpc = remote_block.forward(inputs)[0] + outputs_rpc = remote_blocks.forward(inputs) outputs_rpc.sum().backward() grads_rpc = inputs.grad @@ -52,18 +47,14 @@ def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq @pytest.mark.forked def test_chained_inference_exact_match(atol_inference=1e-4): dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True) - config = transformers.AutoConfig.from_pretrained(MODEL_NAME) - remote_block = get_remote_module(dht, f"{MODEL_NAME}{UID_DELIMITER}0") - assert remote_block is not None, f"Could not find {MODEL_NAME}{UID_DELIMITER}0 in DHT" - assert isinstance(remote_block, RemoteTransformerBlock) - - _ = remote_block.info # lazy-init info now, because otherwise we will _break_ info init by chaning _info - remote_block._info = ExpertInfo(f"{MODEL_NAME}.3 {MODEL_NAME}.4", remote_block._info.peer_id) + config = src.DistributedBloomConfig.from_pretrained(MODEL_NAME) + remote_blocks = get_remote_sequence(dht, 3, 5, config) + assert isinstance(remote_blocks, RemoteSequential) inputs = torch.randn(1, 8, config.hidden_size) outputs_inference = [] - with remote_block.inference_session(max_length=inputs.shape[1]) as sess: + with remote_blocks.inference_session(max_length=inputs.shape[1]) as sess: for i in range(inputs.shape[1]): outputs_inference.append(sess.step(inputs[:, i : i + 1, :])) outputs_inference = torch.cat(outputs_inference, dim=1)